| import argparse | |
| import torch | |
| from model import GaussianSplatting2D | |
| from utils.misc_utils import load_cfg | |
| def get_gaussian_cfg(args): | |
| gaussian_cfg = f"num-{args.num_gaussians:d}" | |
| if args.disable_inverse_scale: | |
| gaussian_cfg += f"_scale-{args.init_scale:.1f}" | |
| else: | |
| gaussian_cfg += f"_inv-scale-{args.init_scale:.1f}" | |
| if not args.quantize: | |
| args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits = 32, 32, 32, 32 | |
| min_bits = min(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits) | |
| max_bits = max(args.pos_bits, args.scale_bits, args.rot_bits, args.feat_bits) | |
| if min_bits < 4 or max_bits > 32: | |
| raise ValueError( | |
| f"Bit precision must be between 4 and 32 but got: {args.pos_bits:d}, {args.scale_bits:d}, {args.rot_bits:d}, {args.feat_bits:d}" | |
| ) | |
| gaussian_cfg += f"_bits-{args.pos_bits:d}-{args.scale_bits:d}-{args.rot_bits:d}-{args.feat_bits:d}" | |
| if not args.disable_topk_norm: | |
| gaussian_cfg += f"_top-{args.topk:d}" | |
| gaussian_cfg += f"_{args.init_mode[0]}-{args.init_random_ratio:.1f}" | |
| return gaussian_cfg | |
| def get_log_dir(args): | |
| gaussian_cfg = get_gaussian_cfg(args) | |
| loss_cfg = f"l1-{args.l1_loss_ratio:.1f}_l2-{args.l2_loss_ratio:.1f}_ssim-{args.ssim_loss_ratio:.1f}" | |
| folder = f"{gaussian_cfg}_{loss_cfg}" | |
| if args.downsample: | |
| folder += f"_ds-{args.downsample_ratio:.1f}" | |
| if not args.disable_lr_schedule: | |
| folder += f"_decay-{args.max_decay_times:d}-{args.decay_ratio:.1f}" | |
| if not args.disable_prog_optim: | |
| folder += "_prog" | |
| return f"{args.log_root}/{args.exp_name}/{folder}" | |
| def main(args): | |
| args.log_dir = get_log_dir(args) | |
| ImageGS = GaussianSplatting2D(args) | |
| if args.eval: | |
| ImageGS.render(render_height=args.render_height) | |
| else: | |
| ImageGS.optimize() | |
| if __name__ == "__main__": | |
| torch.hub.set_dir("models/torch") | |
| parser = argparse.ArgumentParser() | |
| parser = load_cfg(cfg_path="cfgs/default.yaml", parser=parser) | |
| arguments = parser.parse_args() | |
| main(arguments) | |