Image-GS / main.py
Julien Blanchon
Deploy optimized Image-GS with dynamic dependencies
d62394f
raw
history blame
2.08 kB
import argparse
import torch
from model import GaussianSplatting2D
from utils.misc_utils import load_cfg
def get_gaussian_cfg(args):
gaussian_cfg = f"num-{args.num_gaussians:d}"
if args.disable_inverse_scale:
gaussian_cfg += f"_scale-{args.init_scale:.1f}"
else:
gaussian_cfg += f"_inv-scale-{args.init_scale:.1f}"
if not args.quantize:
args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits = 32, 32, 32, 32
min_bits = min(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits)
max_bits = max(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits)
if min_bits < 4 or max_bits > 32:
raise ValueError(
f"Bit precision must be between 4 and 32 but got: {args.pos_bits:d}, {args.scale_bits:d}, {args.rot_bits:d}, {args.feat_bits:d}"
)
gaussian_cfg += f"_bits-{args.pos_bits:d}-{args.scale_bits:d}-{args.rot_bits:d}-{args.feat_bits:d}"
if not args.disable_topk_norm:
gaussian_cfg += f"_top-{args.topk:d}"
gaussian_cfg += f"_{args.init_mode[0]}-{args.init_random_ratio:.1f}"
return gaussian_cfg
def get_log_dir(args):
gaussian_cfg = get_gaussian_cfg(args)
loss_cfg = f"l1-{args.l1_loss_ratio:.1f}_l2-{args.l2_loss_ratio:.1f}_ssim-{args.ssim_loss_ratio:.1f}"
folder = f"{gaussian_cfg}_{loss_cfg}"
if args.downsample:
folder += f"_ds-{args.downsample_ratio:.1f}"
if not args.disable_lr_schedule:
folder += f"_decay-{args.max_decay_times:d}-{args.decay_ratio:.1f}"
if not args.disable_prog_optim:
folder += "_prog"
return f"{args.log_root}/{args.exp_name}/{folder}"
def main(args):
args.log_dir = get_log_dir(args)
ImageGS = GaussianSplatting2D(args)
if args.eval:
ImageGS.render(render_height=args.render_height)
else:
ImageGS.optimize()
if __name__ == "__main__":
torch.hub.set_dir("models/torch")
parser = argparse.ArgumentParser()
parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser)
arguments = parser.parse_args()
main(arguments)