| import os |
| import sys |
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| from omegaconf import OmegaConf |
| from hydra import initialize_config_dir, compose |
| from hydra.core.global_hydra import GlobalHydra |
|
|
| from core.registry import register_method |
| from core.base_method import BaseMethod |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../3dgrut'))) |
| from threedgrut.trainer import Trainer3DGRUT |
| from threedgrut.optimizers import SelectiveAdam |
|
|
| class MockScene: |
| def getTrainCameras(self): |
| return [] |
| def getTestCameras(self): |
| return [] |
|
|
| @register_method("3dgut") |
| class ThreeDGUTWrapper(BaseMethod): |
| def __init__(self, dataset_config, hyperparams): |
| self.track_decoupling = hyperparams.get("track_decoupling", False) |
| |
| GlobalHydra.instance().clear() |
| config_dir = os.path.abspath("/root/autodl-tmp/3dgrut/configs") |
| |
| with initialize_config_dir(version_base=None, config_dir=config_dir): |
| ds_colmap = "dataset=colmap" |
| try: |
| compose(config_name="base_gs", overrides=["dataset=colmap"]) |
| except Exception: |
| ds_colmap = "+dataset=colmap" |
| |
| ds_nerf = "dataset=nerf" |
| try: |
| compose(config_name="base_gs", overrides=["dataset=nerf"]) |
| except Exception: |
| ds_nerf = "+dataset=nerf" |
|
|
| overrides = [ |
| "+render=3dgut", |
| f"++dataset.path={dataset_config['source_path']}", |
| f"++path={dataset_config['source_path']}", |
| f"++out_dir={dataset_config['model_path']}", |
| "++experiment_name=benchmark", |
| "++n_iterations=30000", |
| "++loss.use_l1=True", |
| "++loss.lambda_l1=0.8", |
| "++loss.use_ssim=True", |
| "++loss.lambda_ssim=0.2", |
| "++with_gui=False", |
| "++with_viser_gui=False", |
| ] |
| |
| if "Synthetic_NeRF" in str(dataset_config["source_path"]): |
| overrides.append(ds_nerf) |
| else: |
| overrides.append(ds_colmap) |
| overrides.append(f"++dataset.downsample_factor={dataset_config.get('resolution', 1)}") |
| |
| conf = compose(config_name="base_gs", overrides=overrides) |
| |
| OmegaConf.set_struct(conf, False) |
| |
| if "initialization" not in conf: |
| conf.initialization = OmegaConf.create({}) |
| |
| if "Synthetic_NeRF" in str(dataset_config["source_path"]): |
| conf.initialization.method = "random" |
| conf.initialization.num_gaussians = 100000 |
| else: |
| conf.initialization.method = "colmap" |
| |
| conf.initialization.use_observation_points = False |
| |
| if not hasattr(conf.dataset, "test_split_interval"): |
| conf.dataset.test_split_interval = 8 |
| if not hasattr(conf.dataset, "eval"): |
| conf.dataset.eval = True |
| |
| self.trainer = Trainer3DGRUT(conf) |
| self.model = self.trainer.model |
| self.data_iter = iter(self.trainer.train_dataloader) |
| self.last_n_gaussians = self.model.num_gaussians |
| self.scene = MockScene() |
|
|
| def train_iteration(self, step): |
| try: |
| batch = next(self.data_iter) |
| except StopIteration: |
| self.data_iter = iter(self.trainer.train_dataloader) |
| batch = next(self.data_iter) |
|
|
| gpu_batch = self.trainer.train_dataset.get_gpu_batch_with_intrinsics(batch) |
| outputs = self.model(gpu_batch, train=True, frame_id=step) |
| batch_losses = self.trainer.get_losses(gpu_batch, outputs) |
| |
| loss_target = batch_losses.get("l1_loss", torch.tensor(0.0, device="cuda")) |
| loss_parasitic = batch_losses.get("ssim_loss", torch.tensor(0.0, device="cuda")) |
| loss = batch_losses.get("total_loss", loss_target + loss_parasitic) |
|
|
| self.trainer.strategy.pre_backward( |
| step=step, |
| scene_extent=self.trainer.scene_extent, |
| train_dataset=self.trainer.train_dataset, |
| batch=gpu_batch, |
| writer=None, |
| ) |
|
|
| grad_cos_sim = 0.0 |
| parasitic_ratio = 0.0 |
|
|
| if self.track_decoupling and step % 100 == 0: |
| self.model.optimizer.zero_grad(set_to_none=True) |
| loss_target.backward(retain_graph=True) |
| grad_target = self.model.positions.grad.clone() if self.model.positions.grad is not None else torch.zeros_like(self.model.positions) |
| |
| self.model.optimizer.zero_grad(set_to_none=True) |
| loss_parasitic.backward(retain_graph=True) |
| grad_parasitic = self.model.positions.grad.clone() if self.model.positions.grad is not None else torch.zeros_like(self.model.positions) |
| |
| state = self.model.optimizer.state.get(self.model.positions, None) |
| if state is not None and "exp_avg_sq" in state: |
| v_t = state["exp_avg_sq"] |
| lr = self.model.optimizer.param_groups[0]["lr"] |
| u_t = (lr / (torch.sqrt(v_t) + 1e-8)) * grad_target |
| u_p = (lr / (torch.sqrt(v_t) + 1e-8)) * grad_parasitic |
| else: |
| u_t = grad_target |
| u_p = grad_parasitic |
|
|
| valid_mask = (torch.norm(u_t, dim=1) > 0) & (torch.norm(u_p, dim=1) > 0) |
| if valid_mask.any(): |
| grad_cos_sim = float(F.cosine_similarity(u_t[valid_mask], u_p[valid_mask], dim=1).mean()) |
| parasitic_ratio = float(torch.norm(u_p, dim=1).mean() / (torch.norm(u_t, dim=1).mean() + 1e-7)) |
| |
| self.model.optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| else: |
| loss.backward() |
|
|
| self.trainer.strategy.post_backward( |
| step=step, |
| scene_extent=self.trainer.scene_extent, |
| train_dataset=self.trainer.train_dataset, |
| batch=gpu_batch, |
| writer=None, |
| ) |
|
|
| if "mog_visibility" in outputs and isinstance(self.model.optimizer, SelectiveAdam): |
| self.model.optimizer.step(outputs["mog_visibility"]) |
| else: |
| self.model.optimizer.step() |
| |
| self.model.optimizer.zero_grad() |
| self.model.scheduler_step(step) |
| |
| self.trainer.strategy.post_optimizer_step( |
| step=step, |
| scene_extent=self.trainer.scene_extent, |
| train_dataset=self.trainer.train_dataset, |
| batch=gpu_batch, |
| writer=None, |
| ) |
|
|
| num_gaussians = self.model.num_gaussians |
| metrics = { |
| "loss": float(loss), |
| "loss_l1": float(loss_target), |
| "loss_ssim": float(loss_parasitic), |
| "num_gaussians": int(num_gaussians), |
| "delta_N": int(num_gaussians - self.last_n_gaussians), |
| "peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)), |
| "grad_cos_sim": float(grad_cos_sim), |
| "parasitic_ratio": float(parasitic_ratio), |
| "hits_mean": float(outputs.get("hits_count", torch.tensor(0.0)).float().mean()) |
| } |
| self.last_n_gaussians = num_gaussians |
|
|
| histograms = {} |
| if step % 1000 == 0: |
| histograms["opacity"] = self.model.get_density().clone().detach() |
| scales = self.model.get_scale().clone().detach() |
| histograms["scaling"] = scales |
| scales_2d = scales[:, :2] if scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2) |
| gamma = scales_2d.max(dim=-1)[0] / (scales_2d.min(dim=-1)[0] + 1e-7) |
| histograms["anisotropy"] = gamma |
| histograms["sh_dc_mag"] = self.model.features_albedo.detach().norm(dim=-1) |
|
|
| return metrics, histograms |
|
|
| def render(self, camera): |
| with torch.no_grad(): |
| gpu_batch = self.trainer.val_dataset.get_gpu_batch_with_intrinsics(camera) |
| outputs = self.model(gpu_batch, train=False) |
| return {"image": outputs["pred_rgb"], "depth": outputs.get("pred_dist", None)} |
|
|
| def save(self, save_dir, step): |
| self.trainer.save_checkpoint(last_checkpoint=(step >= 30000)) |
|
|
| def load(self, model_path, iteration): |
| ckpt_path = os.path.join(model_path, f"ours_{iteration}", f"ckpt_{iteration}.pt") |
| if not os.path.exists(ckpt_path): |
| ckpt_path = os.path.join(model_path, "ckpt_last.pt") |
| checkpoint = torch.load(ckpt_path, map_location="cuda", weights_only=False) |
| self.model.init_from_checkpoint(checkpoint, setup_optimizer=False) |
|
|
| def get_spatial_centers(self): |
| return self.model.positions |
|
|
| def compute_physical_metrics(self, cameras=None): |
| metrics = {} |
| with torch.no_grad(): |
| scales = self.model.get_scale() |
| scales_2d = scales[:, :2] if scales.dim() > 1 and scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2) |
| max_S, _ = torch.max(scales_2d, dim=1) |
| min_S, _ = torch.min(scales_2d, dim=1) |
| gamma = max_S / (min_S + 1e-7) |
| metrics["gamma_median"] = float(torch.median(gamma)) |
| metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90)) |
| metrics["scale_mean"] = float(torch.mean(scales_2d)) |
| metrics["alpha_mean"] = float(torch.mean(self.model.get_density())) |
| dc = self.model.features_albedo |
| rest = self.model.features_specular |
| if rest is not None and rest.shape[1] > 0: |
| metrics["sh_energy_ratio"] = float(rest.norm(dim=-1).mean() / (dc.norm(dim=-1).mean() + 1e-7)) |
| if cameras is not None and len(cameras) > 0: |
| view_dirs = [] |
| for c in cameras: |
| view_dirs.append(c.world_view_transform[:3, 2].tolist()) |
| view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1) |
| rots = F.normalize(self.model.rotation.clone(), dim=1) |
| w, x, y, z = rots.unbind(dim=-1) |
| normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1) |
| max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1) |
| metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean()) |
| return metrics |
|
|
| def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor: |
| with torch.no_grad(): |
| V = query_points.shape[0] |
| densities = torch.zeros(V, device="cuda") |
| xyz = self.model.positions |
| opacities = self.model.get_density().squeeze() |
| scales = self.model.get_scale() |
| sigma_sq = scales[:, :2].max(dim=1)[0].pow(2) if scales.shape[1] >= 2 else scales.squeeze().pow(2) |
| N_gaussians = xyz.shape[0] |
| chunk_size = max(1, 30_000_000 // (N_gaussians + 1)) |
| for i in range(0, V, chunk_size): |
| end = min(i + chunk_size, V) |
| dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2) |
| weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7)) |
| densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1) |
| return densities |
|
|