SplatAtlas / methods /wrapper_bags.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
20.2 kB
import os
import sys
import math
import random
import torch
import numpy as np
import torch.nn.functional as F
from argparse import ArgumentParser
from core.registry import register_method
from core.base_method import BaseMethod
sys.path.insert(0, '/root/autodl-tmp/BAGS_offy')
from utils.loss_utils import l1_loss, ssim
from gaussian_renderer import render as native_render
from scene import Scene, GaussianModel
from arguments import ModelParams, PipelineParams, OptimizationParams
def tv_loss(grids):
number_of_grids = grids.shape[0]
h_tv_count = grids[:, :, 1:, :].shape[1] * grids[:, :, 1:, :].shape[2] * grids[:, :, 1:, :].shape[3]
w_tv_count = grids[:, :, :, 1:].shape[1] * grids[:, :, :, 1:].shape[2] * grids[:, :, :, 1:].shape[3]
h_tv = torch.pow((grids[:, :, 1:, :] - grids[:, :, :-1, :]), 2).sum()
w_tv = torch.pow((grids[:, :, :, 1:] - grids[:, :, :, :-1]), 2).sum()
return 2 * (h_tv / h_tv_count + w_tv / w_tv_count) / number_of_grids
def get_emb(sin_inp):
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
def get_2d_emb(batch_size, x, y, out_ch, device):
out_ch = int(np.ceil(out_ch / 4) * 2)
inv_freq = 1.0 / (10000 ** (torch.arange(0, out_ch, 2).float() / out_ch))
pos_x = torch.arange(x, device=device).type(inv_freq.type()) * 2 * np.pi / x
pos_y = torch.arange(y, device=device).type(inv_freq.type()) * 2 * np.pi / y
sin_inp_x = torch.einsum("i,j->ij", pos_x, inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, inv_freq)
emb_x = get_emb(sin_inp_x).unsqueeze(1)
emb_y = get_emb(sin_inp_y)
emb = torch.zeros((x, y, out_ch * 2), device=device)
emb[:, :, : out_ch] = emb_x
emb[:, :, out_ch : 2 * out_ch] = emb_y
return emb[None, :, :, :].repeat(batch_size, 1, 1, 1)
@register_method("bags")
class BAGSWrapper(BaseMethod):
def __init__(self, dataset_config, hyperparams):
self.parser = ArgumentParser()
self.lp = ModelParams(self.parser)
self.op = OptimizationParams(self.parser)
self.pp = PipelineParams(self.parser)
self.args = self.parser.parse_args([])
self.args.source_path = dataset_config["source_path"]
self.args.model_path = dataset_config["model_path"]
self.args.eval = True
self.args.resolution = dataset_config.get("resolution", 1)
self.track_decoupling = hyperparams.get("track_decoupling", False)
self.dataset = self.lp.extract(self.args)
self.opt = self.op.extract(self.args)
self.pipe = self.pp.extract(self.args)
self.opt.iterations = 60000
self.opt.ms_steps = 6000
self.upsample_iter = [3000, 6000]
# INJECTED_RES_FIX begin
import sys as _sys
_scene, _explicit_res = None, None
for _i, _a in enumerate(_sys.argv[:-1]):
_v = _sys.argv[_i + 1]
if _a == "--scene": _scene = _v
elif _a == "--source_path": _scene = _v.rstrip("/").split("/")[-1]
elif _a == "--resolution":
try: _explicit_res = int(_v)
except: pass
_OUTDOOR_360 = {"bicycle", "flowers", "garden", "stump", "treehill"}
if _explicit_res is not None and _explicit_res > 0:
_res = _explicit_res
elif _scene is not None:
_res = 4 if _scene in _OUTDOOR_360 else 2
else:
_res = None
try:
if _res is not None:
self.dataset.resolution = _res
print("[res-fix] scene=%s explicit=%s -> res=%s (%s)" % (_scene, _explicit_res, _res, __file__))
except Exception as _e:
print("[res-fix] FAILED:", _e)
# INJECTED_RES_FIX end
self.scene = Scene(self.dataset)
inp_shape = [len(self.scene.getTrainCameras()), int(np.round(self.scene.orig_h/self.args.resolution)), int(np.round(self.scene.orig_w/self.args.resolution))]
self.gaussians = GaussianModel(
self.dataset.sh_degree, inp_shape,
ks1=self.dataset.kernel_size1, ks2=self.dataset.kernel_size2,
ks3=self.dataset.kernel_size3, ks_ss=self.dataset.kernel_size_ss,
not_use_rgbd=self.opt.not_use_rgbd, not_use_pe=self.opt.not_use_pe
)
self.scene.load_gaussian(self.gaussians)
self.gaussians.training_setup(self.opt)
bg_color = [1, 1, 1] if self.dataset.white_background else [0, 0, 0]
self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
self.trainCameras_scale4 = self.scene.getTrainCameras(scale=4.0).copy()
self.trainCameras_scale2 = self.scene.getTrainCameras(scale=2.0).copy()
self.trainCameras_scale1 = self.scene.getTrainCameras(scale=1.0).copy()
self.viewpoint_stack = self.trainCameras_scale4.copy()
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale4)
self.unfold1 = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size1, self.dataset.kernel_size1), padding=self.dataset.kernel_size1 // 2).cuda()
self.unfold2 = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size2, self.dataset.kernel_size2), padding=self.dataset.kernel_size2 // 2).cuda()
self.unfold3 = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size3, self.dataset.kernel_size3), padding=self.dataset.kernel_size3 // 2).cuda()
if self.dataset.kernel_size_ss != self.dataset.kernel_size3:
self.opt.use_another_mlp = True
self.unfold_ss = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size_ss, self.dataset.kernel_size_ss), padding=self.dataset.kernel_size_ss // 2).cuda()
else:
self.unfold_ss = self.unfold3
self.last_n_gaussians = self.gaussians.get_xyz.shape[0]
def train_iteration(self, step):
ori_iter = step
adj_iter = step
if ori_iter == self.upsample_iter[0]:
self.viewpoint_stack = self.trainCameras_scale2.copy()
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale2)
elif ori_iter == self.upsample_iter[1]:
self.viewpoint_stack = self.trainCameras_scale1.copy()
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale1)
if ori_iter > self.opt.ms_steps:
adj_iter = step - self.opt.ms_steps
self.gaussians.update_learning_rate(adj_iter)
if adj_iter % 1000 == 0:
self.gaussians.oneupSHdegree()
if not self.viewpoint_stack:
if ori_iter >= self.upsample_iter[0] and ori_iter < self.upsample_iter[1]:
self.viewpoint_stack = self.trainCameras_scale2.copy()
elif ori_iter >= self.upsample_iter[1]:
self.viewpoint_stack = self.trainCameras_scale1.copy()
else:
self.viewpoint_stack = self.trainCameras_scale4.copy()
viewpoint_cam = self.viewpoint_stack.pop(random.randint(0, len(self.viewpoint_stack) - 1))
subpixel_offset = None
if self.dataset.ray_jitter:
subpixel_offset = torch.rand((int(viewpoint_cam.image_height), int(viewpoint_cam.image_width), 2), dtype=torch.float32, device="cuda") - 0.5
render_pkg = native_render(viewpoint_cam, self.gaussians, self.pipe, self.background, kernel_size=self.dataset.kernel_size, subpixel_offset=subpixel_offset)
image, depth, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["depth"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
gt_image = viewpoint_cam.original_image.cuda()
loss_target = torch.tensor(0.0, device="cuda")
loss_parasitic = torch.tensor(0.0, device="cuda")
if adj_iter > 250:
shuffle_rgb = image.unsqueeze(0)
shuffle_depth = depth.unsqueeze(0) - depth.min()
shuffle_depth = shuffle_depth / (shuffle_depth.max() + 1e-7)
pos_enc = get_2d_emb(1, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1], 16, torch.device("cuda"))
if ori_iter < 3000:
kernel_weights, mask = self.gaussians.mlp_rgb_ms(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), ori_iter)
patches = self.unfold1(shuffle_rgb).view(1, 3, self.dataset.kernel_size1 ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1])
elif ori_iter >= 3000 and ori_iter < 6000:
kernel_weights, mask = self.gaussians.mlp_rgb_ms(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), ori_iter)
patches = self.unfold2(shuffle_rgb).view(1, 3, self.dataset.kernel_size2 ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1])
else:
if (ori_iter > self.opt.ms_steps) and self.opt.use_another_mlp:
kernel_weights, mask = self.gaussians.mlp_rgb_ss(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), adj_iter)
patches = self.unfold_ss(shuffle_rgb).view(1, 3, self.dataset.kernel_size_ss ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1])
else:
kernel_weights, mask = self.gaussians.mlp_rgb_ms(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), ori_iter)
patches = self.unfold3(shuffle_rgb).view(1, 3, self.dataset.kernel_size3 ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1])
kernel_weights = kernel_weights.unsqueeze(1)
rgb = torch.sum(patches * kernel_weights, 2)[0]
mask = mask[0]
blur_image = mask * rgb + (1 - mask) * image
depthloss = self.opt.depth_loss_alpha * tv_loss(shuffle_depth) if self.opt.use_depth_loss else torch.tensor(0.0, device="cuda")
maskloss = self.opt.mask_loss_alpha * mask.mean() if self.opt.use_mask_loss else torch.tensor(0.0, device="cuda")
tvloss = self.opt.rgbtv_loss_alpha * tv_loss(shuffle_rgb) if self.opt.use_rgbtv_loss else torch.tensor(0.0, device="cuda")
Ll1 = l1_loss(blur_image, gt_image)
loss_target = (1.0 - self.opt.lambda_dssim) * Ll1
loss_parasitic = self.opt.lambda_dssim * (1.0 - ssim(blur_image, gt_image)) + tvloss + maskloss + depthloss
else:
Ll1 = l1_loss(image, gt_image)
loss_target = (1.0 - self.opt.lambda_dssim) * Ll1
loss_parasitic = self.opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss = loss_target + loss_parasitic
grad_cos_sim = 0.0
parasitic_ratio = 0.0
metrics = {}
if self.track_decoupling and step % 100 == 0:
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
grad_target = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz)
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_parasitic.backward(retain_graph=True)
grad_parasitic = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz)
valid_mask = (torch.norm(grad_target, dim=1) > 0) & (torch.norm(grad_parasitic, dim=1) > 0)
if valid_mask.any():
grad_cos_sim = float(F.cosine_similarity(grad_target[valid_mask], grad_parasitic[valid_mask], dim=1).mean())
parasitic_ratio = float(torch.norm(grad_parasitic, dim=1).mean() / (torch.norm(grad_target, dim=1).mean() + 1e-7))
param_groups_map = {
"spatial": [self.gaussians._xyz],
"geometry": [self.gaussians._scaling, self.gaussians._rotation],
"opacity": [self.gaussians._opacity],
"appearance": [self.gaussians._features_dc, self.gaussians._features_rest],
}
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
grads_target = {}
for group_name, params in param_groups_map.items():
grads_target[group_name] = torch.cat([p.grad.clone().reshape(-1) for p in params if p.grad is not None])
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_parasitic.backward(retain_graph=True)
grads_parasitic = {}
for group_name, params in param_groups_map.items():
grads_parasitic[group_name] = torch.cat([p.grad.clone().reshape(-1) for p in params if p.grad is not None])
for group_name in param_groups_map:
gt, gp = grads_target.get(group_name), grads_parasitic.get(group_name)
if gt is not None and gp is not None and gt.norm() > 0 and gp.norm() > 0:
cos = float(F.cosine_similarity(gt.unsqueeze(0), gp.unsqueeze(0)))
r = float(gp.norm() / (gt.norm() + gp.norm() + 1e-7))
ti = r * max(0.0, -cos)
else:
ti = 0.0
metrics[f"sti_{group_name}"] = ti
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss.backward()
else:
loss.backward()
with torch.no_grad():
if adj_iter < self.opt.densify_until_iter:
self.gaussians.max_radii2D[visibility_filter] = torch.max(self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if adj_iter > self.opt.densify_from_iter and adj_iter % self.opt.densification_interval == 0:
size_threshold = 20 if adj_iter > self.opt.opacity_reset_interval else None
if ori_iter <= self.opt.ms_steps:
dgt = self.opt.init_dgt
min_op = self.opt.init_opacity if self.opt.init_opacity >= 0 else self.opt.min_opacity
else:
dgt = self.opt.densify_grad_threshold
min_op = self.opt.min_opacity
self.gaussians.densify_and_prune(dgt, min_op, self.scene.cameras_extent, size_threshold)
if ori_iter >= self.upsample_iter[0] and ori_iter < self.upsample_iter[1]:
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale2)
elif ori_iter >= self.upsample_iter[1]:
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale1)
else:
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale4)
if adj_iter % self.opt.opacity_reset_interval == 0 or (self.dataset.white_background and adj_iter == self.opt.densify_from_iter):
self.gaussians.reset_opacity()
if adj_iter % 100 == 0 and adj_iter > self.opt.densify_until_iter:
if adj_iter < (self.opt.iterations - self.opt.ms_steps) - 100:
if ori_iter >= self.upsample_iter[0] and ori_iter < self.upsample_iter[1]:
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale2)
elif ori_iter >= self.upsample_iter[1]:
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale1)
else:
self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale4)
if ori_iter < self.opt.iterations:
self.gaussians.optimizer.step()
self.gaussians.optimizer.zero_grad(set_to_none=True)
num_gaussians = self.gaussians.get_xyz.shape[0]
metrics.update({
"loss": float(loss), "loss_l1": float(loss_target), "loss_ssim": float(loss_parasitic),
"num_gaussians": int(num_gaussians), "delta_N": int(num_gaussians - self.last_n_gaussians),
"peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)),
"grad_cos_sim": float(grad_cos_sim), "parasitic_ratio": float(parasitic_ratio)
})
self.last_n_gaussians = num_gaussians
histograms = {}
if step % 1000 == 0:
histograms["opacity"] = torch.sigmoid(self.gaussians._opacity).clone().detach()
scales = torch.exp(self.gaussians._scaling).clone().detach()
histograms["scaling"] = scales
scales_2d = scales[:, :2] if scales.shape[1] >= 2 else scales
gamma = scales_2d.max(dim=-1)[0] / (scales_2d.min(dim=-1)[0] + 1e-7)
histograms["anisotropy"] = gamma
histograms["sh_dc_mag"] = self.gaussians._features_dc.detach().norm(dim=-1)
return metrics, histograms
def render(self, camera):
with torch.no_grad():
render_pkg = native_render(camera, self.gaussians, self.pipe, self.background)
return {"image": render_pkg["render"], "depth": render_pkg.get("depth", None)}
def save(self, save_dir, step):
self.scene.save(step)
def load(self, model_path, iteration):
self.gaussians.load_ply(os.path.join(model_path, 'point_cloud', f'iteration_{iteration}', 'point_cloud.ply'))
def get_spatial_centers(self):
return self.gaussians._xyz
def compute_physical_metrics(self, cameras=None):
metrics = {}
with torch.no_grad():
raw_scales = self.gaussians._scaling
scales = torch.exp(raw_scales)
scales_2d = scales[:, :2] if scales.dim() > 1 and scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2)
max_S, _ = torch.max(scales_2d, dim=1)
min_S, _ = torch.min(scales_2d, dim=1)
gamma = max_S / (min_S + 1e-7)
metrics["gamma_median"] = float(torch.median(gamma))
metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90))
metrics["scale_mean"] = float(torch.mean(scales_2d))
metrics["alpha_mean"] = float(torch.mean(torch.sigmoid(self.gaussians._opacity)))
dc, rest = self.gaussians._features_dc, self.gaussians._features_rest
if rest is not None and rest.shape[1] > 0:
metrics["sh_energy_ratio"] = float(rest.norm(dim=-1).mean() / (dc.norm(dim=-1).mean() + 1e-7))
if cameras is not None and len(cameras) > 0:
view_dirs = []
for c in cameras:
view_dirs.append(c.world_view_transform[:3, 2].tolist())
view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1)
rots = F.normalize(self.gaussians._rotation.clone(), dim=1)
w, x, y, z = rots.unbind(dim=-1)
normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1)
max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1)
metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean())
return metrics
def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor:
with torch.no_grad():
V = query_points.shape[0]
densities = torch.zeros(V, device="cuda")
xyz, opacities = self.gaussians._xyz, torch.sigmoid(self.gaussians._opacity).squeeze()
scales = torch.exp(self.gaussians._scaling)
sigma_sq = (scales[:, :2].max(dim=1)[0].pow(2)) if scales.shape[1] >= 2 else scales.squeeze().pow(2)
N_gaussians = xyz.shape[0]
chunk_size = max(1, 30_000_000 // (N_gaussians + 1))
for i in range(0, V, chunk_size):
end = min(i + chunk_size, V)
dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2)
weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7))
densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1)
return densities