SplatAtlas / methods /wrapper_absgssgf.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
16.5 kB
import os
import sys
import math
import random
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from argparse import ArgumentParser
from core.registry import register_method
from core.base_method import BaseMethod
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../absgs_official')))
from utils.loss_utils import l1_loss, ssim
from gaussian_renderer import render as native_render
from scene import Scene, GaussianModel
from arguments import ModelParams, PipelineParams, OptimizationParams
def get_sobel_map(img):
luma = (0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]).unsqueeze(0).unsqueeze(0)
wx = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32, device=img.device).view(1, 1, 3, 3)
wy = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device=img.device).view(1, 1, 3, 3)
gx = F.conv2d(luma, wx, padding=1)
gy = F.conv2d(luma, wy, padding=1)
return torch.sqrt(gx**2 + gy**2 + 1e-8)
def create_window(window_size, channel):
def gaussian(window_size, sigma):
gauss = torch.Tensor([math.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def ssim_map(img1, img2, window_size=11, channel=3):
window = create_window(window_size, channel).to(img1.device)
mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size//2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
map_ssim = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return map_ssim
@register_method("absgssgf")
class AbsGSSGFWrapper(BaseMethod):
def __init__(self, dataset_config, hyperparams):
self.parser = ArgumentParser()
self.lp = ModelParams(self.parser)
self.op = OptimizationParams(self.parser)
self.pp = PipelineParams(self.parser)
self.args = self.parser.parse_args([])
self.args.source_path = dataset_config["source_path"]
self.args.model_path = dataset_config["model_path"]
self.args.eval = True
self.args.resolution = dataset_config.get("resolution", 1)
self.track_decoupling = hyperparams.get("track_decoupling", False)
self.dataset = self.lp.extract(self.args)
self.opt = self.op.extract(self.args)
self.pipe = self.pp.extract(self.args)
self.opt.percent_dense = 0.001
if "360" in str(self.args.source_path):
self.opt.lambda_dist = 100.0
else:
self.opt.lambda_dist = 1000.0
self.gaussians = GaussianModel(self.dataset.sh_degree)
# INJECTED_RES_FIX begin
import sys as _sys
_scene, _explicit_res = None, None
for _i, _a in enumerate(_sys.argv[:-1]):
_v = _sys.argv[_i + 1]
if _a == "--scene": _scene = _v
elif _a == "--source_path": _scene = _v.rstrip("/").split("/")[-1]
elif _a == "--resolution":
try: _explicit_res = int(_v)
except: pass
_OUTDOOR_360 = {"bicycle", "flowers", "garden", "stump", "treehill"}
if _explicit_res is not None and _explicit_res > 0:
_res = _explicit_res
elif _scene is not None:
_res = 4 if _scene in _OUTDOOR_360 else 2
else:
_res = None
try:
if _res is not None:
self.dataset.resolution = _res
print("[res-fix] scene=%s explicit=%s -> res=%s (%s)" % (_scene, _explicit_res, _res, __file__))
except Exception as _e:
print("[res-fix] FAILED:", _e)
# INJECTED_RES_FIX end
self.scene = Scene(self.dataset, self.gaussians)
self.gaussians.training_setup(self.opt)
bg_color = [1, 1, 1] if self.dataset.white_background else [0, 0, 0]
self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
self.viewpoint_stack = self.scene.getTrainCameras().copy()
self.last_n_gaussians = len(self.gaussians.get_xyz)
def train_iteration(self, step):
self.gaussians.update_learning_rate(step)
if step % 1000 == 0:
self.gaussians.oneupSHdegree()
if not self.viewpoint_stack:
self.viewpoint_stack = self.scene.getTrainCameras().copy()
viewpoint_cam = self.viewpoint_stack.pop(random.randint(0, len(self.viewpoint_stack) - 1))
render_pkg = native_render(viewpoint_cam, self.gaussians, self.pipe, self.background)
image = render_pkg["render"]
viewspace_point_tensor = render_pkg["viewspace_points"]
visibility_filter = render_pkg["visibility_filter"]
radii = render_pkg["radii"]
gs_w = render_pkg.get("gs_w", torch.zeros_like(radii))
gt_image = viewpoint_cam.original_image.cuda()
E_struct = 0.0
for w_s, s in [(1.0, 1.0), (0.5, 0.5), (0.25, 0.25)]:
if s == 1.0:
p_s, g_s = image, gt_image
else:
p_s = F.interpolate(image.unsqueeze(0), scale_factor=s, mode='bilinear', align_corners=False).squeeze(0)
g_s = F.interpolate(gt_image.unsqueeze(0), scale_factor=s, mode='bilinear', align_corners=False).squeeze(0)
diff = torch.abs(get_sobel_map(p_s) - get_sobel_map(g_s))
if s != 1.0:
diff = F.interpolate(diff, size=(image.shape[1], image.shape[2]), mode='bilinear', align_corners=False)
E_struct = E_struct + w_s * diff.squeeze()
F_p = (E_struct / (E_struct.max() + 1e-8)) ** 1.2
W = F_p.detach() ** 0.5
Ll1_map = torch.abs(image - gt_image).mean(dim=0)
loss_l1_reweighted = ((1.0 - self.opt.lambda_dssim) * Ll1_map * (1.0 + 0.7 * W)).mean()
ssim_error_map = 1.0 - ssim_map(image.unsqueeze(0), gt_image.unsqueeze(0)).squeeze(0).mean(dim=0)
loss_ssim_reweighted = (self.opt.lambda_dssim * ssim_error_map * (1.0 + 1.0 * W)).mean()
loss_struct = 0.05 * E_struct.mean()
loss_target = loss_l1_reweighted + loss_struct
loss_parasitic = loss_ssim_reweighted
loss = loss_target + loss_parasitic
grad_cos_sim = 0.0
parasitic_ratio = 0.0
if self.track_decoupling and step % 100 == 0:
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
target_updates = []
for group in self.gaussians.optimizer.param_groups:
p = group['params'][0]
if p.grad is not None:
state = self.gaussians.optimizer.state.get(p, None)
if state is not None and 'exp_avg_sq' in state:
v = state['exp_avg_sq']
lr = group['lr']
u = (lr / (torch.sqrt(v) + 1e-8)) * p.grad.clone()
target_updates.append(u.view(u.shape[0], -1))
else:
target_updates.append(torch.zeros_like(p).view(p.shape[0], -1))
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_parasitic.backward(retain_graph=True)
parasitic_updates = []
for group in self.gaussians.optimizer.param_groups:
p = group['params'][0]
if p.grad is not None:
state = self.gaussians.optimizer.state.get(p, None)
if state is not None and 'exp_avg_sq' in state:
v = state['exp_avg_sq']
lr = group['lr']
u = (lr / (torch.sqrt(v) + 1e-8)) * p.grad.clone()
parasitic_updates.append(u.view(u.shape[0], -1))
else:
parasitic_updates.append(torch.zeros_like(p).view(p.shape[0], -1))
u_t = torch.cat(target_updates, dim=1)
u_p = torch.cat(parasitic_updates, dim=1)
norm_t = torch.norm(u_t, dim=1)
norm_p = torch.norm(u_p, dim=1)
valid_mask = (norm_t > 0) & (norm_p > 0)
if valid_mask.any():
grad_cos_sim = float(F.cosine_similarity(u_t[valid_mask], u_p[valid_mask], dim=1).mean())
parasitic_ratio = float((norm_p / (norm_t + norm_p + 1e-7)).mean())
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss.backward()
else:
loss.backward()
with torch.no_grad():
self.gaussians.max_weight[visibility_filter] = torch.max(self.gaussians.max_weight[visibility_filter], gs_w[visibility_filter])
if step < self.opt.densify_until_iter:
self.gaussians.max_radii2D[visibility_filter] = torch.max(self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if step > self.opt.densify_from_iter and step % self.opt.densification_interval == 0:
size_threshold = 20 if step > self.opt.opacity_reset_interval else None
self.gaussians.densify_and_prune(self.opt.densify_grad_threshold, self.opt.densify_grad_abs_threshold, 0.005, self.scene.cameras_extent, size_threshold)
if step % self.opt.opacity_reduce_interval == 0 and self.opt.use_reduce:
self.gaussians.reduce_opacity()
if step % self.opt.opacity_reset_interval == 0 or (self.dataset.white_background and step == self.opt.densify_from_iter):
self.gaussians.reset_opacity()
if step > self.opt.densify_from_iter and step < self.opt.prune_until_iter and self.opt.use_prune_weight:
img_num = len(self.viewpoint_stack) if len(self.viewpoint_stack) > 0 else 1
if step % img_num == 0 and step % self.opt.opacity_reset_interval > img_num:
prune_mask = (self.gaussians.max_weight < self.opt.min_weight).squeeze()
self.gaussians.prune_points(prune_mask)
self.gaussians.max_weight *= 0
if step < self.opt.iterations:
self.gaussians.optimizer.step()
self.gaussians.optimizer.zero_grad(set_to_none=True)
num_gaussians = self.gaussians.get_xyz.shape[0]
denom_mask = self.gaussians.denom > 0
grad_reg_mean = float((self.gaussians.xyz_gradient_accum[denom_mask] / self.gaussians.denom[denom_mask]).mean()) if denom_mask.any() else 0.0
grad_abs_mean = float((self.gaussians.xyz_gradient_accum_abs[denom_mask] / self.gaussians.denom[denom_mask]).mean()) if denom_mask.any() else 0.0
metrics = {
"loss": float(loss), "loss_l1": float(loss_target), "loss_ssim": float(loss_parasitic),
"num_gaussians": int(num_gaussians), "delta_N": int(num_gaussians - self.last_n_gaussians),
"peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)),
"grad_cos_sim": float(grad_cos_sim), "parasitic_ratio": float(parasitic_ratio),
"grad_regular_mean": float(grad_reg_mean), "grad_abs_mean": float(grad_abs_mean),
"max_weight_mean": float(self.gaussians.max_weight.mean()),
"sgf_w_mean": float(W.mean()),
"sgf_loss_struct": float(loss_struct),
"sgf_loss_l1_reweighted": float(loss_l1_reweighted),
"sgf_loss_ssim_reweighted": float(loss_ssim_reweighted)
}
self.last_n_gaussians = num_gaussians
histograms = {}
if step % 1000 == 0:
histograms["opacity"] = torch.sigmoid(self.gaussians._opacity).clone().detach()
scales = torch.exp(self.gaussians._scaling).clone().detach()
histograms["scaling"] = scales
scales_2d = scales[:, :2] if scales.shape[1] >= 2 else scales
gamma = scales_2d.max(dim=-1)[0] / (scales_2d.min(dim=-1)[0] + 1e-7)
histograms["anisotropy"] = gamma
histograms["sh_dc_mag"] = self.gaussians._features_dc.detach().norm(dim=-1)
return metrics, histograms
def render(self, camera):
with torch.no_grad():
render_pkg = native_render(camera, self.gaussians, self.pipe, self.background)
return {"image": render_pkg["render"], "depth": render_pkg.get("depth", None)}
def save(self, save_dir, step):
self.scene.save(step)
def load(self, model_path, iteration):
self.gaussians.load_ply(os.path.join(model_path, 'point_cloud', f'iteration_{iteration}', 'point_cloud.ply'))
def get_spatial_centers(self):
return self.gaussians._xyz
def compute_physical_metrics(self, cameras=None):
metrics = {}
with torch.no_grad():
raw_scales = self.gaussians._scaling
scales = torch.exp(raw_scales)
scales_2d = scales[:, :2] if scales.dim() > 1 and scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2)
max_S, _ = torch.max(scales_2d, dim=1)
min_S, _ = torch.min(scales_2d, dim=1)
gamma = max_S / (min_S + 1e-7)
metrics["gamma_median"] = float(torch.median(gamma))
metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90))
metrics["scale_mean"] = float(torch.mean(scales_2d))
metrics["alpha_mean"] = float(torch.mean(torch.sigmoid(self.gaussians._opacity)))
dc, rest = self.gaussians._features_dc, self.gaussians._features_rest
if rest is not None and rest.shape[1] > 0:
metrics["sh_energy_ratio"] = float(rest.norm(dim=-1).mean() / (dc.norm(dim=-1).mean() + 1e-7))
if cameras is not None and len(cameras) > 0:
view_dirs = []
for c in cameras:
view_dirs.append(c.world_view_transform[:3, 2].tolist())
view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1)
rots = F.normalize(self.gaussians._rotation.clone(), dim=1)
w, x, y, z = rots.unbind(dim=-1)
normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1)
max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1)
metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean())
return metrics
def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor:
with torch.no_grad():
V = query_points.shape[0]
densities = torch.zeros(V, device="cuda")
xyz, opacities = self.gaussians._xyz, torch.sigmoid(self.gaussians._opacity).squeeze()
scales = torch.exp(self.gaussians._scaling)
sigma_sq = (scales[:, :2].max(dim=1)[0].pow(2)) if scales.shape[1] >= 2 else scales.squeeze().pow(2)
N_gaussians = xyz.shape[0]
chunk_size = max(1, 30_000_000 // (N_gaussians + 1))
for i in range(0, V, chunk_size):
end = min(i + chunk_size, V)
dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2)
weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7))
densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1)
return densities