|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def activate_head_gs(out, activation="norm_exp", conf_activation="expp1", conf_dim=None): |
|
|
""" |
|
|
Process network output to extract GS params and density values. |
|
|
Density could be view-dependent as SH coefficient |
|
|
|
|
|
|
|
|
Args: |
|
|
out: Network output tensor (B, C, H, W) |
|
|
activation: Activation type for 3D points |
|
|
conf_activation: Activation type for confidence values |
|
|
|
|
|
Returns: |
|
|
Tuple of (3D points tensor, confidence tensor) |
|
|
""" |
|
|
|
|
|
fmap = out.permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
conf_dim = 1 if conf_dim is None else conf_dim |
|
|
xyz = fmap[:, :, :, :-conf_dim] |
|
|
conf = fmap[:, :, :, -1] if conf_dim == 1 else fmap[:, :, :, -conf_dim:] |
|
|
|
|
|
if activation == "norm_exp": |
|
|
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) |
|
|
xyz_normed = xyz / d |
|
|
pts3d = xyz_normed * torch.expm1(d) |
|
|
elif activation == "norm": |
|
|
pts3d = xyz / xyz.norm(dim=-1, keepdim=True) |
|
|
elif activation == "exp": |
|
|
pts3d = torch.exp(xyz) |
|
|
elif activation == "relu": |
|
|
pts3d = F.relu(xyz) |
|
|
elif activation == "sigmoid": |
|
|
pts3d = torch.sigmoid(xyz) |
|
|
elif activation == "linear": |
|
|
pts3d = xyz |
|
|
else: |
|
|
raise ValueError(f"Unknown activation: {activation}") |
|
|
|
|
|
if conf_activation == "expp1": |
|
|
conf_out = 1 + conf.exp() |
|
|
elif conf_activation == "expp0": |
|
|
conf_out = conf.exp() |
|
|
elif conf_activation == "sigmoid": |
|
|
conf_out = torch.sigmoid(conf) |
|
|
elif conf_activation == "linear": |
|
|
conf_out = conf |
|
|
else: |
|
|
raise ValueError(f"Unknown conf_activation: {conf_activation}") |
|
|
|
|
|
return pts3d, conf_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Permute(nn.Module): |
|
|
"""nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage.""" |
|
|
|
|
|
dims: Tuple[int, ...] |
|
|
|
|
|
def __init__(self, dims: Tuple[int, ...]) -> None: |
|
|
super().__init__() |
|
|
self.dims = dims |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x.permute(*self.dims) |
|
|
|
|
|
|
|
|
def position_grid_to_embed( |
|
|
pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100 |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) |
|
|
|
|
|
Args: |
|
|
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates |
|
|
embed_dim: Output channel dimension for embeddings |
|
|
|
|
|
Returns: |
|
|
Tensor of shape (H, W, embed_dim) with positional embeddings |
|
|
""" |
|
|
H, W, grid_dim = pos_grid.shape |
|
|
assert grid_dim == 2 |
|
|
pos_flat = pos_grid.reshape(-1, grid_dim) |
|
|
|
|
|
|
|
|
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) |
|
|
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) |
|
|
|
|
|
|
|
|
emb = torch.cat([emb_x, emb_y], dim=-1) |
|
|
|
|
|
return emb.view(H, W, embed_dim) |
|
|
|
|
|
|
|
|
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: |
|
|
""" |
|
|
This function generates a 1D positional embedding from a given grid using sine and cosine functions. # noqa |
|
|
|
|
|
Args: |
|
|
- embed_dim: The embedding dimension. |
|
|
- pos: The position to generate the embedding from. |
|
|
|
|
|
Returns: |
|
|
- emb: The generated 1D positional embedding. |
|
|
""" |
|
|
assert embed_dim % 2 == 0 |
|
|
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) |
|
|
omega /= embed_dim / 2.0 |
|
|
omega = 1.0 / omega_0**omega |
|
|
|
|
|
pos = pos.reshape(-1) |
|
|
out = torch.einsum("m,d->md", pos, omega) |
|
|
|
|
|
emb_sin = torch.sin(out) |
|
|
emb_cos = torch.cos(out) |
|
|
|
|
|
emb = torch.cat([emb_sin, emb_cos], dim=1) |
|
|
return emb.float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_uv_grid( |
|
|
width: int, |
|
|
height: int, |
|
|
aspect_ratio: float = None, |
|
|
dtype: torch.dtype = None, |
|
|
device: torch.device = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Create a normalized UV grid of shape (width, height, 2). |
|
|
|
|
|
The grid spans horizontally and vertically according to an aspect ratio, |
|
|
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right |
|
|
corner is at (x_span, y_span), normalized by the diagonal of the plane. |
|
|
|
|
|
Args: |
|
|
width (int): Number of points horizontally. |
|
|
height (int): Number of points vertically. |
|
|
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. |
|
|
dtype (torch.dtype, optional): Data type of the resulting tensor. |
|
|
device (torch.device, optional): Device on which the tensor is created. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: A (width, height, 2) tensor of UV coordinates. |
|
|
""" |
|
|
|
|
|
if aspect_ratio is None: |
|
|
aspect_ratio = float(width) / float(height) |
|
|
|
|
|
|
|
|
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 |
|
|
span_x = aspect_ratio / diag_factor |
|
|
span_y = 1.0 / diag_factor |
|
|
|
|
|
|
|
|
left_x = -span_x * (width - 1) / width |
|
|
right_x = span_x * (width - 1) / width |
|
|
top_y = -span_y * (height - 1) / height |
|
|
bottom_y = span_y * (height - 1) / height |
|
|
|
|
|
|
|
|
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) |
|
|
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") |
|
|
uv_grid = torch.stack((uu, vv), dim=-1) |
|
|
|
|
|
return uv_grid |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def custom_interpolate( |
|
|
x: torch.Tensor, |
|
|
size: Union[Tuple[int, int], None] = None, |
|
|
scale_factor: Union[float, None] = None, |
|
|
mode: str = "bilinear", |
|
|
align_corners: bool = True, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Safe interpolation implementation to avoid INT_MAX overflow in torch.nn.functional.interpolate. |
|
|
""" |
|
|
if size is None: |
|
|
assert scale_factor is not None, "Either size or scale_factor must be provided." |
|
|
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) |
|
|
|
|
|
INT_MAX = 1610612736 |
|
|
total = size[0] * size[1] * x.shape[0] * x.shape[1] |
|
|
|
|
|
if total > INT_MAX: |
|
|
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) |
|
|
outs = [ |
|
|
nn.functional.interpolate(c, size=size, mode=mode, align_corners=align_corners) |
|
|
for c in chunks |
|
|
] |
|
|
return torch.cat(outs, dim=0).contiguous() |
|
|
|
|
|
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) |
|
|
|