| | import random |
| |
|
| | import numpy as np |
| | import torch |
| | from sklearn.neighbors import NearestNeighbors |
| | from torch import Tensor |
| | import torch.nn.functional as F |
| | import matplotlib.pyplot as plt |
| | from matplotlib import colormaps |
| |
|
| |
|
| | class CameraOptModule(torch.nn.Module): |
| | """Camera pose optimization module.""" |
| |
|
| | def __init__(self, n: int): |
| | super().__init__() |
| | |
| | self.embeds = torch.nn.Embedding(n, 9) |
| | |
| | self.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0])) |
| |
|
| | def zero_init(self): |
| | torch.nn.init.zeros_(self.embeds.weight) |
| |
|
| | def random_init(self, std: float): |
| | torch.nn.init.normal_(self.embeds.weight, std=std) |
| |
|
| | def forward(self, camtoworlds: Tensor, embed_ids: Tensor) -> Tensor: |
| | """Adjust camera pose based on deltas. |
| | |
| | Args: |
| | camtoworlds: (..., 4, 4) |
| | embed_ids: (...,) |
| | |
| | Returns: |
| | updated camtoworlds: (..., 4, 4) |
| | """ |
| | assert camtoworlds.shape[:-2] == embed_ids.shape |
| | batch_shape = camtoworlds.shape[:-2] |
| | pose_deltas = self.embeds(embed_ids) |
| | dx, drot = pose_deltas[..., :3], pose_deltas[..., 3:] |
| | rot = rotation_6d_to_matrix( |
| | drot + self.identity.expand(*batch_shape, -1) |
| | ) |
| | transform = torch.eye(4, device=pose_deltas.device).repeat((*batch_shape, 1, 1)) |
| | transform[..., :3, :3] = rot |
| | transform[..., :3, 3] = dx |
| | return torch.matmul(camtoworlds, transform) |
| |
|
| |
|
| | class AppearanceOptModule(torch.nn.Module): |
| | """Appearance optimization module.""" |
| |
|
| | def __init__( |
| | self, |
| | n: int, |
| | feature_dim: int, |
| | embed_dim: int = 16, |
| | sh_degree: int = 3, |
| | mlp_width: int = 64, |
| | mlp_depth: int = 2, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.sh_degree = sh_degree |
| | self.embeds = torch.nn.Embedding(n, embed_dim) |
| | layers = [] |
| | layers.append( |
| | torch.nn.Linear(embed_dim + feature_dim + (sh_degree + 1) ** 2, mlp_width) |
| | ) |
| | layers.append(torch.nn.ReLU(inplace=True)) |
| | for _ in range(mlp_depth - 1): |
| | layers.append(torch.nn.Linear(mlp_width, mlp_width)) |
| | layers.append(torch.nn.ReLU(inplace=True)) |
| | layers.append(torch.nn.Linear(mlp_width, 3)) |
| | self.color_head = torch.nn.Sequential(*layers) |
| |
|
| | def forward( |
| | self, features: Tensor, embed_ids: Tensor, dirs: Tensor, sh_degree: int |
| | ) -> Tensor: |
| | """Adjust appearance based on embeddings. |
| | |
| | Args: |
| | features: (N, feature_dim) |
| | embed_ids: (C,) |
| | dirs: (C, N, 3) |
| | |
| | Returns: |
| | colors: (C, N, 3) |
| | """ |
| | from gsplat.cuda._torch_impl import _eval_sh_bases_fast |
| |
|
| | C, N = dirs.shape[:2] |
| | |
| | if embed_ids is None: |
| | embeds = torch.zeros(C, self.embed_dim, device=features.device) |
| | else: |
| | embeds = self.embeds(embed_ids) |
| | embeds = embeds[:, None, :].expand(-1, N, -1) |
| | |
| | features = features[None, :, :].expand(C, -1, -1) |
| | |
| | dirs = F.normalize(dirs, dim=-1) |
| | num_bases_to_use = (sh_degree + 1) ** 2 |
| | num_bases = (self.sh_degree + 1) ** 2 |
| | sh_bases = torch.zeros(C, N, num_bases, device=features.device) |
| | sh_bases[:, :, :num_bases_to_use] = _eval_sh_bases_fast(num_bases_to_use, dirs) |
| | |
| | if self.embed_dim > 0: |
| | h = torch.cat([embeds, features, sh_bases], dim=-1) |
| | else: |
| | h = torch.cat([features, sh_bases], dim=-1) |
| | colors = self.color_head(h) |
| | return colors |
| |
|
| |
|
| | def rotation_6d_to_matrix(d6: Tensor) -> Tensor: |
| | """ |
| | Converts 6D rotation representation by Zhou et al. [1] to rotation matrix |
| | using Gram--Schmidt orthogonalization per Section B of [1]. Adapted from pytorch3d. |
| | Args: |
| | d6: 6D rotation representation, of size (*, 6) |
| | |
| | Returns: |
| | batch of rotation matrices of size (*, 3, 3) |
| | |
| | [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. |
| | On the Continuity of Rotation Representations in Neural Networks. |
| | IEEE Conference on Computer Vision and Pattern Recognition, 2019. |
| | Retrieved from http://arxiv.org/abs/1812.07035 |
| | """ |
| |
|
| | a1, a2 = d6[..., :3], d6[..., 3:] |
| | b1 = F.normalize(a1, dim=-1) |
| | b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 |
| | b2 = F.normalize(b2, dim=-1) |
| | b3 = torch.cross(b1, b2, dim=-1) |
| | return torch.stack((b1, b2, b3), dim=-2) |
| |
|
| |
|
| | def knn(x: Tensor, K: int = 4) -> Tensor: |
| | x_np = x.cpu().numpy() |
| | model = NearestNeighbors(n_neighbors=K, metric="euclidean").fit(x_np) |
| | distances, _ = model.kneighbors(x_np) |
| | return torch.from_numpy(distances).to(x) |
| |
|
| |
|
| | def rgb_to_sh(rgb: Tensor) -> Tensor: |
| | C0 = 0.28209479177387814 |
| | return (rgb - 0.5) / C0 |
| |
|
| |
|
| | def set_random_seed(seed: int): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| |
|
| |
|
| | |
| | def colormap(img, cmap="jet"): |
| | W, H = img.shape[:2] |
| | dpi = 300 |
| | fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) |
| | im = ax.imshow(img, cmap=cmap) |
| | ax.set_axis_off() |
| | fig.colorbar(im, ax=ax) |
| | fig.tight_layout() |
| | fig.canvas.draw() |
| | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
| | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| | img = torch.from_numpy(data).float().permute(2, 0, 1) |
| | plt.close() |
| | return img |
| |
|
| |
|
| | def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: |
| | """Convert single channel to a color img. |
| | |
| | Args: |
| | img (torch.Tensor): (..., 1) float32 single channel image. |
| | colormap (str): Colormap for img. |
| | |
| | Returns: |
| | (..., 3) colored img with colors in [0, 1]. |
| | """ |
| | img = torch.nan_to_num(img, 0) |
| | if colormap == "gray": |
| | return img.repeat(1, 1, 3) |
| | img_long = (img * 255).long() |
| | img_long_min = torch.min(img_long) |
| | img_long_max = torch.max(img_long) |
| | assert img_long_min >= 0, f"the min value is {img_long_min}" |
| | assert img_long_max <= 255, f"the max value is {img_long_max}" |
| | return torch.tensor( |
| | colormaps[colormap].colors, |
| | device=img.device, |
| | )[img_long[..., 0]] |
| |
|
| |
|
| | def apply_depth_colormap( |
| | depth: torch.Tensor, |
| | acc: torch.Tensor = None, |
| | near_plane: float = None, |
| | far_plane: float = None, |
| | ) -> torch.Tensor: |
| | """Converts a depth image to color for easier analysis. |
| | |
| | Args: |
| | depth (torch.Tensor): (..., 1) float32 depth. |
| | acc (torch.Tensor | None): (..., 1) optional accumulation mask. |
| | near_plane: Closest depth to consider. If None, use min image value. |
| | far_plane: Furthest depth to consider. If None, use max image value. |
| | |
| | Returns: |
| | (..., 3) colored depth image with colors in [0, 1]. |
| | """ |
| | near_plane = near_plane or float(torch.min(depth)) |
| | far_plane = far_plane or float(torch.max(depth)) |
| | depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) |
| | depth = torch.clip(depth, 0.0, 1.0) |
| | img = apply_float_colormap(depth, colormap="turbo") |
| | if acc is not None: |
| | img = img * acc + (1.0 - acc) |
| | return img |
| |
|