SplatAtlas / methods /wrapper_3dgut.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
11.4 kB
import os
import sys
import torch
import numpy as np
import torch.nn.functional as F
from omegaconf import OmegaConf
from hydra import initialize_config_dir, compose
from hydra.core.global_hydra import GlobalHydra
from core.registry import register_method
from core.base_method import BaseMethod
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../3dgrut')))
from threedgrut.trainer import Trainer3DGRUT
from threedgrut.optimizers import SelectiveAdam
class MockScene:
def getTrainCameras(self):
return []
def getTestCameras(self):
return []
@register_method("3dgut")
class ThreeDGUTWrapper(BaseMethod):
def __init__(self, dataset_config, hyperparams):
self.track_decoupling = hyperparams.get("track_decoupling", False)
GlobalHydra.instance().clear()
config_dir = os.path.abspath("/root/autodl-tmp/3dgrut/configs")
with initialize_config_dir(version_base=None, config_dir=config_dir):
ds_colmap = "dataset=colmap"
try:
compose(config_name="base_gs", overrides=["dataset=colmap"])
except Exception:
ds_colmap = "+dataset=colmap"
ds_nerf = "dataset=nerf"
try:
compose(config_name="base_gs", overrides=["dataset=nerf"])
except Exception:
ds_nerf = "+dataset=nerf"
overrides = [
"+render=3dgut",
f"++dataset.path={dataset_config['source_path']}",
f"++path={dataset_config['source_path']}",
f"++out_dir={dataset_config['model_path']}",
"++experiment_name=benchmark",
"++n_iterations=30000",
"++loss.use_l1=True",
"++loss.lambda_l1=0.8",
"++loss.use_ssim=True",
"++loss.lambda_ssim=0.2",
"++with_gui=False",
"++with_viser_gui=False",
]
if "Synthetic_NeRF" in str(dataset_config["source_path"]):
overrides.append(ds_nerf)
else:
overrides.append(ds_colmap)
overrides.append(f"++dataset.downsample_factor={dataset_config.get('resolution', 1)}")
conf = compose(config_name="base_gs", overrides=overrides)
OmegaConf.set_struct(conf, False)
if "initialization" not in conf:
conf.initialization = OmegaConf.create({})
if "Synthetic_NeRF" in str(dataset_config["source_path"]):
conf.initialization.method = "random"
conf.initialization.num_gaussians = 100000
else:
conf.initialization.method = "colmap"
conf.initialization.use_observation_points = False
if not hasattr(conf.dataset, "test_split_interval"):
conf.dataset.test_split_interval = 8
if not hasattr(conf.dataset, "eval"):
conf.dataset.eval = True
self.trainer = Trainer3DGRUT(conf)
self.model = self.trainer.model
self.data_iter = iter(self.trainer.train_dataloader)
self.last_n_gaussians = self.model.num_gaussians
self.scene = MockScene()
def train_iteration(self, step):
try:
batch = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.trainer.train_dataloader)
batch = next(self.data_iter)
gpu_batch = self.trainer.train_dataset.get_gpu_batch_with_intrinsics(batch)
outputs = self.model(gpu_batch, train=True, frame_id=step)
batch_losses = self.trainer.get_losses(gpu_batch, outputs)
loss_target = batch_losses.get("l1_loss", torch.tensor(0.0, device="cuda"))
loss_parasitic = batch_losses.get("ssim_loss", torch.tensor(0.0, device="cuda"))
loss = batch_losses.get("total_loss", loss_target + loss_parasitic)
self.trainer.strategy.pre_backward(
step=step,
scene_extent=self.trainer.scene_extent,
train_dataset=self.trainer.train_dataset,
batch=gpu_batch,
writer=None,
)
grad_cos_sim = 0.0
parasitic_ratio = 0.0
if self.track_decoupling and step % 100 == 0:
self.model.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
grad_target = self.model.positions.grad.clone() if self.model.positions.grad is not None else torch.zeros_like(self.model.positions)
self.model.optimizer.zero_grad(set_to_none=True)
loss_parasitic.backward(retain_graph=True)
grad_parasitic = self.model.positions.grad.clone() if self.model.positions.grad is not None else torch.zeros_like(self.model.positions)
state = self.model.optimizer.state.get(self.model.positions, None)
if state is not None and "exp_avg_sq" in state:
v_t = state["exp_avg_sq"]
lr = self.model.optimizer.param_groups[0]["lr"]
u_t = (lr / (torch.sqrt(v_t) + 1e-8)) * grad_target
u_p = (lr / (torch.sqrt(v_t) + 1e-8)) * grad_parasitic
else:
u_t = grad_target
u_p = grad_parasitic
valid_mask = (torch.norm(u_t, dim=1) > 0) & (torch.norm(u_p, dim=1) > 0)
if valid_mask.any():
grad_cos_sim = float(F.cosine_similarity(u_t[valid_mask], u_p[valid_mask], dim=1).mean())
parasitic_ratio = float(torch.norm(u_p, dim=1).mean() / (torch.norm(u_t, dim=1).mean() + 1e-7))
self.model.optimizer.zero_grad(set_to_none=True)
loss.backward()
else:
loss.backward()
self.trainer.strategy.post_backward(
step=step,
scene_extent=self.trainer.scene_extent,
train_dataset=self.trainer.train_dataset,
batch=gpu_batch,
writer=None,
)
if "mog_visibility" in outputs and isinstance(self.model.optimizer, SelectiveAdam):
self.model.optimizer.step(outputs["mog_visibility"])
else:
self.model.optimizer.step()
self.model.optimizer.zero_grad()
self.model.scheduler_step(step)
self.trainer.strategy.post_optimizer_step(
step=step,
scene_extent=self.trainer.scene_extent,
train_dataset=self.trainer.train_dataset,
batch=gpu_batch,
writer=None,
)
num_gaussians = self.model.num_gaussians
metrics = {
"loss": float(loss),
"loss_l1": float(loss_target),
"loss_ssim": float(loss_parasitic),
"num_gaussians": int(num_gaussians),
"delta_N": int(num_gaussians - self.last_n_gaussians),
"peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)),
"grad_cos_sim": float(grad_cos_sim),
"parasitic_ratio": float(parasitic_ratio),
"hits_mean": float(outputs.get("hits_count", torch.tensor(0.0)).float().mean())
}
self.last_n_gaussians = num_gaussians
histograms = {}
if step % 1000 == 0:
histograms["opacity"] = self.model.get_density().clone().detach()
scales = self.model.get_scale().clone().detach()
histograms["scaling"] = scales
scales_2d = scales[:, :2] if scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2)
gamma = scales_2d.max(dim=-1)[0] / (scales_2d.min(dim=-1)[0] + 1e-7)
histograms["anisotropy"] = gamma
histograms["sh_dc_mag"] = self.model.features_albedo.detach().norm(dim=-1)
return metrics, histograms
def render(self, camera):
with torch.no_grad():
gpu_batch = self.trainer.val_dataset.get_gpu_batch_with_intrinsics(camera)
outputs = self.model(gpu_batch, train=False)
return {"image": outputs["pred_rgb"], "depth": outputs.get("pred_dist", None)}
def save(self, save_dir, step):
self.trainer.save_checkpoint(last_checkpoint=(step >= 30000))
def load(self, model_path, iteration):
ckpt_path = os.path.join(model_path, f"ours_{iteration}", f"ckpt_{iteration}.pt")
if not os.path.exists(ckpt_path):
ckpt_path = os.path.join(model_path, "ckpt_last.pt")
checkpoint = torch.load(ckpt_path, map_location="cuda", weights_only=False)
self.model.init_from_checkpoint(checkpoint, setup_optimizer=False)
def get_spatial_centers(self):
return self.model.positions
def compute_physical_metrics(self, cameras=None):
metrics = {}
with torch.no_grad():
scales = self.model.get_scale()
scales_2d = scales[:, :2] if scales.dim() > 1 and scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2)
max_S, _ = torch.max(scales_2d, dim=1)
min_S, _ = torch.min(scales_2d, dim=1)
gamma = max_S / (min_S + 1e-7)
metrics["gamma_median"] = float(torch.median(gamma))
metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90))
metrics["scale_mean"] = float(torch.mean(scales_2d))
metrics["alpha_mean"] = float(torch.mean(self.model.get_density()))
dc = self.model.features_albedo
rest = self.model.features_specular
if rest is not None and rest.shape[1] > 0:
metrics["sh_energy_ratio"] = float(rest.norm(dim=-1).mean() / (dc.norm(dim=-1).mean() + 1e-7))
if cameras is not None and len(cameras) > 0:
view_dirs = []
for c in cameras:
view_dirs.append(c.world_view_transform[:3, 2].tolist())
view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1)
rots = F.normalize(self.model.rotation.clone(), dim=1)
w, x, y, z = rots.unbind(dim=-1)
normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1)
max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1)
metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean())
return metrics
def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor:
with torch.no_grad():
V = query_points.shape[0]
densities = torch.zeros(V, device="cuda")
xyz = self.model.positions
opacities = self.model.get_density().squeeze()
scales = self.model.get_scale()
sigma_sq = scales[:, :2].max(dim=1)[0].pow(2) if scales.shape[1] >= 2 else scales.squeeze().pow(2)
N_gaussians = xyz.shape[0]
chunk_size = max(1, 30_000_000 // (N_gaussians + 1))
for i in range(0, V, chunk_size):
end = min(i + chunk_size, V)
dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2)
weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7))
densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1)
return densities