| |
| |
| |
| |
|
|
| import math |
| import os |
|
|
| import cv2 |
| import matplotlib |
| import numpy as np |
| import torch |
| from diff_gaussian_rasterization import ( |
| GaussianRasterizationSettings, |
| GaussianRasterizer, |
| ) |
| from einops import rearrange |
| from plyfile import PlyData, PlyElement |
| from torch import nn |
|
|
| from collections import OrderedDict |
| import videoio |
|
|
| @torch.no_grad() |
| def get_turntable_cameras( |
| hfov=50, |
| num_views=8, |
| w=384, |
| h=384, |
| radius=2.7, |
| elevation=20, |
| up_vector=np.array([0, 0, 1]), |
| ): |
| fx = w / (2 * np.tan(np.deg2rad(hfov) / 2.0)) |
| fy = fx |
| cx, cy = w / 2.0, h / 2.0 |
| fxfycxcy = ( |
| np.array([fx, fy, cx, cy]).reshape(1, 4).repeat(num_views, axis=0) |
| ) |
| |
| azimuths = np.linspace(270, 630, num_views, endpoint=False) |
| elevations = np.ones_like(azimuths) * elevation |
| c2ws = [] |
| for elev, azim in zip(elevations, azimuths): |
| elev, azim = np.deg2rad(elev), np.deg2rad(azim) |
| z = radius * np.sin(elev) |
| base = radius * np.cos(elev) |
| x = base * np.cos(azim) |
| y = base * np.sin(azim) |
| cam_pos = np.array([x, y, z]) |
| forward = -cam_pos / np.linalg.norm(cam_pos) |
| right = np.cross(forward, up_vector) |
| right = right / np.linalg.norm(right) |
| up = np.cross(right, forward) |
| up = up / np.linalg.norm(up) |
| R = np.stack((right, -up, forward), axis=1) |
| c2w = np.eye(4) |
| c2w[:3, :4] = np.concatenate((R, cam_pos[:, None]), axis=1) |
| c2ws.append(c2w) |
| c2ws = np.stack(c2ws, axis=0) |
| return w, h, num_views, fxfycxcy, c2ws |
|
|
| def imageseq2video(images, filename, fps=24): |
| |
| if images.dtype == np.uint8: |
| images = images.astype(np.float32) / 255.0 |
|
|
| videoio.videosave(filename, images, lossless=True, preset="veryfast", fps=fps) |
|
|
|
|
| |
| def strip_lowerdiag(L): |
| uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device) |
|
|
| uncertainty[:, 0] = L[:, 0, 0] |
| uncertainty[:, 1] = L[:, 0, 1] |
| uncertainty[:, 2] = L[:, 0, 2] |
| uncertainty[:, 3] = L[:, 1, 1] |
| uncertainty[:, 4] = L[:, 1, 2] |
| uncertainty[:, 5] = L[:, 2, 2] |
| return uncertainty |
|
|
|
|
| def strip_symmetric(sym): |
| return strip_lowerdiag(sym) |
|
|
|
|
| def build_rotation(r): |
| norm = torch.sqrt( |
| r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3] |
| ) |
|
|
| q = r / norm[:, None] |
|
|
| R = torch.zeros((q.size(0), 3, 3), device=r.device) |
|
|
| r = q[:, 0] |
| x = q[:, 1] |
| y = q[:, 2] |
| z = q[:, 3] |
|
|
| R[:, 0, 0] = 1 - 2 * (y * y + z * z) |
| R[:, 0, 1] = 2 * (x * y - r * z) |
| R[:, 0, 2] = 2 * (x * z + r * y) |
| R[:, 1, 0] = 2 * (x * y + r * z) |
| R[:, 1, 1] = 1 - 2 * (x * x + z * z) |
| R[:, 1, 2] = 2 * (y * z - r * x) |
| R[:, 2, 0] = 2 * (x * z - r * y) |
| R[:, 2, 1] = 2 * (y * z + r * x) |
| R[:, 2, 2] = 1 - 2 * (x * x + y * y) |
| return R |
|
|
|
|
| def build_scaling_rotation(s, r): |
| L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device) |
| R = build_rotation(r) |
|
|
| L[:, 0, 0] = s[:, 0] |
| L[:, 1, 1] = s[:, 1] |
| L[:, 2, 2] = s[:, 2] |
|
|
| L = R @ L |
| return L |
|
|
|
|
| def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): |
| L = build_scaling_rotation(scaling_modifier * scaling, rotation) |
| actual_covariance = L @ L.transpose(1, 2) |
| symm = strip_symmetric(actual_covariance) |
| return symm |
|
|
|
|
| |
| C0 = 0.28209479177387814 |
| C1 = 0.4886025119029199 |
| C2 = [ |
| 1.0925484305920792, |
| -1.0925484305920792, |
| 0.31539156525252005, |
| -1.0925484305920792, |
| 0.5462742152960396, |
| ] |
| C3 = [ |
| -0.5900435899266435, |
| 2.890611442640554, |
| -0.4570457994644658, |
| 0.3731763325901154, |
| -0.4570457994644658, |
| 1.445305721320277, |
| -0.5900435899266435, |
| ] |
| C4 = [ |
| 2.5033429417967046, |
| -1.7701307697799304, |
| 0.9461746957575601, |
| -0.6690465435572892, |
| 0.10578554691520431, |
| -0.6690465435572892, |
| 0.47308734787878004, |
| -1.7701307697799304, |
| 0.6258357354491761, |
| ] |
|
|
|
|
| def eval_sh(deg, sh, dirs): |
| """ |
| Evaluate spherical harmonics at unit directions |
| using hardcoded SH polynomials. |
| Works with torch/np/jnp. |
| ... Can be 0 or more batch dimensions. |
| Args: |
| deg: int SH deg. Currently, 0-3 supported |
| sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] |
| dirs: jnp.ndarray unit directions [..., 3] |
| Returns: |
| [..., C] |
| """ |
| assert deg <= 4 and deg >= 0 |
| coeff = (deg + 1) ** 2 |
| assert sh.shape[-1] >= coeff |
|
|
| result = C0 * sh[..., 0] |
| if deg > 0: |
| x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] |
| result = ( |
| result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3] |
| ) |
|
|
| if deg > 1: |
| xx, yy, zz = x * x, y * y, z * z |
| xy, yz, xz = x * y, y * z, x * z |
| result = ( |
| result |
| + C2[0] * xy * sh[..., 4] |
| + C2[1] * yz * sh[..., 5] |
| + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] |
| + C2[3] * xz * sh[..., 7] |
| + C2[4] * (xx - yy) * sh[..., 8] |
| ) |
|
|
| if deg > 2: |
| result = ( |
| result |
| + C3[0] * y * (3 * xx - yy) * sh[..., 9] |
| + C3[1] * xy * z * sh[..., 10] |
| + C3[2] * y * (4 * zz - xx - yy) * sh[..., 11] |
| + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] |
| + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] |
| + C3[5] * z * (xx - yy) * sh[..., 14] |
| + C3[6] * x * (xx - 3 * yy) * sh[..., 15] |
| ) |
|
|
| if deg > 3: |
| result = ( |
| result |
| + C4[0] * xy * (xx - yy) * sh[..., 16] |
| + C4[1] * yz * (3 * xx - yy) * sh[..., 17] |
| + C4[2] * xy * (7 * zz - 1) * sh[..., 18] |
| + C4[3] * yz * (7 * zz - 3) * sh[..., 19] |
| + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] |
| + C4[5] * xz * (7 * zz - 3) * sh[..., 21] |
| + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] |
| + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] |
| + C4[8] |
| * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) |
| * sh[..., 24] |
| ) |
| return result |
|
|
|
|
| def RGB2SH(rgb): |
| return (rgb - 0.5) / C0 |
|
|
|
|
| def SH2RGB(sh): |
| return sh * C0 + 0.5 |
|
|
|
|
| def create_video(image_folder, output_video_file, framerate=30): |
| |
| images = [img for img in os.listdir(image_folder) if img.endswith(".png")] |
| images.sort() |
|
|
| |
| frame = cv2.imread(os.path.join(image_folder, images[0])) |
| height, width, layers = frame.shape |
|
|
| video = cv2.VideoWriter( |
| output_video_file, cv2.VideoWriter_fourcc(*"mp4v"), framerate, (width, height) |
| ) |
|
|
| |
| for image in images: |
| video.write(cv2.imread(os.path.join(image_folder, image))) |
|
|
| cv2.destroyAllWindows() |
| video.release() |
|
|
|
|
| class Camera(nn.Module): |
| def __init__(self, C2W, fxfycxcy, h, w): |
| """ |
| C2W: 4x4 camera-to-world matrix; opencv convention |
| fxfycxcy: 4 |
| """ |
| super().__init__() |
| self.C2W = C2W.clone().float() |
| self.W2C = self.C2W.inverse() |
| self.h = h |
| self.w = w |
|
|
| self.znear = 0.01 |
| self.zfar = 100.0 |
|
|
| fx, fy, cx, cy = fxfycxcy[0], fxfycxcy[1], fxfycxcy[2], fxfycxcy[3] |
| self.tanfovX = w / (2 * fx) |
| self.tanfovY = h / (2 * fy) |
|
|
| def getProjectionMatrix(W, H, fx, fy, cx, cy, znear, zfar): |
| P = torch.zeros(4, 4, device=fx.device) |
| P[0, 0] = 2 * fx / W |
| P[1, 1] = 2 * fy / H |
| P[0, 2] = 2 * (cx / W) - 1 |
| P[1, 2] = 2 * (cy / H) - 1 |
| P[2, 2] = -(zfar + znear) / (zfar - znear) |
| P[3, 2] = 1.0 |
| P[2, 3] = -(2 * zfar * znear) / (zfar - znear) |
| return P |
|
|
| self.world_view_transform = self.W2C.transpose(0, 1) |
| self.projection_matrix = getProjectionMatrix( |
| self.w, self.h, fx, fy, cx, cy, self.znear, self.zfar |
| ).transpose(0, 1) |
| self.full_proj_transform = ( |
| self.world_view_transform.unsqueeze(0).bmm( |
| self.projection_matrix.unsqueeze(0) |
| ) |
| ).squeeze(0) |
| self.camera_center = self.C2W[:3, 3] |
|
|
|
|
| |
| class GaussianModel: |
| def setup_functions(self): |
| self.scaling_activation = torch.exp |
| self.inv_scaling_activation = torch.log |
| self.rotation_activation = torch.nn.functional.normalize |
| self.opacity_activation = torch.sigmoid |
| self.covariance_activation = build_covariance_from_scaling_rotation |
|
|
| def __init__(self, sh_degree: int, scaling_modifier=None): |
| self.sh_degree = sh_degree |
| self._xyz = torch.empty(0) |
| self._features_dc = torch.empty(0) |
| if self.sh_degree > 0: |
| self._features_rest = torch.empty(0) |
| else: |
| self._features_rest = None |
| self._scaling = torch.empty(0) |
| self._rotation = torch.empty(0) |
| self._opacity = torch.empty(0) |
| self.setup_functions() |
|
|
| self.scaling_modifier = scaling_modifier |
|
|
| def empty(self): |
| self.__init__(self.sh_degree, self.scaling_modifier) |
|
|
| def set_data(self, xyz, features, scaling, rotation, opacity): |
| """ |
| xyz : torch.tensor of shape (N, 3) |
| features : torch.tensor of shape (N, (self.sh_degree + 1) ** 2, 3) |
| scaling : torch.tensor of shape (N, 3) |
| rotation : torch.tensor of shape (N, 4) |
| opacity : torch.tensor of shape (N, 1) |
| """ |
| self._xyz = xyz |
| self._features_dc = features[:, :1, :].contiguous() |
| if self.sh_degree > 0: |
| self._features_rest = features[:, 1:, :].contiguous() |
| else: |
| self._features_rest = None |
| self._scaling = scaling |
| self._rotation = rotation |
| self._opacity = opacity |
| return self |
|
|
| def to(self, device): |
| self._xyz = self._xyz.to(device) |
| self._features_dc = self._features_dc.to(device) |
| if self.sh_degree > 0: |
| self._features_rest = self._features_rest.to(device) |
| self._scaling = self._scaling.to(device) |
| self._rotation = self._rotation.to(device) |
| self._opacity = self._opacity.to(device) |
| return self |
|
|
| def filter(self, valid_mask): |
| self._xyz = self._xyz[valid_mask] |
| self._features_dc = self._features_dc[valid_mask] |
| if self.sh_degree > 0: |
| self._features_rest = self._features_rest[valid_mask] |
| self._scaling = self._scaling[valid_mask] |
| self._rotation = self._rotation[valid_mask] |
| self._opacity = self._opacity[valid_mask] |
| return self |
|
|
| def crop(self, crop_bbx=[-1, 1, -1, 1, -1, 1]): |
| x_min, x_max, y_min, y_max, z_min, z_max = crop_bbx |
| xyz = self._xyz |
| invalid_mask = ( |
| (xyz[:, 0] < x_min) |
| | (xyz[:, 0] > x_max) |
| | (xyz[:, 1] < y_min) |
| | (xyz[:, 1] > y_max) |
| | (xyz[:, 2] < z_min) |
| | (xyz[:, 2] > z_max) |
| ) |
| valid_mask = ~invalid_mask |
|
|
| return self.filter(valid_mask) |
|
|
| def crop_by_xyz(self, floater_thres=0.75): |
| xyz = self._xyz |
| invalid_mask = ( |
| (((xyz[:, 0] < -floater_thres) & (xyz[:, 1] < -floater_thres)) |
| | ((xyz[:, 0] < -floater_thres) & (xyz[:, 1] > floater_thres)) |
| | ((xyz[:, 0] > floater_thres) & (xyz[:, 1] < -floater_thres)) |
| | ((xyz[:, 0] > floater_thres) & (xyz[:, 1] > floater_thres))) |
| & (xyz[:, 2] < -floater_thres) |
| ) |
| valid_mask = ~invalid_mask |
|
|
| return self.filter(valid_mask) |
|
|
| def prune(self, opacity_thres=0.05): |
| opacity = self.get_opacity.squeeze(1) |
| valid_mask = opacity > opacity_thres |
|
|
| return self.filter(valid_mask) |
| |
| def prune_by_scaling(self, scaling_thres=0.1): |
| scaling = self.get_scaling |
| valid_mask = scaling.max(dim=1).values < scaling_thres |
| position_mask = self._xyz[:, 2] > 0 |
|
|
| valid_mask = valid_mask | position_mask |
|
|
| return self.filter(valid_mask) |
|
|
| def prune_by_nearfar(self, cam_origins, nearfar_percent=(0.01, 0.99)): |
| |
| |
| assert len(nearfar_percent) == 2 |
| assert nearfar_percent[0] < nearfar_percent[1] |
| assert nearfar_percent[0] >= 0 and nearfar_percent[1] <= 1 |
|
|
| device = self._xyz.device |
| |
| |
| dists = torch.cdist(self._xyz[None], cam_origins[None].to(device))[0] |
| |
| dists_percentile = torch.quantile( |
| dists, torch.tensor(nearfar_percent).to(device), dim=0 |
| ) |
| |
| |
| |
| reject_mask = (dists < dists_percentile[0:1, :]) | ( |
| dists > dists_percentile[1:2, :] |
| ) |
| reject_mask = reject_mask.any(dim=1) |
| valid_mask = ~reject_mask |
|
|
| return self.filter(valid_mask) |
|
|
| def apply_all_filters( |
| self, |
| opacity_thres=0.05, |
| scaling_thres=None, |
| floater_thres=None, |
| crop_bbx=[-1, 1, -1, 1, -1, 1], |
| cam_origins=None, |
| nearfar_percent=(0.005, 1.0), |
| ): |
| self.prune(opacity_thres) |
| if scaling_thres is not None: |
| self.prune_by_scaling(scaling_thres) |
| if floater_thres is not None: |
| self.crop_by_xyz(floater_thres) |
| if crop_bbx is not None: |
| self.crop(crop_bbx) |
| if cam_origins is not None: |
| self.prune_by_nearfar(cam_origins, nearfar_percent) |
| return self |
|
|
| def shrink_bbx(self, drop_ratio=0.05): |
| xyz = self._xyz |
| xyz_min, xyz_max = torch.quantile( |
| xyz, |
| torch.tensor([drop_ratio, 1 - drop_ratio]).float().to(xyz.device), |
| dim=0, |
| ) |
| xyz_min = xyz_min.detach().cpu().numpy() |
| xyz_max = xyz_max.detach().cpu().numpy() |
| crop_bbx = [ |
| xyz_min[0], |
| xyz_max[0], |
| xyz_min[1], |
| xyz_max[1], |
| xyz_min[2], |
| xyz_max[2], |
| ] |
| print(f"Shrinking bbx to {crop_bbx}") |
| return self.crop(crop_bbx) |
|
|
| def report_stats(self): |
| print( |
| f"xyz: {self._xyz.shape}, {self._xyz.min().item()}, {self._xyz.max().item()}" |
| ) |
| print( |
| f"features_dc: {self._features_dc.shape}, {self._features_dc.min().item()}, {self._features_dc.max().item()}" |
| ) |
| if self.sh_degree > 0: |
| print( |
| f"features_rest: {self._features_rest.shape}, {self._features_rest.min().item()}, {self._features_rest.max().item()}" |
| ) |
| print( |
| f"scaling: {self._scaling.shape}, {self._scaling.min().item()}, {self._scaling.max().item()}" |
| ) |
| print( |
| f"rotation: {self._rotation.shape}, {self._rotation.min().item()}, {self._rotation.max().item()}" |
| ) |
| print( |
| f"opacity: {self._opacity.shape}, {self._opacity.min().item()}, {self._opacity.max().item()}" |
| ) |
|
|
| print( |
| f"after activation, xyz: {self.get_xyz.shape}, {self.get_xyz.min().item()}, {self.get_xyz.max().item()}" |
| ) |
| print( |
| f"after activation, features: {self.get_features.shape}, {self.get_features.min().item()}, {self.get_features.max().item()}" |
| ) |
| print( |
| f"after activation, scaling: {self.get_scaling.shape}, {self.get_scaling.min().item()}, {self.get_scaling.max().item()}" |
| ) |
| print( |
| f"after activation, rotation: {self.get_rotation.shape}, {self.get_rotation.min().item()}, {self.get_rotation.max().item()}" |
| ) |
| print( |
| f"after activation, opacity: {self.get_opacity.shape}, {self.get_opacity.min().item()}, {self.get_opacity.max().item()}" |
| ) |
| print( |
| f"after activation, covariance: {self.get_covariance().shape}, {self.get_covariance().min().item()}, {self.get_covariance().max().item()}" |
| ) |
|
|
| @property |
| def get_scaling(self): |
| if self.scaling_modifier is not None: |
| return self.scaling_activation(self._scaling) * self.scaling_modifier |
| else: |
| return self.scaling_activation(self._scaling) |
|
|
| @property |
| def get_rotation(self): |
| return self.rotation_activation(self._rotation) |
|
|
| @property |
| def get_xyz(self): |
| return self._xyz |
|
|
| @property |
| def get_features(self): |
| if self.sh_degree > 0: |
| features_dc = self._features_dc |
| features_rest = self._features_rest |
| return torch.cat((features_dc, features_rest), dim=1) |
| else: |
| return self._features_dc |
|
|
| @property |
| def get_opacity(self): |
| return self.opacity_activation(self._opacity) |
|
|
| def get_covariance(self, scaling_modifier=1): |
| return self.covariance_activation( |
| self.get_scaling, scaling_modifier, self._rotation |
| ) |
|
|
| def construct_dtypes(self, use_fp16=False, enable_gs_viewer=True): |
| if not use_fp16: |
| l = [ |
| ("x", "f4"), |
| ("y", "f4"), |
| ("z", "f4"), |
| ("red", "u1"), |
| ("green", "u1"), |
| ("blue", "u1"), |
| ] |
| |
| for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): |
| l.append((f"f_dc_{i}", "f4")) |
|
|
| if enable_gs_viewer: |
| assert self.sh_degree <= 3, "GS viewer only supports SH up to degree 3" |
| sh_degree = 3 |
| for i in range(((sh_degree + 1) ** 2 - 1) * 3): |
| l.append((f"f_rest_{i}", "f4")) |
| else: |
| if self.sh_degree > 0: |
| for i in range( |
| self._features_rest.shape[1] * self._features_rest.shape[2] |
| ): |
| l.append((f"f_rest_{i}", "f4")) |
|
|
| l.append(("opacity", "f4")) |
| for i in range(self._scaling.shape[1]): |
| l.append((f"scale_{i}", "f4")) |
| for i in range(self._rotation.shape[1]): |
| l.append((f"rot_{i}", "f4")) |
| else: |
| l = [ |
| ("x", "f2"), |
| ("y", "f2"), |
| ("z", "f2"), |
| ("red", "u1"), |
| ("green", "u1"), |
| ("blue", "u1"), |
| ] |
| |
| for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): |
| l.append((f"f_dc_{i}", "f2")) |
|
|
| if self.sh_degree > 0: |
| for i in range( |
| self._features_rest.shape[1] * self._features_rest.shape[2] |
| ): |
| l.append((f"f_rest_{i}", "f2")) |
| l.append(("opacity", "f2")) |
| for i in range(self._scaling.shape[1]): |
| l.append((f"scale_{i}", "f2")) |
| for i in range(self._rotation.shape[1]): |
| l.append((f"rot_{i}", "f2")) |
| return l |
|
|
| def save_ply( |
| self, |
| path, |
| use_fp16=False, |
| enable_gs_viewer=True, |
| color_code=False, |
| filter_mask=None, |
| ): |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
| xyz = self._xyz.detach().cpu().numpy() |
| f_dc = ( |
| self._features_dc.detach() |
| .transpose(1, 2) |
| .flatten(start_dim=1) |
| .contiguous() |
| .cpu() |
| .numpy() |
| ) |
| if not color_code: |
| rgb = (SH2RGB(f_dc) * 255.0).clip(0.0, 255.0).astype(np.uint8) |
| else: |
| |
| index = np.linspace(0, 1, xyz.shape[0]) |
| rgb = matplotlib.colormaps["viridis"](index)[..., :3] |
| rgb = (rgb * 255.0).clip(0.0, 255.0).astype(np.uint8) |
|
|
| opacities = self._opacity.detach().cpu().numpy() |
| if self.scaling_modifier is not None: |
| scale = self.inv_scaling_activation(self.get_scaling).detach().cpu().numpy() |
| else: |
| scale = self._scaling.detach().cpu().numpy() |
| rotation = self._rotation.detach().cpu().numpy() |
|
|
| dtype_full = self.construct_dtypes(use_fp16, enable_gs_viewer) |
| elements = np.empty(xyz.shape[0], dtype=dtype_full) |
|
|
| f_rest = None |
| if self.sh_degree > 0: |
| f_rest = ( |
| self._features_rest.detach() |
| .transpose(1, 2) |
| .flatten(start_dim=1) |
| .contiguous() |
| .cpu() |
| .numpy() |
| ) |
|
|
| if enable_gs_viewer: |
| sh_degree = 3 |
| if f_rest is None: |
| f_rest = np.zeros( |
| (xyz.shape[0], 3 * ((sh_degree + 1) ** 2 - 1)), dtype=np.float32 |
| ) |
| elif f_rest.shape[1] < 3 * ((sh_degree + 1) ** 2 - 1): |
| f_rest_pad = np.zeros( |
| (xyz.shape[0], 3 * ((sh_degree + 1) ** 2 - 1)), dtype=np.float32 |
| ) |
| f_rest_pad[:, : f_rest.shape[1]] = f_rest |
| f_rest = f_rest_pad |
|
|
| if f_rest is not None: |
| attributes = np.concatenate( |
| (xyz, rgb, f_dc, f_rest, opacities, scale, rotation), axis=1 |
| ) |
| else: |
| attributes = np.concatenate( |
| (xyz, rgb, f_dc, opacities, scale, rotation), axis=1 |
| ) |
|
|
| if filter_mask is not None: |
| attributes = attributes[filter_mask] |
| elements = elements[filter_mask] |
|
|
| elements[:] = list(map(tuple, attributes)) |
| el = PlyElement.describe(elements, "vertex") |
| PlyData([el]).write(path) |
|
|
| def load_ply(self, path): |
| plydata = PlyData.read(path) |
|
|
| xyz = np.stack( |
| ( |
| np.asarray(plydata.elements[0]["x"]), |
| np.asarray(plydata.elements[0]["y"]), |
| np.asarray(plydata.elements[0]["z"]), |
| ), |
| axis=1, |
| ) |
| opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] |
|
|
| features_dc = np.zeros((xyz.shape[0], 3, 1)) |
| features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) |
| features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) |
| features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) |
|
|
| if self.sh_degree > 0: |
| extra_f_names = [ |
| p.name |
| for p in plydata.elements[0].properties |
| if p.name.startswith("f_rest_") |
| ] |
| extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) |
| assert len(extra_f_names) == 3 * (self.sh_degree + 1) ** 2 - 3 |
| features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) |
| for idx, attr_name in enumerate(extra_f_names): |
| features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
| |
| features_extra = features_extra.reshape( |
| (features_extra.shape[0], 3, (self.sh_degree + 1) ** 2 - 1) |
| ) |
|
|
| scale_names = [ |
| p.name |
| for p in plydata.elements[0].properties |
| if p.name.startswith("scale_") |
| ] |
| scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) |
| scales = np.zeros((xyz.shape[0], len(scale_names))) |
| for idx, attr_name in enumerate(scale_names): |
| scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
|
|
| rot_names = [ |
| p.name for p in plydata.elements[0].properties if p.name.startswith("rot") |
| ] |
| rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) |
| rots = np.zeros((xyz.shape[0], len(rot_names))) |
| for idx, attr_name in enumerate(rot_names): |
| rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) |
|
|
| self._xyz = torch.from_numpy(xyz.astype(np.float32)) |
| self._features_dc = ( |
| torch.from_numpy(features_dc.astype(np.float32)) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
| if self.sh_degree > 0: |
| self._features_rest = ( |
| torch.from_numpy(features_extra.astype(np.float32)) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
| self._opacity = torch.from_numpy( |
| np.copy(opacities).astype(np.float32) |
| ).contiguous() |
| self._scaling = torch.from_numpy(scales.astype(np.float32)).contiguous() |
| self._rotation = torch.from_numpy(rots.astype(np.float32)).contiguous() |
|
|
|
|
| def render_opencv_cam( |
| pc: GaussianModel, |
| height: int, |
| width: int, |
| C2W: torch.Tensor, |
| fxfycxcy: torch.Tensor, |
| bg_color=(1.0, 1.0, 1.0), |
| scaling_modifier=1.0, |
| ): |
| """ |
| Render the scene. |
| |
| Background tensor (bg_color) must be on GPU! |
| """ |
| |
| screenspace_points = torch.empty_like( |
| pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" |
| ) |
| |
| |
| |
| |
|
|
| viewpoint_camera = Camera(C2W=C2W, fxfycxcy=fxfycxcy, h=height, w=width) |
|
|
| bg_color = torch.tensor(list(bg_color), dtype=torch.float32, device=C2W.device) |
|
|
| |
| raster_settings = GaussianRasterizationSettings( |
| image_height=int(viewpoint_camera.h), |
| image_width=int(viewpoint_camera.w), |
| tanfovx=viewpoint_camera.tanfovX, |
| tanfovy=viewpoint_camera.tanfovY, |
| bg=bg_color, |
| scale_modifier=scaling_modifier, |
| viewmatrix=viewpoint_camera.world_view_transform, |
| projmatrix=viewpoint_camera.full_proj_transform, |
| sh_degree=pc.sh_degree, |
| campos=viewpoint_camera.camera_center, |
| prefiltered=False, |
| debug=False, |
| ) |
|
|
| rasterizer = GaussianRasterizer(raster_settings=raster_settings) |
|
|
| means3D = pc.get_xyz |
| means2D = screenspace_points |
| opacity = pc.get_opacity |
| scales = pc.get_scaling |
| rotations = pc.get_rotation |
| shs = pc.get_features |
|
|
| |
| rendered_image, radii = rasterizer( |
| means3D=means3D, |
| means2D=means2D, |
| shs=shs, |
| colors_precomp=None, |
| opacities=opacity, |
| scales=scales, |
| rotations=rotations, |
| cov3D_precomp=None, |
| ) |
|
|
| |
| |
| return { |
| "render": rendered_image, |
| "viewspace_points": screenspace_points, |
| "visibility_filter": radii > 0, |
| "radii": radii, |
| } |
|
|
|
|
| class DeferredGaussianRender(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| xyz, |
| features, |
| scaling, |
| rotation, |
| opacity, |
| height, |
| width, |
| C2W, |
| fxfycxcy, |
| scaling_modifier=None, |
| ): |
| """ |
| xyz: [b, n_gaussians, 3] |
| features: [b, n_gaussians, (sh_degree+1)^2, 3] |
| scaling: [b, n_gaussians, 3] |
| rotation: [b, n_gaussians, 4] |
| opacity: [b, n_gaussians, 1] |
| |
| height: int |
| width: int |
| C2W: [b, v, 4, 4] |
| fxfycxcy: [b, v, 4] |
| |
| output: [b, v, 3, height, width] |
| """ |
| ctx.scaling_modifier = scaling_modifier |
|
|
| |
| sh_degree = int(math.sqrt(features.shape[-2])) - 1 |
|
|
| |
| gaussians_model = GaussianModel(sh_degree, scaling_modifier) |
|
|
| with torch.no_grad(): |
| b, v = C2W.size(0), C2W.size(1) |
| renders = [] |
| for i in range(b): |
| pc = gaussians_model.set_data( |
| xyz[i], features[i], scaling[i], rotation[i], opacity[i] |
| ) |
| for j in range(v): |
| renders.append( |
| render_opencv_cam(pc, height, width, C2W[i, j], fxfycxcy[i, j])[ |
| "render" |
| ] |
| ) |
| renders = torch.stack(renders, dim=0) |
| renders = renders.reshape(b, v, 3, height, width) |
|
|
| renders = renders.requires_grad_() |
|
|
| |
| ctx.save_for_backward(xyz, features, scaling, rotation, opacity, C2W, fxfycxcy) |
| ctx.rendering_size = (height, width) |
| ctx.sh_degree = sh_degree |
|
|
| |
| del gaussians_model |
|
|
| return renders |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| |
| xyz, features, scaling, rotation, opacity, C2W, fxfycxcy = ctx.saved_tensors |
| height, width = ctx.rendering_size |
| sh_degree = ctx.sh_degree |
|
|
| |
| input_dict = OrderedDict( |
| [ |
| ("xyz", xyz), |
| ("features", features), |
| ("scaling", scaling), |
| ("rotation", rotation), |
| ("opacity", opacity), |
| ] |
| ) |
| input_dict = {k: v.detach().requires_grad_() for k, v in input_dict.items()} |
|
|
| |
| gaussians_model = GaussianModel(sh_degree, ctx.scaling_modifier) |
|
|
| with torch.enable_grad(): |
| b, v = C2W.size(0), C2W.size(1) |
| for i in range(b): |
| for j in range(v): |
| |
| pc = gaussians_model.set_data( |
| **{k: v[i] for k, v in input_dict.items()} |
| ) |
|
|
| |
| render = render_opencv_cam( |
| pc, height, width, C2W[i, j], fxfycxcy[i, j] |
| )["render"] |
|
|
| |
| render.backward(grad_output[i, j]) |
|
|
| del gaussians_model |
|
|
| return *[var.grad for var in input_dict.values()], None, None, None, None, None |
|
|
|
|
| |
| deferred_gaussian_render = DeferredGaussianRender.apply |
|
|
| @torch.no_grad() |
| @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) |
| def render_turntable(pc: GaussianModel, rendering_resolution=384, num_views=8): |
| w, h, v, fxfycxcy, c2w = get_turntable_cameras( |
| h=rendering_resolution, w=rendering_resolution, num_views=num_views, |
| elevation=0, |
| ) |
|
|
| device = pc._xyz.device |
| fxfycxcy = torch.from_numpy(fxfycxcy).float().to(device) |
| c2w = torch.from_numpy(c2w).float().to(device) |
|
|
| renderings = torch.zeros(v, 3, h, w, dtype=torch.float32, device=device) |
| for j in range(v): |
| renderings[j] = render_opencv_cam(pc, h, w, c2w[j], fxfycxcy[j])["render"] |
| torch.cuda.empty_cache() |
| renderings = renderings.detach().cpu().numpy() |
| renderings = (renderings * 255).clip(0, 255).astype(np.uint8) |
| renderings = rearrange(renderings, "v c h w -> h (v w) c") |
| return renderings |
|
|
|
|
| if __name__ == "__main__": |
| import json |
|
|
| from PIL import Image |
| from tqdm import tqdm |
|
|
| out_dir = "/mnt/localssd/debug-3dgs" |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| os.system( |
| f"wget https://phidias.s3.us-west-2.amazonaws.com/kaiz/neural-capture/eval-3dgs-lowres/AWS_test_set/results/1.fashion_boots_rubber_boots__short__Feb_21__2023_at_5_19_25_PM_yf/point_cloud/iteration_30000_fg/point_cloud.ply -O {out_dir}/point_cloud.ply" |
| ) |
| os.system( |
| f"wget https://neural-capture.s3.us-west-2.amazonaws.com/data/AWS_test_set/preprocessed/1.fashion_boots_rubber_boots__short__Feb_21__2023_at_5_19_25_PM_yf/opencv_cameras_traj_norm.json -O {out_dir}/opencv_cameras_traj_norm.json" |
| ) |
|
|
| device = "cuda:0" |
|
|
| pc = GaussianModel(sh_degree=3) |
| pc.load_ply(f"{out_dir}/point_cloud.ply") |
| pc = pc.to(device) |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| pc.report_stats() |
|
|
| with open(f"{out_dir}/opencv_cameras_traj_norm.json", "r") as f: |
| cam_traj = json.load(f) |
|
|
| for i, cam in tqdm(enumerate(cam_traj["frames"]), desc="Rendering progress"): |
| w2c = np.array(cam["w2c"]) |
| c2w = np.linalg.inv(w2c) |
| c2w = torch.from_numpy(c2w.astype(np.float32)).to(device) |
|
|
| fx = cam["fx"] |
| fy = cam["fy"] |
| cx = cam["cx"] |
| cy = cam["cy"] |
| cx = cx - 5 |
| cy = cy + 4 |
| fxfycxcy = torch.tensor([fx, fy, cx, cy], dtype=torch.float32, device=device) |
|
|
| h = cam["h"] |
| w = cam["w"] |
|
|
| im = render_opencv_cam(pc, h, w, c2w, fxfycxcy, bg_color=[0.0, 0.0, 0.0])[ |
| "render" |
| ] |
| im = im.detach().cpu().numpy().transpose(1, 2, 0) |
| im = (im * 255).astype(np.uint8) |
| Image.fromarray(im).save(f"{out_dir}/render_{i:08d}.png") |
|
|
| create_video(out_dir, f"{out_dir}/render.mp4", framerate=30) |
| print(f"Saved {out_dir}/render.mp4") |
|
|