SplatAtlas / methods /wrapper_hac_plus.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
12.4 kB
import os
import sys
import random
import torch
import numpy as np
import torch.nn.functional as F
from argparse import ArgumentParser
from core.registry import register_method
from core.base_method import BaseMethod
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../hac_plus_official')))
from utils.loss_utils import l1_loss, ssim
from gaussian_renderer import prefilter_voxel, render as native_render, generate_neural_gaussians
from scene import Scene, GaussianModel
from arguments import ModelParams, PipelineParams, OptimizationParams
from utils.encodings import get_binary_vxl_size
@register_method("hac_plus")
class HACPlusWrapper(BaseMethod):
def __init__(self, dataset_config, hyperparams):
self.parser = ArgumentParser()
self.lp = ModelParams(self.parser)
self.op = OptimizationParams(self.parser)
self.pp = PipelineParams(self.parser)
self.args = self.parser.parse_args([])
self.args.source_path = dataset_config["source_path"]
self.args.model_path = dataset_config["model_path"]
self.args.eval = True
self.args.resolution = dataset_config.get("resolution", 1)
self.track_decoupling = hyperparams.get("track_decoupling", False)
self.args.n_features = 4
self.args.log2 = 13
self.args.log2_2D = 15
self.args.lmbda = 0.001
self.dataset = self.lp.extract(self.args)
self.opt = self.op.extract(self.args)
self.pipe = self.pp.extract(self.args)
is_synthetic_nerf = os.path.exists(os.path.join(self.args.source_path, "transforms_train.json"))
self.gaussians = GaussianModel(
self.dataset.feat_dim,
self.dataset.n_offsets,
self.dataset.voxel_size,
self.dataset.update_depth,
self.dataset.update_init_factor,
self.dataset.update_hierachy_factor,
self.dataset.use_feat_bank,
n_features_per_level=self.args.n_features,
log2_hashmap_size=self.args.log2,
log2_hashmap_size_2D=self.args.log2_2D,
is_synthetic_nerf=is_synthetic_nerf,
)
# INJECTED_RES_FIX begin
import sys as _sys
_scene, _explicit_res = None, None
for _i, _a in enumerate(_sys.argv[:-1]):
_v = _sys.argv[_i + 1]
if _a == "--scene": _scene = _v
elif _a == "--source_path": _scene = _v.rstrip("/").split("/")[-1]
elif _a == "--resolution":
try: _explicit_res = int(_v)
except: pass
_OUTDOOR_360 = {"bicycle", "flowers", "garden", "stump", "treehill"}
if _explicit_res is not None and _explicit_res > 0:
_res = _explicit_res
elif _scene is not None:
_res = 4 if _scene in _OUTDOOR_360 else 2
else:
_res = None
try:
if _res is not None:
self.dataset.resolution = _res
print("[res-fix] scene=%s explicit=%s -> res=%s (%s)" % (_scene, _explicit_res, _res, __file__))
except Exception as _e:
print("[res-fix] FAILED:", _e)
# INJECTED_RES_FIX end
self.scene = Scene(self.dataset, self.gaussians)
self.gaussians.update_anchor_bound()
self.gaussians.training_setup(self.opt)
bg_color = [1, 1, 1] if self.dataset.white_background else [0, 0, 0]
self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
self.viewpoint_stack = self.scene.getTrainCameras().copy()
self.last_n_gaussians = self.gaussians.get_anchor.shape[0] * self.gaussians.n_offsets
def train_iteration(self, step):
self.gaussians.update_learning_rate(step)
if not self.viewpoint_stack:
self.viewpoint_stack = self.scene.getTrainCameras().copy()
viewpoint_cam = self.viewpoint_stack.pop(random.randint(0, len(self.viewpoint_stack) - 1))
voxel_visible_mask = prefilter_voxel(viewpoint_cam, self.gaussians, self.pipe, self.background)
retain_grad = (step < self.opt.update_until and step >= 0)
render_pkg = native_render(viewpoint_cam, self.gaussians, self.pipe, self.background, visible_mask=voxel_visible_mask, retain_grad=retain_grad, step=step)
image = render_pkg["render"]
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
ssim_value = ssim(image, gt_image)
scaling_reg = render_pkg["scaling"].prod(dim=1).mean()
loss_target = (1.0 - self.opt.lambda_dssim) * Ll1
loss_parasitic = self.opt.lambda_dssim * (1.0 - ssim_value) + 0.01 * scaling_reg
bit_per_param = render_pkg.get("bit_per_param", None)
bit_hash_grid_val = 0.0
if bit_per_param is not None:
_, bit_hash_grid, _, _ = get_binary_vxl_size((self.gaussians.get_encoding_params() + 1) / 2)
bit_hash_grid_val = float(bit_hash_grid)
denom = self.gaussians._anchor.shape[0] * (self.gaussians.feat_dim + 6 + 3 * self.gaussians.n_offsets)
loss_parasitic = loss_parasitic + self.args.lmbda * (bit_per_param + bit_hash_grid / denom)
loss = loss_target + loss_parasitic
grad_cos_sim = 0.0
parasitic_ratio = 0.0
if self.track_decoupling and step % 100 == 0:
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
grad_t = self.gaussians._anchor.grad.clone() if self.gaussians._anchor.grad is not None else torch.zeros_like(self.gaussians._anchor)
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_parasitic.backward(retain_graph=True)
grad_p = self.gaussians._anchor.grad.clone() if self.gaussians._anchor.grad is not None else torch.zeros_like(self.gaussians._anchor)
state = self.gaussians.optimizer.state.get(self.gaussians._anchor, {})
v = state.get("exp_avg_sq", torch.ones_like(grad_t) * 1e-8)
lr = 0.0
for pg in self.gaussians.optimizer.param_groups:
if pg["name"] == "anchor":
lr = pg["lr"]
break
u_t = (lr / (torch.sqrt(v) + 1e-8)) * grad_t
u_p = (lr / (torch.sqrt(v) + 1e-8)) * grad_p
valid_mask = (torch.norm(u_t, dim=1) > 0) & (torch.norm(u_p, dim=1) > 0)
if valid_mask.any():
grad_cos_sim = float(F.cosine_similarity(u_t[valid_mask], u_p[valid_mask], dim=1).mean())
parasitic_ratio = float(torch.norm(u_p, dim=1).mean() / (torch.norm(u_t, dim=1).mean() + 1e-7))
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss.backward()
else:
loss.backward()
with torch.no_grad():
if step < self.opt.update_until and step > self.opt.start_stat:
self.gaussians.training_statis(render_pkg["viewspace_points"], render_pkg["neural_opacity"], render_pkg["visibility_filter"], render_pkg["selection_mask"], voxel_visible_mask)
if step not in range(3000, 4000):
if step > self.opt.update_from and step % self.opt.update_interval == 0:
self.gaussians.adjust_anchor(check_interval=self.opt.update_interval, success_threshold=self.opt.success_threshold, grad_threshold=self.opt.densify_grad_threshold, min_opacity=self.opt.min_opacity)
elif step == self.opt.update_until:
if hasattr(self.gaussians, "opacity_accum"):
del self.gaussians.opacity_accum
if hasattr(self.gaussians, "offset_gradient_accum"):
del self.gaussians.offset_gradient_accum
if hasattr(self.gaussians, "offset_denom"):
del self.gaussians.offset_denom
torch.cuda.empty_cache()
if step < self.opt.iterations:
self.gaussians.optimizer.step()
self.gaussians.optimizer.zero_grad(set_to_none=True)
num_gaussians = self.gaussians.get_anchor.shape[0] * self.gaussians.n_offsets
metrics = {
"loss": float(loss), "loss_l1": float(loss_target), "loss_ssim": float(loss_parasitic),
"num_gaussians": int(num_gaussians), "delta_N": int(num_gaussians - self.last_n_gaussians),
"peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)),
"grad_cos_sim": float(grad_cos_sim), "parasitic_ratio": float(parasitic_ratio),
"bit_per_param": float(bit_per_param) if bit_per_param is not None else 0.0,
"bit_hash_grid": float(bit_hash_grid_val)
}
self.last_n_gaussians = num_gaussians
return metrics, {}
def render(self, camera):
with torch.no_grad():
voxel_visible_mask = prefilter_voxel(camera, self.gaussians, self.pipe, self.background)
render_pkg = native_render(camera, self.gaussians, self.pipe, self.background, visible_mask=voxel_visible_mask)
return {"image": render_pkg["render"], "depth": render_pkg.get("depth", None)}
def save(self, save_dir, step):
self.scene.save(step)
def load(self, model_path, iteration):
self.scene = Scene(self.dataset, self.gaussians, load_iteration=iteration, shuffle=False)
self.gaussians.eval()
def get_spatial_centers(self):
return self.gaussians.get_anchor
def compute_physical_metrics(self, cameras=None):
metrics = {}
with torch.no_grad():
cam = cameras[0] if cameras and len(cameras) > 0 else self.viewpoint_stack[0]
xyz, color, opacity, scaling, rot, _ = generate_neural_gaussians(cam, self.gaussians, visible_mask=None, is_training=False, step=30000)
scales_2d = scaling[:, :2] if scaling.dim() > 1 and scaling.shape[1] >= 2 else scaling.unsqueeze(-1).expand(-1, 2)
max_s, _ = torch.max(scales_2d, dim=1)
min_s, _ = torch.min(scales_2d, dim=1)
gamma = max_s / (min_s + 1e-7)
metrics["gamma_median"] = float(torch.median(gamma))
metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90))
metrics["scale_mean"] = float(torch.mean(scales_2d))
metrics["alpha_mean"] = float(torch.mean(opacity))
if cameras is not None and len(cameras) > 0:
view_dirs = []
for c in cameras:
view_dirs.append(c.world_view_transform[:3, 2].tolist())
view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1)
rots = F.normalize(rot.clone(), dim=1)
w, x, y, z = rots.unbind(dim=-1)
normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1)
max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1)
metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean())
return metrics
def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor:
with torch.no_grad():
v = query_points.shape[0]
densities = torch.zeros(v, device="cuda")
cam = cameras[0] if cameras and len(cameras) > 0 else self.viewpoint_stack[0]
xyz, color, opacity, scaling, rot, _ = generate_neural_gaussians(cam, self.gaussians, visible_mask=None, is_training=False, step=30000)
opacities = opacity.squeeze()
sigma_sq = (scaling[:, :2].max(dim=1)[0].pow(2)) if scaling.shape[1] >= 2 else scaling.squeeze().pow(2)
n_gaussians = xyz.shape[0]
chunk_size = max(1, 30_000_000 // (n_gaussians + 1))
for i in range(0, v, chunk_size):
end = min(i + chunk_size, v)
dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2)
weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7))
densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1)
return densities