SplatAtlas / methods /wrapper_gaussianfocus.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
18 kB
import os
import sys
import random
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from argparse import ArgumentParser
from core.registry import register_method
from core.base_method import BaseMethod
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../GaussianFocus_offy')))
from utils.loss_utils import l1_loss, ssim
from gaussian_renderer import render as native_render
from scene import Scene, GaussianModel
from arguments import ModelParams, PipelineParams, OptimizationParams
class PatchAttention(nn.Module):
def __init__(self, channel):
super(PatchAttention, self).__init__()
self.query_conv = nn.Conv2d(channel, 1, kernel_size=1)
self.key_conv = nn.Conv2d(channel, 1, kernel_size=1)
self.value_conv = nn.Conv2d(channel, channel, kernel_size=1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, image, gt_image, block_size=64):
batch_size, C, height, width = image.size()
blocks = []
for i in range(0, height, block_size):
for j in range(0, width, block_size):
block_image = image[:, :, i:i+block_size, j:j+block_size]
block_gt_image = gt_image[:, :, i:i+block_size, j:j+block_size]
pad_height = block_size - block_image.size(2)
pad_width = block_size - block_image.size(3)
if pad_height > 0 or pad_width > 0:
block_image = F.pad(block_image, (0, pad_width, 0, pad_height))
block_gt_image = F.pad(block_gt_image, (0, pad_width, 0, pad_height))
query = self.query_conv(block_image).view(batch_size, -1, block_size * block_size).permute(0, 2, 1)
key = self.key_conv(block_gt_image).view(batch_size, -1, block_size * block_size)
value = self.value_conv(block_gt_image).view(batch_size, -1, block_size * block_size)
attention = self.softmax(torch.bmm(query, key))
out = torch.bmm(value, attention.permute(0, 2, 1))
out = out.view(batch_size, C, block_size, block_size)
blocks.append(out[:, :, :min(block_size, height - i), :min(block_size, width - j)])
output = torch.cat([torch.cat(blocks[i * (width // block_size + (width % block_size > 0)):(i + 1) * (width // block_size + (width % block_size > 0))], dim=3) for i in range(height // block_size + (height % block_size > 0))], dim=2)
return output
def edge_loss(input_img, target_img):
edge_filter = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view(1, 1, 3, 3).to(input_img.device)
edge_filter = edge_filter.repeat(input_img.size(1), 1, 1, 1)
edge_filter = nn.Parameter(data=edge_filter, requires_grad=False)
input_x = F.conv2d(input_img, edge_filter, padding=1, groups=input_img.size(1))
input_y = F.conv2d(input_img, edge_filter.transpose(2, 3), padding=1, groups=input_img.size(1))
target_x = F.conv2d(target_img, edge_filter, padding=1, groups=target_img.size(1))
target_y = F.conv2d(target_img, edge_filter.transpose(2, 3), padding=1, groups=target_img.size(1))
loss_x = F.l1_loss(input_x, target_x)
loss_y = F.l1_loss(input_y, target_y)
return (loss_x + loss_y) / 2
def frequency_loss(input_img, target_img):
grad_x_input = input_img[:, :, 1:] - input_img[:, :, :-1]
grad_y_input = input_img[:, 1:, :] - input_img[:, :-1, :]
grad_x_target = target_img[:, :, 1:] - target_img[:, :, :-1]
grad_y_target = target_img[:, 1:, :] - target_img[:, :-1, :]
loss_x = F.l1_loss(grad_x_input, grad_x_target)
loss_y = F.l1_loss(grad_y_input, grad_y_target)
return (loss_x + loss_y) / 2
class CombinedLoss(nn.Module):
def __init__(self):
super(CombinedLoss, self).__init__()
def forward(self, finals, gt_image, beta=1, eta=1):
if len(finals.shape) == 3:
finals = finals.unsqueeze(0)
if len(gt_image.shape) == 3:
gt_image = gt_image.unsqueeze(0)
e_loss = edge_loss(finals, gt_image)
f_loss = frequency_loss(finals, gt_image)
return beta * e_loss + eta * f_loss
@register_method("gaussianfocus")
class GaussianFocusWrapper(BaseMethod):
def __init__(self, dataset_config, hyperparams):
self.parser = ArgumentParser()
self.lp = ModelParams(self.parser)
self.op = OptimizationParams(self.parser)
self.pp = PipelineParams(self.parser)
self.args = self.parser.parse_args([])
self.args.source_path = dataset_config["source_path"]
self.args.model_path = dataset_config["model_path"]
self.args.eval = True
self.args.resolution = dataset_config.get("resolution", 1)
self.track_decoupling = hyperparams.get("track_decoupling", False)
self.dataset = self.lp.extract(self.args)
self.opt = self.op.extract(self.args)
self.pipe = self.pp.extract(self.args)
self.dataset.kernel_size = 0.1
self.dataset.ray_jitter = False
self.dataset.resample_gt_image = False
self.dataset.sample_more_highres = False
self.gaussians = GaussianModel(self.dataset.sh_degree)
# INJECTED_RES_FIX begin
import sys as _sys
_scene, _explicit_res = None, None
for _i, _a in enumerate(_sys.argv[:-1]):
_v = _sys.argv[_i + 1]
if _a == "--scene": _scene = _v
elif _a == "--source_path": _scene = _v.rstrip("/").split("/")[-1]
elif _a == "--resolution":
try: _explicit_res = int(_v)
except: pass
_OUTDOOR_360 = {"bicycle", "flowers", "garden", "stump", "treehill"}
if _explicit_res is not None and _explicit_res > 0:
_res = _explicit_res
elif _scene is not None:
_res = 4 if _scene in _OUTDOOR_360 else 2
else:
_res = None
try:
if _res is not None:
self.dataset.resolution = _res
print("[res-fix] scene=%s explicit=%s -> res=%s (%s)" % (_scene, _explicit_res, _res, __file__))
except Exception as _e:
print("[res-fix] FAILED:", _e)
# INJECTED_RES_FIX end
self.scene = Scene(self.dataset, self.gaussians)
self.gaussians.training_setup(self.opt)
bg_color = [1, 1, 1] if self.dataset.white_background else [0, 0, 0]
self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
self.viewpoint_stack = self.scene.getTrainCameras().copy()
self.gaussians.compute_3D_filter(cameras=self.scene.getTrainCameras())
self.last_n_gaussians = len(self.gaussians.get_xyz)
self.patch_attention = PatchAttention(channel=3).to("cuda")
self.attn_optimizer = torch.optim.Adam(self.patch_attention.parameters(), lr=1e-3)
self.combined_loss_fn = CombinedLoss()
def train_iteration(self, step):
self.gaussians.update_learning_rate(step)
if step % 1000 == 0:
self.gaussians.oneupSHdegree()
if not self.viewpoint_stack:
self.viewpoint_stack = self.scene.getTrainCameras().copy()
self.gaussians.compute_3D_filter(cameras=self.scene.getTrainCameras())
viewpoint_cam = self.viewpoint_stack.pop(random.randint(0, len(self.viewpoint_stack) - 1))
render_pkg = native_render(viewpoint_cam, self.gaussians, self.pipe, self.background, kernel_size=self.dataset.kernel_size)
image = render_pkg["render"]
viewspace_point_tensor = render_pkg["viewspace_points"]
visibility_filter = render_pkg["visibility_filter"]
radii = render_pkg["radii"]
gt_image = viewpoint_cam.original_image.cuda()
patch_attn_active = 0
finals = image
if step % 50 == 0:
patch_attn_active = 1
image1 = image.unsqueeze(0)
gt_image1 = gt_image.unsqueeze(0)
output = self.patch_attention(image1, gt_image1)
finals = output * image
finals = torch.squeeze(finals, 0)
finals_var = finals
gt_image_var = gt_image
Ll1 = l1_loss(image, gt_image)
ssim_value = ssim(image, gt_image)
loss_target = (1.0 - self.opt.lambda_dssim) * Ll1 + self.opt.lambda_dssim * (1.0 - ssim_value)
loss_x = self.combined_loss_fn(finals_var, gt_image_var)
loss = loss_target + loss_x
grad_cos_sim = 0.0
parasitic_ratio = 0.0
stats = {
"sti_spatial": 0.0,
"sti_geometry": 0.0,
"sti_opacity": 0.0,
"sti_appearance": 0.0
}
if self.track_decoupling and step % 100 == 0:
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
grad_target = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz)
self.gaussians.optimizer.zero_grad(set_to_none=True)
if getattr(loss_x, "requires_grad", False): loss_x.backward(retain_graph=True)
grad_parasitic = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz)
valid_mask = (torch.norm(grad_target, dim=1) > 0) & (torch.norm(grad_parasitic, dim=1) > 0)
if valid_mask.any():
grad_cos_sim = float(F.cosine_similarity(grad_target[valid_mask], grad_parasitic[valid_mask], dim=1).mean())
parasitic_ratio = float(torch.norm(grad_parasitic, dim=1).mean() / (torch.norm(grad_target, dim=1).mean() + 1e-7))
param_groups_map = {
"spatial": [self.gaussians._xyz],
"geometry": [self.gaussians._scaling, self.gaussians._rotation],
"opacity": [self.gaussians._opacity],
"appearance": [self.gaussians._features_dc, self.gaussians._features_rest],
}
self.gaussians.optimizer.zero_grad(set_to_none=True)
loss_target.backward(retain_graph=True)
grads_target = {}
for group_name, params in param_groups_map.items():
grads_target[group_name] = torch.cat([p.grad.clone().reshape(-1) for p in params if p.grad is not None])
self.gaussians.optimizer.zero_grad(set_to_none=True)
if getattr(loss_x, "requires_grad", False): loss_x.backward(retain_graph=True)
grads_parasitic = {}
for group_name, params in param_groups_map.items():
grads_parasitic[group_name] = torch.cat([p.grad.clone().reshape(-1) for p in params if p.grad is not None])
N_total = self.gaussians.get_xyz.shape[0]
SAFE_N = 2_000_000 # above this, skip group-wise STI to avoid OOM
for group_name in param_groups_map:
if N_total > SAFE_N:
stats[f"sti_{group_name}"] = 0.0
continue
gt, gp = grads_target.get(group_name), grads_parasitic.get(group_name)
if gt is not None and gp is not None and gt.norm() > 0 and gp.norm() > 0:
# 1D dot product is memory-cheap; cosine_similarity on unsqueeze(0)
# allocates 2D intermediates that OOM at N>2M.
cos = float((gt @ gp) / (gt.norm() * gp.norm() + 1e-7))
r = float(gp.norm() / (gt.norm() + gp.norm() + 1e-7))
ti = r * max(0.0, -cos)
else:
ti = 0.0
stats[f"sti_{group_name}"] = ti
self.gaussians.optimizer.zero_grad(set_to_none=True)
self.attn_optimizer.zero_grad()
loss.backward()
else:
self.attn_optimizer.zero_grad()
loss.backward()
with torch.no_grad():
if step < self.opt.densify_until_iter:
self.gaussians.max_radii2D[visibility_filter] = torch.max(self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if step > self.opt.densify_from_iter and step % self.opt.densification_interval == 0:
size_threshold = 0.02 if step > 1000 else None
self.gaussians.densify_and_prune(self.opt.densify_grad_threshold, 0.005, self.scene.cameras_extent, size_threshold, step)
self.gaussians.compute_3D_filter(cameras=self.scene.getTrainCameras())
if step % self.opt.opacity_reset_interval == 0 or (self.dataset.white_background and step == self.opt.densify_from_iter):
self.gaussians.reset_opacity()
if step % 100 == 0 and step > self.opt.densify_until_iter:
if step < self.opt.iterations - 100:
self.gaussians.compute_3D_filter(cameras=self.scene.getTrainCameras())
self.gaussians.optimizer.step()
self.gaussians.optimizer.zero_grad(set_to_none=True)
if patch_attn_active:
self.attn_optimizer.step()
num_gaussians = self.gaussians.get_xyz.shape[0]
metrics = {
"loss": float(loss),
"loss_l1": float(loss_target),
"loss_ssim": float(loss_x),
"loss_edge_freq": float(loss_x),
"patch_attn_active": int(patch_attn_active),
"num_gaussians": int(num_gaussians),
"delta_N": int(num_gaussians - self.last_n_gaussians),
"peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)),
"grad_cos_sim": float(grad_cos_sim),
"parasitic_ratio": float(parasitic_ratio)
}
metrics.update(stats)
self.last_n_gaussians = num_gaussians
histograms = {}
if step % 1000 == 0:
histograms["opacity"] = torch.sigmoid(self.gaussians._opacity).clone().detach()
scales = torch.exp(self.gaussians._scaling).clone().detach()
histograms["scaling"] = scales
scales_2d = scales[:, :2] if scales.shape[1] >= 2 else scales
gamma = scales_2d.max(dim=-1)[0] / (scales_2d.min(dim=-1)[0] + 1e-7)
histograms["anisotropy"] = gamma
histograms["sh_dc_mag"] = self.gaussians._features_dc.detach().norm(dim=-1)
return metrics, histograms
def render(self, camera):
with torch.no_grad():
render_pkg = native_render(camera, self.gaussians, self.pipe, self.background, kernel_size=self.dataset.kernel_size)
return {"image": render_pkg["render"], "depth": render_pkg.get("depth", None)}
def save(self, save_dir, step):
self.scene.save(step)
def load(self, model_path, iteration):
self.gaussians.load_ply(os.path.join(model_path, 'point_cloud', f'iteration_{iteration}', 'point_cloud.ply'))
def get_spatial_centers(self):
return self.gaussians._xyz
def compute_physical_metrics(self, cameras=None):
metrics = {}
with torch.no_grad():
raw_scales = self.gaussians._scaling
scales = torch.exp(raw_scales)
scales_2d = scales[:, :2] if scales.dim() > 1 and scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2)
max_S, _ = torch.max(scales_2d, dim=1)
min_S, _ = torch.min(scales_2d, dim=1)
gamma = max_S / (min_S + 1e-7)
metrics["gamma_median"] = float(torch.median(gamma))
metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90))
metrics["scale_mean"] = float(torch.mean(scales_2d))
metrics["alpha_mean"] = float(torch.mean(torch.sigmoid(self.gaussians._opacity)))
dc, rest = self.gaussians._features_dc, self.gaussians._features_rest
if rest is not None and rest.shape[1] > 0:
metrics["sh_energy_ratio"] = float(rest.norm(dim=-1).mean() / (dc.norm(dim=-1).mean() + 1e-7))
if cameras is not None and len(cameras) > 0:
view_dirs = []
for c in cameras:
view_dirs.append(c.world_view_transform[:3, 2].tolist())
view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1)
rots = F.normalize(self.gaussians._rotation.clone(), dim=1)
w, x, y, z = rots.unbind(dim=-1)
normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1)
max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1)
metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean())
return metrics
def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor:
with torch.no_grad():
V = query_points.shape[0]
densities = torch.zeros(V, device="cuda")
xyz, opacities = self.gaussians._xyz, torch.sigmoid(self.gaussians._opacity).squeeze()
scales = torch.exp(self.gaussians._scaling)
sigma_sq = (scales[:, :2].max(dim=1)[0].pow(2)) if scales.shape[1] >= 2 else scales.squeeze().pow(2)
N_gaussians = xyz.shape[0]
chunk_size = max(1, 30_000_000 // (N_gaussians + 1))
for i in range(0, V, chunk_size):
end = min(i + chunk_size, V)
dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2)
weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7))
densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1)
return densities