SplatAtlas / methods /wrapper_edgeloss.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
14.3 kB
import os
import sys
import math
import time
import random
import torch
import numpy as np
import torch.nn.functional as F
from argparse import ArgumentParser
from core.registry import register_method
from core.base_method import BaseMethod
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../3dgsAtlas_official')))
from utils.loss_utils import l1_loss, ssim
from scene import Scene, GaussianModel
from arguments import ModelParams, PipelineParams, OptimizationParams
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
def native_render(viewpoint_camera, pc, pipe, bg_color, scaling_modifier=1.0):
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity
scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
cov3D_precomp = pc.get_covariance(scaling_modifier)
else:
scales = pc.get_scaling
rotations = pc.get_rotation
shs = pc.get_features
rendered_image, radii = rasterizer(
means3D=means3D,
means2D=means2D,
shs=shs,
colors_precomp=None,
opacities=opacity,
scales=scales,
rotations=rotations,
cov3D_precomp=cov3D_precomp
)
rendered_image = rendered_image.clamp(0, 1)
return {
"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter": radii > 0,
"radii": radii,
}
def _sobel(img):
"""Single-scale luma Sobel response. NO stop-gradient — keeps autograd flowing."""
luma = (0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]).unsqueeze(0).unsqueeze(0)
wx = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32, device=img.device).view(1, 1, 3, 3)
wy = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device=img.device).view(1, 1, 3, 3)
gx = F.conv2d(luma, wx, padding=1)
gy = F.conv2d(luma, wy, padding=1)
return torch.sqrt(gx**2 + gy**2 + 1e-8)
@register_method("edgeloss")
class EdgeLossWrapper(BaseMethod):
"""
Ablation baseline answering Reviewer vm48:
"direct edge-weighted photometric losses or alternative
gradient-domain weighting strategies"
Key contrast with SGF: NO stop-gradient on the weight map.
The Sobel weight is part of the autograd graph, so its gradient
flows back through all parameters (position, color, opacity, scale, rot),
not just positional gradients. This is the "trivial" edge-weighted
photometric loss SGF distinguishes itself from.
Uses vanilla 3DGS backbone (3dgsAtlas_official) and standard
densification logic — only the loss is modified.
"""
def __init__(self, dataset_config, hyperparams):
self.parser = ArgumentParser()
self.lp = ModelParams(self.parser)
self.op = OptimizationParams(self.parser)
self.pp = PipelineParams(self.parser)
self.args = self.parser.parse_args([])
self.args.source_path = dataset_config["source_path"]
self.args.model_path = dataset_config["model_path"]
self.args.eval = True
self.args.resolution = dataset_config.get("resolution", 1)
self.track_decoupling = hyperparams.get("track_decoupling", False)
self.dataset = self.lp.extract(self.args)
self.opt = self.op.extract(self.args)
self.pipe = self.pp.extract(self.args)
self.gaussians = GaussianModel(self.dataset.sh_degree)
# INJECTED_RES_FIX begin
import sys as _sys
_scene, _explicit_res = None, None
for _i, _a in enumerate(_sys.argv[:-1]):
_v = _sys.argv[_i + 1]
if _a == "--scene": _scene = _v
elif _a == "--source_path": _scene = _v.rstrip("/").split("/")[-1]
elif _a == "--resolution":
try: _explicit_res = int(_v)
except: pass
_OUTDOOR_360 = {"bicycle", "flowers", "garden", "stump", "treehill"}
if _explicit_res is not None and _explicit_res > 0:
_res = _explicit_res
elif _scene is not None:
_res = 4 if _scene in _OUTDOOR_360 else 2
else:
_res = None
try:
if _res is not None:
self.dataset.resolution = _res
print("[res-fix] scene=%s explicit=%s -> res=%s (%s)" % (_scene, _explicit_res, _res, __file__))
except Exception as _e:
print("[res-fix] FAILED:", _e)
# INJECTED_RES_FIX end
self.scene = Scene(self.dataset, self.gaussians)
self.gaussians.training_setup(self.opt)
bg_color = [1, 1, 1] if self.dataset.white_background else [0, 0, 0]
self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
self.viewpoint_stack = self.scene.getTrainCameras().copy()
self.last_n_gaussians = len(self.gaussians.get_xyz)
def train_iteration(self, step):
_iter_start = time.perf_counter()
self.gaussians.update_learning_rate(step)
if step % 1000 == 0:
self.gaussians.oneupSHdegree()
if not self.viewpoint_stack:
self.viewpoint_stack = self.scene.getTrainCameras().copy()
viewpoint_cam = self.viewpoint_stack.pop(random.randint(0, len(self.viewpoint_stack) - 1))
render_pkg = native_render(viewpoint_cam, self.gaussians, self.pipe, self.background)
image = render_pkg["render"]
viewspace_point_tensor = render_pkg["viewspace_points"]
visibility_filter = render_pkg["visibility_filter"]
radii = render_pkg["radii"]
gt_image = viewpoint_cam.original_image.cuda()
# === Edge-weighted loss (NO stop-gradient — key contrast with SGF) ===
edge_pred = _sobel(image)
edge_gt = _sobel(gt_image)
E = torch.abs(edge_pred - edge_gt).squeeze()
# Normalize without detaching: weight map participates in autograd.
W_edge = E / (E.max() + 1e-8)
# L1 component, spatially reweighted by edge map (no detach)
L1_map = torch.abs(image - gt_image).mean(dim=0)
loss_l1 = ((1.0 - self.opt.lambda_dssim) * L1_map * (1.0 + 0.7 * W_edge)).mean()
# SSIM component: standard scalar ssim (no spatial reweighting),
# keeping the comparison simple and the gradient through W_edge isolated to L1.
loss_ssim = self.opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss_target = loss_l1
loss_parasitic = loss_ssim
loss = loss_target + loss_parasitic
grad_cos_sim = 0.0
parasitic_ratio = 0.0
if self.track_decoupling and step % 100 == 0:
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
grad_target = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz)
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_parasitic.backward(retain_graph=True)
grad_parasitic = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz)
valid_mask = (torch.norm(grad_target, dim=1) > 0) & (torch.norm(grad_parasitic, dim=1) > 0)
if valid_mask.any():
grad_cos_sim = float(F.cosine_similarity(grad_target[valid_mask], grad_parasitic[valid_mask], dim=1).mean())
parasitic_ratio = float(torch.norm(grad_parasitic, dim=1).mean() / (torch.norm(grad_target, dim=1).mean() + 1e-7))
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss.backward()
else:
loss.backward()
with torch.no_grad():
if step < self.opt.densify_until_iter:
self.gaussians.max_radii2D[visibility_filter] = torch.max(self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
# vanilla backbone densification: 2-arg
self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if step > self.opt.densify_from_iter and step % self.opt.densification_interval == 0:
size_threshold = 20 if step > self.opt.opacity_reset_interval else None
# vanilla backbone signature with radii
self.gaussians.densify_and_prune(self.opt.densify_grad_threshold, 0.005, self.scene.cameras_extent, size_threshold, radii)
if step % self.opt.opacity_reset_interval == 0 or (self.dataset.white_background and step == self.opt.densify_from_iter):
self.gaussians.reset_opacity()
self.gaussians.optimizer.step()
self.gaussians.optimizer.zero_grad(set_to_none=True)
num_gaussians = self.gaussians.get_xyz.shape[0]
metrics = {
"loss": float(loss), "loss_l1": float(loss_target), "loss_ssim": float(loss_parasitic),
"num_gaussians": int(num_gaussians), "delta_N": int(num_gaussians - self.last_n_gaussians),
"peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)),
"grad_cos_sim": float(grad_cos_sim), "parasitic_ratio": float(parasitic_ratio),
"edge_weight_mean": float(W_edge.mean()),
}
self.last_n_gaussians = num_gaussians
metrics["iter_time_ms"] = float((time.perf_counter() - _iter_start) * 1000)
metrics["vram_allocated_GB"] = float(torch.cuda.memory_allocated() / 1024**3)
histograms = {}
if step % 1000 == 0:
histograms["opacity"] = torch.sigmoid(self.gaussians._opacity).clone().detach()
scales = torch.exp(self.gaussians._scaling).clone().detach()
histograms["scaling"] = scales
gamma = scales.max(dim=-1)[0] / (scales.min(dim=-1)[0] + 1e-7)
histograms["anisotropy"] = gamma
histograms["sh_dc_mag"] = self.gaussians._features_dc.detach().norm(dim=-1)
return metrics, histograms
def render(self, camera):
torch.cuda.synchronize()
_t0 = time.perf_counter()
with torch.no_grad():
render_pkg = native_render(camera, self.gaussians, self.pipe, self.background)
torch.cuda.synchronize()
render_ms = (time.perf_counter() - _t0) * 1000
return {"image": render_pkg["render"], "depth": render_pkg.get("depth", None), "render_ms": render_ms}
def save(self, save_dir, step):
self.scene.save(step)
def load(self, model_path, iteration):
self.gaussians.load_ply(os.path.join(model_path, 'point_cloud', f'iteration_{iteration}', 'point_cloud.ply'))
def get_spatial_centers(self):
return self.gaussians._xyz
def compute_physical_metrics(self, cameras=None):
metrics = {}
with torch.no_grad():
raw_scales = self.gaussians._scaling
scales = torch.exp(raw_scales)
max_S, _ = torch.max(scales, dim=1)
min_S, _ = torch.min(scales, dim=1)
gamma = max_S / (min_S + 1e-7)
metrics["gamma_median"] = float(torch.median(gamma))
metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90))
metrics["scale_mean"] = float(torch.mean(scales))
metrics["alpha_mean"] = float(torch.mean(torch.sigmoid(self.gaussians._opacity)))
dc, rest = self.gaussians._features_dc, self.gaussians._features_rest
if rest is not None and rest.shape[1] > 0:
metrics["sh_energy_ratio"] = float(rest.norm(dim=-1).mean() / (dc.norm(dim=-1).mean() + 1e-7))
if cameras is not None and len(cameras) > 0:
view_dirs = []
for c in cameras:
view_dirs.append(c.world_view_transform[:3, 2].tolist())
view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1)
rots = F.normalize(self.gaussians._rotation.clone(), dim=1)
w, x, y, z = rots.unbind(dim=-1)
normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1)
max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1)
metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean())
return metrics
def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor:
with torch.no_grad():
V = query_points.shape[0]
densities = torch.zeros(V, device="cuda")
xyz, opacities = self.gaussians._xyz, torch.sigmoid(self.gaussians._opacity).squeeze()
scales = torch.exp(self.gaussians._scaling)
sigma_sq = scales.max(dim=1)[0].pow(2)
N_gaussians = xyz.shape[0]
chunk_size = max(1, 30_000_000 // (N_gaussians + 1))
for i in range(0, V, chunk_size):
end = min(i + chunk_size, V)
dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2)
weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7))
densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1)
return densities