| import os |
| import sys |
| import math |
| import random |
| import torch |
| import numpy as np |
| import torch.nn.functional as F |
| from argparse import ArgumentParser |
|
|
| from core.registry import register_method |
| from core.base_method import BaseMethod |
|
|
| sys.path.insert(0, '/root/autodl-tmp/BAGS_offy') |
| from utils.loss_utils import l1_loss, ssim |
| from gaussian_renderer import render as native_render |
| from scene import Scene, GaussianModel |
| from arguments import ModelParams, PipelineParams, OptimizationParams |
|
|
| def tv_loss(grids): |
| number_of_grids = grids.shape[0] |
| h_tv_count = grids[:, :, 1:, :].shape[1] * grids[:, :, 1:, :].shape[2] * grids[:, :, 1:, :].shape[3] |
| w_tv_count = grids[:, :, :, 1:].shape[1] * grids[:, :, :, 1:].shape[2] * grids[:, :, :, 1:].shape[3] |
| h_tv = torch.pow((grids[:, :, 1:, :] - grids[:, :, :-1, :]), 2).sum() |
| w_tv = torch.pow((grids[:, :, :, 1:] - grids[:, :, :, :-1]), 2).sum() |
| return 2 * (h_tv / h_tv_count + w_tv / w_tv_count) / number_of_grids |
|
|
| def get_emb(sin_inp): |
| emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) |
| return torch.flatten(emb, -2, -1) |
|
|
| def get_2d_emb(batch_size, x, y, out_ch, device): |
| out_ch = int(np.ceil(out_ch / 4) * 2) |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, out_ch, 2).float() / out_ch)) |
| pos_x = torch.arange(x, device=device).type(inv_freq.type()) * 2 * np.pi / x |
| pos_y = torch.arange(y, device=device).type(inv_freq.type()) * 2 * np.pi / y |
| sin_inp_x = torch.einsum("i,j->ij", pos_x, inv_freq) |
| sin_inp_y = torch.einsum("i,j->ij", pos_y, inv_freq) |
| emb_x = get_emb(sin_inp_x).unsqueeze(1) |
| emb_y = get_emb(sin_inp_y) |
| emb = torch.zeros((x, y, out_ch * 2), device=device) |
| emb[:, :, : out_ch] = emb_x |
| emb[:, :, out_ch : 2 * out_ch] = emb_y |
| return emb[None, :, :, :].repeat(batch_size, 1, 1, 1) |
|
|
| @register_method("bags") |
| class BAGSWrapper(BaseMethod): |
| def __init__(self, dataset_config, hyperparams): |
| self.parser = ArgumentParser() |
| self.lp = ModelParams(self.parser) |
| self.op = OptimizationParams(self.parser) |
| self.pp = PipelineParams(self.parser) |
| self.args = self.parser.parse_args([]) |
|
|
| self.args.source_path = dataset_config["source_path"] |
| self.args.model_path = dataset_config["model_path"] |
| self.args.eval = True |
| self.args.resolution = dataset_config.get("resolution", 1) |
| self.track_decoupling = hyperparams.get("track_decoupling", False) |
|
|
| self.dataset = self.lp.extract(self.args) |
| self.opt = self.op.extract(self.args) |
| self.pipe = self.pp.extract(self.args) |
|
|
| self.opt.iterations = 60000 |
| self.opt.ms_steps = 6000 |
| self.upsample_iter = [3000, 6000] |
|
|
| |
| import sys as _sys |
| _scene, _explicit_res = None, None |
| for _i, _a in enumerate(_sys.argv[:-1]): |
| _v = _sys.argv[_i + 1] |
| if _a == "--scene": _scene = _v |
| elif _a == "--source_path": _scene = _v.rstrip("/").split("/")[-1] |
| elif _a == "--resolution": |
| try: _explicit_res = int(_v) |
| except: pass |
| _OUTDOOR_360 = {"bicycle", "flowers", "garden", "stump", "treehill"} |
| if _explicit_res is not None and _explicit_res > 0: |
| _res = _explicit_res |
| elif _scene is not None: |
| _res = 4 if _scene in _OUTDOOR_360 else 2 |
| else: |
| _res = None |
| try: |
| if _res is not None: |
| self.dataset.resolution = _res |
| print("[res-fix] scene=%s explicit=%s -> res=%s (%s)" % (_scene, _explicit_res, _res, __file__)) |
| except Exception as _e: |
| print("[res-fix] FAILED:", _e) |
| |
| self.scene = Scene(self.dataset) |
|
|
| inp_shape = [len(self.scene.getTrainCameras()), int(np.round(self.scene.orig_h/self.args.resolution)), int(np.round(self.scene.orig_w/self.args.resolution))] |
|
|
| self.gaussians = GaussianModel( |
| self.dataset.sh_degree, inp_shape, |
| ks1=self.dataset.kernel_size1, ks2=self.dataset.kernel_size2, |
| ks3=self.dataset.kernel_size3, ks_ss=self.dataset.kernel_size_ss, |
| not_use_rgbd=self.opt.not_use_rgbd, not_use_pe=self.opt.not_use_pe |
| ) |
| |
| self.scene.load_gaussian(self.gaussians) |
| self.gaussians.training_setup(self.opt) |
|
|
| bg_color = [1, 1, 1] if self.dataset.white_background else [0, 0, 0] |
| self.background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") |
|
|
| self.trainCameras_scale4 = self.scene.getTrainCameras(scale=4.0).copy() |
| self.trainCameras_scale2 = self.scene.getTrainCameras(scale=2.0).copy() |
| self.trainCameras_scale1 = self.scene.getTrainCameras(scale=1.0).copy() |
| |
| self.viewpoint_stack = self.trainCameras_scale4.copy() |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale4) |
|
|
| self.unfold1 = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size1, self.dataset.kernel_size1), padding=self.dataset.kernel_size1 // 2).cuda() |
| self.unfold2 = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size2, self.dataset.kernel_size2), padding=self.dataset.kernel_size2 // 2).cuda() |
| self.unfold3 = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size3, self.dataset.kernel_size3), padding=self.dataset.kernel_size3 // 2).cuda() |
| if self.dataset.kernel_size_ss != self.dataset.kernel_size3: |
| self.opt.use_another_mlp = True |
| self.unfold_ss = torch.nn.Unfold(kernel_size=(self.dataset.kernel_size_ss, self.dataset.kernel_size_ss), padding=self.dataset.kernel_size_ss // 2).cuda() |
| else: |
| self.unfold_ss = self.unfold3 |
|
|
| self.last_n_gaussians = self.gaussians.get_xyz.shape[0] |
|
|
| def train_iteration(self, step): |
| ori_iter = step |
| adj_iter = step |
| |
| if ori_iter == self.upsample_iter[0]: |
| self.viewpoint_stack = self.trainCameras_scale2.copy() |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale2) |
| elif ori_iter == self.upsample_iter[1]: |
| self.viewpoint_stack = self.trainCameras_scale1.copy() |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale1) |
|
|
| if ori_iter > self.opt.ms_steps: |
| adj_iter = step - self.opt.ms_steps |
|
|
| self.gaussians.update_learning_rate(adj_iter) |
| if adj_iter % 1000 == 0: |
| self.gaussians.oneupSHdegree() |
|
|
| if not self.viewpoint_stack: |
| if ori_iter >= self.upsample_iter[0] and ori_iter < self.upsample_iter[1]: |
| self.viewpoint_stack = self.trainCameras_scale2.copy() |
| elif ori_iter >= self.upsample_iter[1]: |
| self.viewpoint_stack = self.trainCameras_scale1.copy() |
| else: |
| self.viewpoint_stack = self.trainCameras_scale4.copy() |
|
|
| viewpoint_cam = self.viewpoint_stack.pop(random.randint(0, len(self.viewpoint_stack) - 1)) |
|
|
| subpixel_offset = None |
| if self.dataset.ray_jitter: |
| subpixel_offset = torch.rand((int(viewpoint_cam.image_height), int(viewpoint_cam.image_width), 2), dtype=torch.float32, device="cuda") - 0.5 |
|
|
| render_pkg = native_render(viewpoint_cam, self.gaussians, self.pipe, self.background, kernel_size=self.dataset.kernel_size, subpixel_offset=subpixel_offset) |
| image, depth, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["depth"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] |
|
|
| gt_image = viewpoint_cam.original_image.cuda() |
| loss_target = torch.tensor(0.0, device="cuda") |
| loss_parasitic = torch.tensor(0.0, device="cuda") |
|
|
| if adj_iter > 250: |
| shuffle_rgb = image.unsqueeze(0) |
| shuffle_depth = depth.unsqueeze(0) - depth.min() |
| shuffle_depth = shuffle_depth / (shuffle_depth.max() + 1e-7) |
| pos_enc = get_2d_emb(1, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1], 16, torch.device("cuda")) |
|
|
| if ori_iter < 3000: |
| kernel_weights, mask = self.gaussians.mlp_rgb_ms(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), ori_iter) |
| patches = self.unfold1(shuffle_rgb).view(1, 3, self.dataset.kernel_size1 ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1]) |
| elif ori_iter >= 3000 and ori_iter < 6000: |
| kernel_weights, mask = self.gaussians.mlp_rgb_ms(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), ori_iter) |
| patches = self.unfold2(shuffle_rgb).view(1, 3, self.dataset.kernel_size2 ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1]) |
| else: |
| if (ori_iter > self.opt.ms_steps) and self.opt.use_another_mlp: |
| kernel_weights, mask = self.gaussians.mlp_rgb_ss(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), adj_iter) |
| patches = self.unfold_ss(shuffle_rgb).view(1, 3, self.dataset.kernel_size_ss ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1]) |
| else: |
| kernel_weights, mask = self.gaussians.mlp_rgb_ms(0, pos_enc, torch.cat([shuffle_rgb, shuffle_depth], 1).detach(), ori_iter) |
| patches = self.unfold3(shuffle_rgb).view(1, 3, self.dataset.kernel_size3 ** 2, shuffle_rgb.shape[-2], shuffle_rgb.shape[-1]) |
|
|
| kernel_weights = kernel_weights.unsqueeze(1) |
| rgb = torch.sum(patches * kernel_weights, 2)[0] |
| mask = mask[0] |
| blur_image = mask * rgb + (1 - mask) * image |
|
|
| depthloss = self.opt.depth_loss_alpha * tv_loss(shuffle_depth) if self.opt.use_depth_loss else torch.tensor(0.0, device="cuda") |
| maskloss = self.opt.mask_loss_alpha * mask.mean() if self.opt.use_mask_loss else torch.tensor(0.0, device="cuda") |
| tvloss = self.opt.rgbtv_loss_alpha * tv_loss(shuffle_rgb) if self.opt.use_rgbtv_loss else torch.tensor(0.0, device="cuda") |
|
|
| Ll1 = l1_loss(blur_image, gt_image) |
| loss_target = (1.0 - self.opt.lambda_dssim) * Ll1 |
| loss_parasitic = self.opt.lambda_dssim * (1.0 - ssim(blur_image, gt_image)) + tvloss + maskloss + depthloss |
| else: |
| Ll1 = l1_loss(image, gt_image) |
| loss_target = (1.0 - self.opt.lambda_dssim) * Ll1 |
| loss_parasitic = self.opt.lambda_dssim * (1.0 - ssim(image, gt_image)) |
|
|
| loss = loss_target + loss_parasitic |
|
|
| grad_cos_sim = 0.0 |
| parasitic_ratio = 0.0 |
| metrics = {} |
|
|
| if self.track_decoupling and step % 100 == 0: |
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss_target.backward(retain_graph=True) |
| grad_target = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz) |
| |
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss_parasitic.backward(retain_graph=True) |
| grad_parasitic = self.gaussians._xyz.grad.clone() if self.gaussians._xyz.grad is not None else torch.zeros_like(self.gaussians._xyz) |
| |
| valid_mask = (torch.norm(grad_target, dim=1) > 0) & (torch.norm(grad_parasitic, dim=1) > 0) |
| if valid_mask.any(): |
| grad_cos_sim = float(F.cosine_similarity(grad_target[valid_mask], grad_parasitic[valid_mask], dim=1).mean()) |
| parasitic_ratio = float(torch.norm(grad_parasitic, dim=1).mean() / (torch.norm(grad_target, dim=1).mean() + 1e-7)) |
|
|
| param_groups_map = { |
| "spatial": [self.gaussians._xyz], |
| "geometry": [self.gaussians._scaling, self.gaussians._rotation], |
| "opacity": [self.gaussians._opacity], |
| "appearance": [self.gaussians._features_dc, self.gaussians._features_rest], |
| } |
|
|
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss_target.backward(retain_graph=True) |
| grads_target = {} |
| for group_name, params in param_groups_map.items(): |
| grads_target[group_name] = torch.cat([p.grad.clone().reshape(-1) for p in params if p.grad is not None]) |
|
|
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss_parasitic.backward(retain_graph=True) |
| grads_parasitic = {} |
| for group_name, params in param_groups_map.items(): |
| grads_parasitic[group_name] = torch.cat([p.grad.clone().reshape(-1) for p in params if p.grad is not None]) |
|
|
| for group_name in param_groups_map: |
| gt, gp = grads_target.get(group_name), grads_parasitic.get(group_name) |
| if gt is not None and gp is not None and gt.norm() > 0 and gp.norm() > 0: |
| cos = float(F.cosine_similarity(gt.unsqueeze(0), gp.unsqueeze(0))) |
| r = float(gp.norm() / (gt.norm() + gp.norm() + 1e-7)) |
| ti = r * max(0.0, -cos) |
| else: |
| ti = 0.0 |
| metrics[f"sti_{group_name}"] = ti |
|
|
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| else: |
| loss.backward() |
|
|
| with torch.no_grad(): |
| if adj_iter < self.opt.densify_until_iter: |
| self.gaussians.max_radii2D[visibility_filter] = torch.max(self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) |
| self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) |
| |
| if adj_iter > self.opt.densify_from_iter and adj_iter % self.opt.densification_interval == 0: |
| size_threshold = 20 if adj_iter > self.opt.opacity_reset_interval else None |
| if ori_iter <= self.opt.ms_steps: |
| dgt = self.opt.init_dgt |
| min_op = self.opt.init_opacity if self.opt.init_opacity >= 0 else self.opt.min_opacity |
| else: |
| dgt = self.opt.densify_grad_threshold |
| min_op = self.opt.min_opacity |
| self.gaussians.densify_and_prune(dgt, min_op, self.scene.cameras_extent, size_threshold) |
| |
| if ori_iter >= self.upsample_iter[0] and ori_iter < self.upsample_iter[1]: |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale2) |
| elif ori_iter >= self.upsample_iter[1]: |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale1) |
| else: |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale4) |
|
|
| if adj_iter % self.opt.opacity_reset_interval == 0 or (self.dataset.white_background and adj_iter == self.opt.densify_from_iter): |
| self.gaussians.reset_opacity() |
|
|
| if adj_iter % 100 == 0 and adj_iter > self.opt.densify_until_iter: |
| if adj_iter < (self.opt.iterations - self.opt.ms_steps) - 100: |
| if ori_iter >= self.upsample_iter[0] and ori_iter < self.upsample_iter[1]: |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale2) |
| elif ori_iter >= self.upsample_iter[1]: |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale1) |
| else: |
| self.gaussians.compute_3D_filter(cameras=self.trainCameras_scale4) |
|
|
| if ori_iter < self.opt.iterations: |
| self.gaussians.optimizer.step() |
| self.gaussians.optimizer.zero_grad(set_to_none=True) |
|
|
| num_gaussians = self.gaussians.get_xyz.shape[0] |
| metrics.update({ |
| "loss": float(loss), "loss_l1": float(loss_target), "loss_ssim": float(loss_parasitic), |
| "num_gaussians": int(num_gaussians), "delta_N": int(num_gaussians - self.last_n_gaussians), |
| "peak_vram_GB": float(torch.cuda.max_memory_allocated() / (1024 ** 3)), |
| "grad_cos_sim": float(grad_cos_sim), "parasitic_ratio": float(parasitic_ratio) |
| }) |
| self.last_n_gaussians = num_gaussians |
| |
| histograms = {} |
| if step % 1000 == 0: |
| histograms["opacity"] = torch.sigmoid(self.gaussians._opacity).clone().detach() |
| scales = torch.exp(self.gaussians._scaling).clone().detach() |
| histograms["scaling"] = scales |
| scales_2d = scales[:, :2] if scales.shape[1] >= 2 else scales |
| gamma = scales_2d.max(dim=-1)[0] / (scales_2d.min(dim=-1)[0] + 1e-7) |
| histograms["anisotropy"] = gamma |
| histograms["sh_dc_mag"] = self.gaussians._features_dc.detach().norm(dim=-1) |
|
|
| return metrics, histograms |
|
|
| def render(self, camera): |
| with torch.no_grad(): |
| render_pkg = native_render(camera, self.gaussians, self.pipe, self.background) |
| return {"image": render_pkg["render"], "depth": render_pkg.get("depth", None)} |
|
|
| def save(self, save_dir, step): |
| self.scene.save(step) |
|
|
| def load(self, model_path, iteration): |
| self.gaussians.load_ply(os.path.join(model_path, 'point_cloud', f'iteration_{iteration}', 'point_cloud.ply')) |
|
|
| def get_spatial_centers(self): |
| return self.gaussians._xyz |
|
|
| def compute_physical_metrics(self, cameras=None): |
| metrics = {} |
| with torch.no_grad(): |
| raw_scales = self.gaussians._scaling |
| scales = torch.exp(raw_scales) |
| scales_2d = scales[:, :2] if scales.dim() > 1 and scales.shape[1] >= 2 else scales.unsqueeze(-1).expand(-1, 2) |
| |
| max_S, _ = torch.max(scales_2d, dim=1) |
| min_S, _ = torch.min(scales_2d, dim=1) |
| gamma = max_S / (min_S + 1e-7) |
| |
| metrics["gamma_median"] = float(torch.median(gamma)) |
| metrics["gamma_90th_percentile"] = float(torch.quantile(gamma, 0.90)) |
| metrics["scale_mean"] = float(torch.mean(scales_2d)) |
| metrics["alpha_mean"] = float(torch.mean(torch.sigmoid(self.gaussians._opacity))) |
| |
| dc, rest = self.gaussians._features_dc, self.gaussians._features_rest |
| if rest is not None and rest.shape[1] > 0: |
| metrics["sh_energy_ratio"] = float(rest.norm(dim=-1).mean() / (dc.norm(dim=-1).mean() + 1e-7)) |
| |
| if cameras is not None and len(cameras) > 0: |
| view_dirs = [] |
| for c in cameras: |
| view_dirs.append(c.world_view_transform[:3, 2].tolist()) |
| view_dirs = F.normalize(torch.tensor(view_dirs, dtype=torch.float32, device="cuda"), dim=1) |
| rots = F.normalize(self.gaussians._rotation.clone(), dim=1) |
| w, x, y, z = rots.unbind(dim=-1) |
| normals = F.normalize(torch.stack([2*(x*z + w*y), 2*(y*z - w*x), 1-2*(x*x + y*y)], dim=-1), dim=1) |
| max_cos, _ = torch.max(torch.abs(torch.matmul(normals, view_dirs.T)), dim=1) |
| metrics["billboard_bias_ratio"] = float((max_cos > 0.90).float().mean()) |
| return metrics |
|
|
| def evaluate_spatial_field(self, query_points: torch.Tensor, cameras=None) -> torch.Tensor: |
| with torch.no_grad(): |
| V = query_points.shape[0] |
| densities = torch.zeros(V, device="cuda") |
| xyz, opacities = self.gaussians._xyz, torch.sigmoid(self.gaussians._opacity).squeeze() |
| scales = torch.exp(self.gaussians._scaling) |
| sigma_sq = (scales[:, :2].max(dim=1)[0].pow(2)) if scales.shape[1] >= 2 else scales.squeeze().pow(2) |
| |
| N_gaussians = xyz.shape[0] |
| chunk_size = max(1, 30_000_000 // (N_gaussians + 1)) |
| for i in range(0, V, chunk_size): |
| end = min(i + chunk_size, V) |
| dist_sq = torch.cdist(query_points[i:end], xyz, p=2).pow(2) |
| weights = torch.exp(-0.5 * dist_sq / (sigma_sq.unsqueeze(0) + 1e-7)) |
| densities[i:end] = torch.sum(weights * opacities.unsqueeze(0), dim=1) |
| return densities |
|
|