SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import torch
from jaxtyping import Float
from torch import nn as nn, Tensor
class SlicedG3RNorm(nn.Module):
def __init__(self, num_features, input_slice, eps=1e-8):
"""
Apply G3R normalization to a slice of the input features.
Devide the input by the maximum absolute value in each channel.
Args:
num_features (int): Total number of features (channels).
input_slice (slice): Size of each slice to normalize independently.
eps (float): Small constant to prevent division by zero.
"""
super().__init__()
self.input_slice = input_slice
dummy = torch.zeros(1, num_features)
chunk = dummy[:, input_slice]
self.slice_size = chunk.shape[-1]
self.eps = eps
def forward(self, x: Float[Tensor, "B C"]):
"""
Args:
x (Tensor): Shape (B, C) where C = num_features
Returns:
Tensor: Same shape, only subset of channels normalized
"""
# Split input into the slice to normalize and the rest
chunk = x[:, self.input_slice]
# Compute max absolute value per channel
# Detach to avoid backpropagating through max operation
max_val_per_channel = chunk.abs().max(0, keepdim=True)[0].detach() + self.eps
# Apply G3R normalization to the selected slice
# Replace the normalized slice back into the original input
x = x.clone()
x[:, self.input_slice] = chunk / max_val_per_channel
return x
class SlicedBatchNorm1d(nn.Module):
def __init__(self, num_features, input_slice, eps=1e-8, affine=False, track_running_stats=True):
"""
Apply normalization independently to a slice of the input features.
Args:
num_features (int): Total number of features (channels).
input_slice (slice): Size of each slice to normalize independently.
eps (float): Small constant to prevent division by zero.
affine (bool): Whether to include learnable scale and bias per slice.
"""
super().__init__()
self.input_slice = input_slice
dummy = torch.zeros(1, num_features)
chunk = dummy[:, input_slice]
self.slice_size = chunk.shape[-1]
self.eps = eps
# Create a BatchNorm1d module for each slice
self.slice_norm = nn.BatchNorm1d(self.slice_size, eps=eps, affine=affine,
track_running_stats=track_running_stats)
def forward(self, x):
"""
Args:
x (Tensor): Shape (B, C) where C = num_features
Returns:
Tensor: Same shape, only subset of channels normalized
"""
B, C = x.shape
# Split input into the slice to normalize and the rest
chunk = x[:, self.input_slice]
# Apply normalization to the selected slice
chunk = self.slice_norm(chunk)
# Replace the normalized slice back into the original input
x = x.clone()
x[:, self.input_slice] = chunk
return x
class CustomGroupNorm(nn.Module):
def __init__(self, group_sizes, eps=1e-8, affine=True):
"""
Args:
group_sizes (list[int]): List of channel counts for each group. Must sum to total input channels.
eps (float): Small constant to prevent division by zero.
affine (bool): Whether to include learnable scale and bias per group.
"""
super().__init__()
self.group_sizes = group_sizes
self.total_channels = sum(group_sizes)
self.eps = eps
# Create a LayerNorm module for each group
self.group_norms = nn.ModuleList([
nn.LayerNorm([size], eps=eps, elementwise_affine=affine)
for size in group_sizes
])
def forward(self, x):
"""
Args:
x (Tensor): Shape (B, C, H, W)
Returns:
Tensor: Same shape, group-wise normalized
"""
B, C = x.shape
assert C == self.total_channels, (
f"Input has {C} channels, expected {self.total_channels} from group sizes {self.group_sizes}"
)
# Split input into channel groups
splits = torch.split(x, self.group_sizes, dim=1)
normed = []
for i, g in enumerate(splits):
normed_group = self.group_norms[i](g)
normed.append(normed_group)
return torch.cat(normed, dim=1)
class AdamState:
def __init__(self, m, v, t):
self.m = m # First moment vector
self.v = v # Second moment vector
self.t = t # Time step
def slice_length(s, dim):
step = s.step or 1
start = s.start if s.start is not None else (0 if step > 0 else dim - 1)
stop = s.stop if s.stop is not None else (dim if step > 0 else -1)
if start < 0: start += dim
if stop < 0: stop += dim
start = max(0, min(dim, start))
stop = max(0, min(dim, stop))
return max(0, (stop - start + (step - 1)) // step) if step > 0 else \
max(0, (start - stop + (-step - 1)) // -step)
@torch.compile(dynamic=True)
def _adam_smooth_unmasked(m, v, t, chunk, beta1, beta2, eps) -> Tensor:
"""Fused moment update + bias-corrected output for the unmasked path."""
m.lerp_(chunk, 1 - beta1)
v.mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2)
t_bc = t.reshape(t.shape[0], *([1] * (m.ndim - 1)))
bias1 = 1 - beta1 ** t_bc
bias2_sqrt = (1 - beta2 ** t_bc).sqrt_()
denom = v.sqrt().div_(bias2_sqrt).add_(eps)
return m.div(bias1).div_(denom)
@torch.compile(dynamic=True)
def _adam_smooth_masked(m, v, t, sel, chunk, beta1, beta2, eps) -> Tensor:
"""Fused moment update + bias-corrected output for the masked path."""
m_sel = m[sel].lerp_(chunk, 1 - beta1)
m[sel] = m_sel
v_sel = v[sel].mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2)
v[sel] = v_sel
t[sel] += 1
t_sel = t[sel].reshape(-1, *([1] * (m.ndim - 1)))
m_hat = m_sel / (1 - beta1 ** t_sel)
v_hat = v_sel / (1 - beta2 ** t_sel)
return m_hat / (torch.sqrt(v_hat) + eps)
class AdamInputSmoothing(nn.Module):
def __init__(self, beta1=0.9, beta2=0.999, eps=1e-15, input_slice: slice | None = None,
shape: tuple | None = None,
device=None):
"""
Implements Adam-like smoothing for input vectors.
Args:
beta1 (float): Exponential decay rate for the first moment estimates.
beta2 (float): Exponential decay rate for the second moment estimates.
eps (float): Small constant to prevent division by zero.
input_slice (slice, optional): If provided, only apply smoothing to this slice of the input.
"""
super().__init__()
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.input_slice: slice | None = input_slice
if self.input_slice is not None:
assert isinstance(self.input_slice, slice), "input_slice must be a slice or None"
# Initialize first and second moment vectors
if shape is None:
self.reset()
else:
self.initialize(shape,
device=device)
def forward(self, x: Tensor) -> Tensor:
"""
Apply Adam-like smoothing to the input.
Args:
x (Tensor): Input tensor of shape (..., input_dim)
Returns:
Tensor: Smoothed tensor of same shape as input
"""
# Select the relevant slice of the input
chunk = x[..., self.input_slice] if self.input_slice is not None else x
# Initialize internal state if needed
if self.is_reset():
self.initialize(chunk.shape, device=chunk.device)
chunk_detached = chunk.detach()
if self.sel is None:
# Increment step first (matches PyTorch Adam convention)
self.t += 1
# Fused moment update + bias-corrected output (compiled kernel)
chunk_smoothed = _adam_smooth_unmasked(self.m, self.v, self.t, chunk_detached,
self.beta1, self.beta2, self.eps)
else:
# Fused masked update (compiled kernel)
chunk_smoothed = _adam_smooth_masked(self.m, self.v, self.t, self.sel, chunk_detached,
self.beta1, self.beta2, self.eps)
# Replace in original tensor
if self.input_slice is not None:
output_shape = slice_length(self.input_slice, x.shape[-1])
if output_shape == x.shape[-1]:
x_out = chunk_smoothed
else:
# only replace a slice, so we need to clone to avoid modifying input
x_out = x.clone()
x_out[..., self.input_slice] = chunk_smoothed
else:
# we overwrite the whole tensor, no need to clone
x_out = chunk_smoothed
return x_out
def reset(self):
"""Reset the internal state."""
self.m = torch.tensor(0, dtype=torch.float32)
self.v = torch.tensor(0, dtype=torch.float32)
self.t = torch.tensor(0, dtype=torch.int64)
self.sel = None
def initialize(self, shape, device) -> None:
"""Initialize the internal state with zeros for the given number of elements and input dimension."""
self.m = torch.zeros(shape, dtype=torch.float32, device=device)
self.v = torch.zeros(shape, dtype=torch.float32, device=device)
self.t = torch.zeros(shape[0], dtype=torch.int64, device=device)
self.sel = None
def update_state(self, adam_state: AdamState) -> None:
"""Update the internal state with provided values."""
m, v, t = adam_state.m, adam_state.v, adam_state.t
self.m = m
self.v = v
self.t = t
self.sel = None
def prune(self, prune_mask: Tensor) -> None:
"""Prune the internal state to only keep entries at the specified indices."""
assert not self.is_reset(), (
"Cannot prune state that has not been initialized. Call forward() at least once first."
)
sel = torch.where(~prune_mask)[0]
self.m = self.m[sel]
self.v = self.v[sel]
self.t = self.t[sel]
if self.sel is not None:
self.sel = self.sel[sel]
def zero_out(self, zero_t=False) -> None:
"""Zero out the moments. Called when resetting gaussians opacities."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
self.m = torch.zeros_like(self.m)
self.v = torch.zeros_like(self.v)
if zero_t:
self.t = torch.zeros_like(self.t)
def replace(self, from_indices: Tensor, dest_indices: Tensor, zero_t=False) -> None:
"""Replace the internal state to duplicate entries at the specified indices."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
self.m[dest_indices] = self.m[from_indices]
self.v[dest_indices] = self.v[from_indices]
if zero_t:
self.t[dest_indices] = 0
else:
self.t[dest_indices] = self.t[from_indices]
def clone(self, clone_mask: Tensor, zero_t=False) -> None:
"""Clone the internal state to duplicate entries at the specified indices."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
num_new_rows = clone_mask.sum()
new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype)
if zero_t:
new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype)
else:
sel = torch.where(clone_mask)[0]
new_t = self.t[sel]
self.m = torch.cat([self.m, new_zeros], dim=0)
self.v = torch.cat([self.v, new_zeros], dim=0)
self.t = torch.cat([self.t, new_t], dim=0)
def add(self, nr_new: int) -> None:
"""Add new entries to the internal state."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
new_zeros = torch.zeros((nr_new, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype)
new_t = torch.zeros((nr_new, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype)
self.m = torch.cat([self.m, new_zeros], dim=0)
self.v = torch.cat([self.v, new_zeros], dim=0)
self.t = torch.cat([self.t, new_t], dim=0)
def split(self, split_mask: Tensor, N: int, zero_t=False) -> None:
"""Split the internal state to duplicate entries at the specified indices."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
# Count how many new rows we need
num_new_rows = split_mask.sum() * N
# Handle t depending on zero_t flag
if zero_t:
new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype)
else:
# Only t needs to copy repeated original values
sel = torch.where(split_mask)[0]
new_t = self.t[sel].repeat_interleave(N, dim=0)
rest_sel = torch.where(~split_mask)[0]
# Preallocate zeros directly for m and v
new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype)
self.m = torch.cat([self.m[rest_sel], new_zeros], dim=0)
self.v = torch.cat([self.v[rest_sel], new_zeros], dim=0)
self.t = torch.cat([self.t[rest_sel], new_t], dim=0)
def get_state(self) -> AdamState:
"""Get the current internal state."""
return AdamState(self.m, self.v, self.t)
def subgroups_view(self, slices: dict[str, slice]) -> dict[str, "AdamInputSmoothing"]:
"""
Create lightweight subgroups that share memory with the main tensor states.
Args:
slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.:
{"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)}
Returns:
dict[str, AdamInputSmoothing]: Submodules that share self.m and self.v tensors.
"""
if not hasattr(self, "m") or self.m.ndim == 0:
raise RuntimeError("Cannot create subgroups before the first forward() call.")
subgroups = {}
for name, slc in slices.items():
sub = AdamInputSmoothing(
beta1=self.beta1,
beta2=self.beta2,
eps=self.eps,
input_slice=None
)
# share the same memory (not copy)
sub.m = self.m[..., slc]
sub.v = self.v[..., slc]
sub.t = self.t # shared time step
subgroups[name] = sub
return subgroups
def aggregate_from_subgroups(self, subgroups: dict[str, "AdamInputSmoothing"], slices: dict[str, slice]) -> None:
"""
Aggregate states from subgroups back into the main module.
Args:
subgroups (dict[str, AdamInputSmoothing]): Submodules created via subgroups_view.
slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.:
{"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)}
"""
if not hasattr(self, "m") or self.is_reset():
raise RuntimeError("Cannot aggregate states before the first forward() call.")
# Adjust stats shape
first_m_val = next(iter(subgroups.values())).m
if self.m.shape[:-1] != first_m_val.shape[:-1]:
self.m = torch.zeros((*first_m_val.shape[:-1], self.m.shape[-1]), dtype=first_m_val.dtype,
device=first_m_val.device)
self.v = torch.zeros((*first_m_val.shape[:-1], self.v.shape[-1]), dtype=first_m_val.dtype,
device=first_m_val.device)
for name, slc in slices.items():
sub = subgroups[name]
self.m[..., slc] = sub.m
self.v[..., slc] = sub.v
# Assume time step is the same across all subgroups
self.t = next(iter(subgroups.values())).t
def is_reset(self) -> bool:
"""Check if the internal state is reset."""
assert self.m.shape == self.v.shape, "First and second moment vectors must have the same shape."
return bool(self.m.ndim == 0 and self.v.ndim == 0 and self.t == 0)