| import os |
| import sys |
| import random |
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| from argparse import ArgumentParser |
|
|
| from core.registry import register_method |
| from core.base_method import BaseMethod |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../hac_plus_official'))) |
| from utils.loss_utils import l1_loss, ssim |
| from gaussian_renderer import prefilter_voxel, render as native_render, generate_neural_gaussians |
| from scene import Scene, GaussianModel |
| from arguments import ModelParams, PipelineParams, OptimizationParams |
| from utils.encodings import get_binary_vxl_size |
|
|
| @register_method("hac_plus") |
| class HACPlusWrapper(BaseMethod): |
| def __init__(self, dataset_config, hyperparams): |
| self.parser = ArgumentParser() |
| self.lp = ModelParams(self.parser) |
| self.op = OptimizationParams(self.parser) |
| self.pp = PipelineParams(self.parser) |
| self.args = self.parser.parse_args([]) |
|
|
| self.args.source_path = dataset_config["source_path"] |
| self.args.model_path = dataset_config["model_path"] |
| self.args.eval = True |
| self.args.resolution = dataset_config.get("resolution", 1) |
| self.track_decoupling = hyperparams.get("track_decoupling", False) |
|
|
| self.args.n_features = 4 |
| self.args.log2 = 13 |
| self.args.log2_2D = 15 |
| self.args.lmbda = 0.001 |
|
|
| self.dataset = self.lp.extract(self.args) |
| self.opt = self.op.extract(self.args) |
| self.pipe = self.pp.extract(self.args) |
|
|
| is_synthetic_nerf = os.path.exists(os.path.join(self.args.source_path, "transforms_train.json")) |
|
|
| self.gaussians = GaussianModel( |
| self.dataset.feat_dim, |
| self.dataset.n_offsets, |
| self.dataset.voxel_size, |
| self.dataset.update_depth, |
| self.dataset.update_init_factor, |
| self.dataset.update_hierachy_factor, |
| self.dataset.use_feat_bank, |
| n_features_per_level=self.args.n_features, |
| log2_hashmap_size=self.args.log2, |
| log2_hashmap_size_2D=self.args.log2_2D, |
| is_synthetic_nerf=is_synthetic_nerf, |
| ) |
| |
| import sys as _sys |
| _scene, _explicit_res = None, None |
| for _i, _a in enumerate(_sys.argv[:-1]): |
| _v = _sys.argv[_i + 1] |
| if _a == "--scene": _scene = _v |
| elif _a == "--source_path": _scene = _v.rstrip("/").split("/")[-1] |
| elif _a == "--resolution": |
| try: _explicit_res = int(_v) |
| except: pass |
| _OUTDOOR_360 = {"bicycle", "flowers", "garden", "stump", "treehill"} |
| if _explicit_res is not None and _explicit_res > 0: |
| _res = _explicit_res |
| elif _scene is not None: |
| _res = 4 if _scene in _OUTDOOR_360 else 2 |
| else: |
| _res = None |
| try: |
| if _res is not None: |
| self.dataset.resolution = _res |
| print("[res-fix] scene=%s explicit=%s -> res=%s (%s)" % (_scene, _explicit_res, _res, __file__)) |
| except Exception as _e: |
| print("[res-fix] FAILED:", _e) |
| |
| self.scene = Scene(self.dataset, self.gaussians) |
| self.gaussians.update_anchor_bound() |
| self.gaussians.training_setup(self.opt) |
|
|
| bg_color = [1, 1, 1] if self.dataset.white_background else [0, 0, 0] |
| self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") |
| self.viewpoint_stack = self.scene.getTrainCameras().copy() |
| self.last_n_gaussians = self.gaussians.get_anchor.shape[0] * self.gaussians.n_offsets |
|
|
| def train_iteration(self, step): |
| self.gaussians.update_learning_rate(step) |
| if not self.viewpoint_stack: |
| self.viewpoint_stack = self.scene.getTrainCameras().copy() |
| |
| viewpoint_cam = self.viewpoint_stack.pop(random.randint(0, len(self.viewpoint_stack) - 1)) |
| |
| voxel_visible_mask = prefilter_voxel(viewpoint_cam, self.gaussians, self.pipe, self.background) |
| retain_grad = (step < self.opt.update_until and step >= 0) |
| |
| render_pkg = native_render(viewpoint_cam, self.gaussians, self.pipe, self.background, visible_mask=voxel_visible_mask, retain_grad=retain_grad, step=step) |
| image = render_pkg["render"] |
| gt_image = viewpoint_cam.original_image.cuda() |
| |
| Ll1 = l1_loss(image, gt_image) |
| ssim_value = ssim(image, gt_image) |
| scaling_reg = render_pkg["scaling"].prod(dim=1).mean() |
| |
| loss_target = (1.0 - self.opt.lambda_dssim) * Ll1 |
| loss_parasitic = self.opt.lambda_dssim * (1.0 - ssim_value) + 0.01 * scaling_reg |
| |
| bit_per_param = render_pkg.get("bit_per_param", None) |
| bit_hash_grid_val = 0.0 |
| if bit_per_param is not None: |
| _, bit_hash_grid, _, _ = get_binary_vxl_size((self.gaussians.get_encoding_params() + 1) / 2) |
| bit_hash_grid_val = float(bit_hash_grid) |
| denom = self.gaussians._anchor.shape[0] * (self.gaussians.feat_dim + 6 + 3 * self.gaussians.n_offsets) |
| loss_parasitic = loss_parasitic + self.args.lmbda * (bit_per_param + bit_hash_grid / denom) |
| |
| loss = loss_target + loss_parasitic |
|
|
| grad_cos_sim = 0.0 |
| parasitic_ratio = 0.0 |
|
|
| if self.track_decoupling and step % 100 == 0: |
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss_target.backward(retain_graph=True) |
| grad_t = self.gaussians._anchor.grad.clone() if self.gaussians._anchor.grad is not None else torch.zeros_like(self.gaussians._anchor) |
| |
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss_parasitic.backward(retain_graph=True) |
| grad_p = self.gaussians._anchor.grad.clone() if self.gaussians._anchor.grad is not None else torch.zeros_like(self.gaussians._anchor) |
| |
| state = self.gaussians.optimizer.state.get(self.gaussians._anchor, {}) |
| v = state.get("exp_avg_sq", torch.ones_like(grad_t) * 1e-8) |
| lr = 0.0 |
| for pg in self.gaussians.optimizer.param_groups: |
| if pg["name"] == "anchor": |
| lr = pg["lr"] |
| break |
| |
| u_t = (lr / (torch.sqrt(v) + 1e-8)) * grad_t |
| u_p = (lr / (torch.sqrt(v) + 1e-8)) * grad_p |
| |
| valid_mask = (torch.norm(u_t, dim=1) > 0) & (torch.norm(u_p, dim=1) > 0) |
| if valid_mask.any(): |
| grad_cos_sim = float(F.cosine_similarity(u_t[valid_mask], u_p[valid_mask], dim=1).mean()) |
| parasitic_ratio = float(torch.norm(u_p, dim=1).mean() / (torch.norm(u_t, dim=1).mean() + 1e-7)) |
| |
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| else: |
| loss.backward() |
|
|
| with torch.no_grad(): |
| if step < self.opt.update_until and step > self.opt.start_stat: |
| self.gaussians.training_statis(render_pkg["viewspace_points"], render_pkg["neural_opacity"], render_pkg["visibility_filter"], render_pkg["selection_mask"], voxel_visible_mask) |
| if step not in range(3000, 4000): |
| if step > self.opt.update_from and step % self.opt.update_interval == 0: |
| self.gaussians.adjust_anchor(check_interval=self.opt.update_interval, success_threshold=self.opt.success_threshold, grad_threshold=self.opt.densify_grad_threshold, min_opacity=self.opt.min_opacity) |
| elif step == self.opt.update_until: |
| if hasattr(self.gaussians, "opacity_accum"): |
| del self.gaussians.opacity_accum |
| if hasattr(self.gaussians, "offset_gradient_accum"): |
| del self.gaussians.offset_gradient_accum |
| if hasattr(self.gaussians, "offset_denom"): |
| del self.gaussians.offset_denom |
| torch.cuda.empty_cache() |
|
|
| if step < self.opt.iterations: |
| self.gaussians.optimizer.step() |
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
|
|
| num_gaussians = self.gaussians.get_anchor.shape[0] * self.gaussians.n_offsets |
| metrics = { |
| "loss": float(loss), "loss_l1": float(loss_target), "loss_ssim": float(loss_parasitic), |
| "num_gaussians": int(num_gaussians), "delta_N": int(num_gaussians - self.last_n_gaussians), |
| "peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)), |
| "grad_cos_sim": float(grad_cos_sim), "parasitic_ratio": float(parasitic_ratio), |
| "bit_per_param": float(bit_per_param) if bit_per_param is not None else 0.0, |
| "bit_hash_grid": float(bit_hash_grid_val) |
| } |
| self.last_n_gaussians = num_gaussians |
| |
| return metrics, {} |
| |
| def render(self, camera): |
| with torch.no_grad(): |
| voxel_visible_mask = prefilter_voxel(camera, self.gaussians, self.pipe, self.background) |
| render_pkg = native_render(camera, self.gaussians, self.pipe, self.background, visible_mask=voxel_visible_mask) |
| return {"image": render_pkg["render"], "depth": render_pkg.get("depth", None)} |
|
|
| def save(self, save_dir, step): |
| self.scene.save(step) |
|
|
| def load(self, model_path, iteration): |
| self.scene = Scene(self.dataset, self.gaussians, load_iteration=iteration, shuffle=False) |
| self.gaussians.eval() |
|
|
| def get_spatial_centers(self): |
| return self.gaussians.get_anchor |
|
|
| def compute_physical_metrics(self, cameras=None): |
| metrics = {} |
| with torch.no_grad(): |
| cam = cameras[0] if cameras and len(cameras) > 0 else self.viewpoint_stack[0] |
| xyz, color, opacity, scaling, rot, _ = generate_neural_gaussians(cam, self.gaussians, visible_mask=None, is_training=False, step=30000) |
| |
| scales_2d = scaling[:, :2] if scaling.dim() > 1 and scaling.shape[1] >= 2 else scaling.unsqueeze(-1).expand(-1, 2) |
| max_s, _ = torch.max(scales_2d, dim=1) |
| min_s, _ = torch.min(scales_2d, dim=1) |
| gamma = max_s / (min_s + 1e-7) |
| |
| metrics["gamma_median"] = float(torch.median(gamma)) |
| metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90)) |
| metrics["scale_mean"] = float(torch.mean(scales_2d)) |
| metrics["alpha_mean"] = float(torch.mean(opacity)) |
| |
| if cameras is not None and len(cameras) > 0: |
| view_dirs = [] |
| for c in cameras: |
| view_dirs.append(c.world_view_transform[:3, 2].tolist()) |
| view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1) |
|
|
| rots = F.normalize(rot.clone(), dim=1) |
| w, x, y, z = rots.unbind(dim=-1) |
| normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1) |
| |
| max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1) |
| metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean()) |
|
|
| return metrics |
|
|
| def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor: |
| with torch.no_grad(): |
| v = query_points.shape[0] |
| densities = torch.zeros(v, device="cuda") |
| cam = cameras[0] if cameras and len(cameras) > 0 else self.viewpoint_stack[0] |
| xyz, color, opacity, scaling, rot, _ = generate_neural_gaussians(cam, self.gaussians, visible_mask=None, is_training=False, step=30000) |
| |
| opacities = opacity.squeeze() |
| sigma_sq = (scaling[:, :2].max(dim=1)[0].pow(2)) if scaling.shape[1] >= 2 else scaling.squeeze().pow(2) |
| |
| n_gaussians = xyz.shape[0] |
| chunk_size = max(1, 30_000_000 // (n_gaussians + 1)) |
| for i in range(0, v, chunk_size): |
| end = min(i + chunk_size, v) |
| dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2) |
| weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7)) |
| densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1) |
| return densities |
|
|