import torch from jaxtyping import Float from torch import nn as nn, Tensor class SlicedG3RNorm(nn.Module): def __init__(self, num_features, input_slice, eps=1e-8): """ Apply G3R normalization to a slice of the input features. Devide the input by the maximum absolute value in each channel. Args: num_features (int): Total number of features (channels). input_slice (slice): Size of each slice to normalize independently. eps (float): Small constant to prevent division by zero. """ super().__init__() self.input_slice = input_slice dummy = torch.zeros(1, num_features) chunk = dummy[:, input_slice] self.slice_size = chunk.shape[-1] self.eps = eps def forward(self, x: Float[Tensor, "B C"]): """ Args: x (Tensor): Shape (B, C) where C = num_features Returns: Tensor: Same shape, only subset of channels normalized """ # Split input into the slice to normalize and the rest chunk = x[:, self.input_slice] # Compute max absolute value per channel # Detach to avoid backpropagating through max operation max_val_per_channel = chunk.abs().max(0, keepdim=True)[0].detach() + self.eps # Apply G3R normalization to the selected slice # Replace the normalized slice back into the original input x = x.clone() x[:, self.input_slice] = chunk / max_val_per_channel return x class SlicedBatchNorm1d(nn.Module): def __init__(self, num_features, input_slice, eps=1e-8, affine=False, track_running_stats=True): """ Apply normalization independently to a slice of the input features. Args: num_features (int): Total number of features (channels). input_slice (slice): Size of each slice to normalize independently. eps (float): Small constant to prevent division by zero. affine (bool): Whether to include learnable scale and bias per slice. """ super().__init__() self.input_slice = input_slice dummy = torch.zeros(1, num_features) chunk = dummy[:, input_slice] self.slice_size = chunk.shape[-1] self.eps = eps # Create a BatchNorm1d module for each slice self.slice_norm = nn.BatchNorm1d(self.slice_size, eps=eps, affine=affine, track_running_stats=track_running_stats) def forward(self, x): """ Args: x (Tensor): Shape (B, C) where C = num_features Returns: Tensor: Same shape, only subset of channels normalized """ B, C = x.shape # Split input into the slice to normalize and the rest chunk = x[:, self.input_slice] # Apply normalization to the selected slice chunk = self.slice_norm(chunk) # Replace the normalized slice back into the original input x = x.clone() x[:, self.input_slice] = chunk return x class CustomGroupNorm(nn.Module): def __init__(self, group_sizes, eps=1e-8, affine=True): """ Args: group_sizes (list[int]): List of channel counts for each group. Must sum to total input channels. eps (float): Small constant to prevent division by zero. affine (bool): Whether to include learnable scale and bias per group. """ super().__init__() self.group_sizes = group_sizes self.total_channels = sum(group_sizes) self.eps = eps # Create a LayerNorm module for each group self.group_norms = nn.ModuleList([ nn.LayerNorm([size], eps=eps, elementwise_affine=affine) for size in group_sizes ]) def forward(self, x): """ Args: x (Tensor): Shape (B, C, H, W) Returns: Tensor: Same shape, group-wise normalized """ B, C = x.shape assert C == self.total_channels, ( f"Input has {C} channels, expected {self.total_channels} from group sizes {self.group_sizes}" ) # Split input into channel groups splits = torch.split(x, self.group_sizes, dim=1) normed = [] for i, g in enumerate(splits): normed_group = self.group_norms[i](g) normed.append(normed_group) return torch.cat(normed, dim=1) class AdamState: def __init__(self, m, v, t): self.m = m # First moment vector self.v = v # Second moment vector self.t = t # Time step def slice_length(s, dim): step = s.step or 1 start = s.start if s.start is not None else (0 if step > 0 else dim - 1) stop = s.stop if s.stop is not None else (dim if step > 0 else -1) if start < 0: start += dim if stop < 0: stop += dim start = max(0, min(dim, start)) stop = max(0, min(dim, stop)) return max(0, (stop - start + (step - 1)) // step) if step > 0 else \ max(0, (start - stop + (-step - 1)) // -step) @torch.compile(dynamic=True) def _adam_smooth_unmasked(m, v, t, chunk, beta1, beta2, eps) -> Tensor: """Fused moment update + bias-corrected output for the unmasked path.""" m.lerp_(chunk, 1 - beta1) v.mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2) t_bc = t.reshape(t.shape[0], *([1] * (m.ndim - 1))) bias1 = 1 - beta1 ** t_bc bias2_sqrt = (1 - beta2 ** t_bc).sqrt_() denom = v.sqrt().div_(bias2_sqrt).add_(eps) return m.div(bias1).div_(denom) @torch.compile(dynamic=True) def _adam_smooth_masked(m, v, t, sel, chunk, beta1, beta2, eps) -> Tensor: """Fused moment update + bias-corrected output for the masked path.""" m_sel = m[sel].lerp_(chunk, 1 - beta1) m[sel] = m_sel v_sel = v[sel].mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2) v[sel] = v_sel t[sel] += 1 t_sel = t[sel].reshape(-1, *([1] * (m.ndim - 1))) m_hat = m_sel / (1 - beta1 ** t_sel) v_hat = v_sel / (1 - beta2 ** t_sel) return m_hat / (torch.sqrt(v_hat) + eps) class AdamInputSmoothing(nn.Module): def __init__(self, beta1=0.9, beta2=0.999, eps=1e-15, input_slice: slice | None = None, shape: tuple | None = None, device=None): """ Implements Adam-like smoothing for input vectors. Args: beta1 (float): Exponential decay rate for the first moment estimates. beta2 (float): Exponential decay rate for the second moment estimates. eps (float): Small constant to prevent division by zero. input_slice (slice, optional): If provided, only apply smoothing to this slice of the input. """ super().__init__() self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.input_slice: slice | None = input_slice if self.input_slice is not None: assert isinstance(self.input_slice, slice), "input_slice must be a slice or None" # Initialize first and second moment vectors if shape is None: self.reset() else: self.initialize(shape, device=device) def forward(self, x: Tensor) -> Tensor: """ Apply Adam-like smoothing to the input. Args: x (Tensor): Input tensor of shape (..., input_dim) Returns: Tensor: Smoothed tensor of same shape as input """ # Select the relevant slice of the input chunk = x[..., self.input_slice] if self.input_slice is not None else x # Initialize internal state if needed if self.is_reset(): self.initialize(chunk.shape, device=chunk.device) chunk_detached = chunk.detach() if self.sel is None: # Increment step first (matches PyTorch Adam convention) self.t += 1 # Fused moment update + bias-corrected output (compiled kernel) chunk_smoothed = _adam_smooth_unmasked(self.m, self.v, self.t, chunk_detached, self.beta1, self.beta2, self.eps) else: # Fused masked update (compiled kernel) chunk_smoothed = _adam_smooth_masked(self.m, self.v, self.t, self.sel, chunk_detached, self.beta1, self.beta2, self.eps) # Replace in original tensor if self.input_slice is not None: output_shape = slice_length(self.input_slice, x.shape[-1]) if output_shape == x.shape[-1]: x_out = chunk_smoothed else: # only replace a slice, so we need to clone to avoid modifying input x_out = x.clone() x_out[..., self.input_slice] = chunk_smoothed else: # we overwrite the whole tensor, no need to clone x_out = chunk_smoothed return x_out def reset(self): """Reset the internal state.""" self.m = torch.tensor(0, dtype=torch.float32) self.v = torch.tensor(0, dtype=torch.float32) self.t = torch.tensor(0, dtype=torch.int64) self.sel = None def initialize(self, shape, device) -> None: """Initialize the internal state with zeros for the given number of elements and input dimension.""" self.m = torch.zeros(shape, dtype=torch.float32, device=device) self.v = torch.zeros(shape, dtype=torch.float32, device=device) self.t = torch.zeros(shape[0], dtype=torch.int64, device=device) self.sel = None def update_state(self, adam_state: AdamState) -> None: """Update the internal state with provided values.""" m, v, t = adam_state.m, adam_state.v, adam_state.t self.m = m self.v = v self.t = t self.sel = None def prune(self, prune_mask: Tensor) -> None: """Prune the internal state to only keep entries at the specified indices.""" assert not self.is_reset(), ( "Cannot prune state that has not been initialized. Call forward() at least once first." ) sel = torch.where(~prune_mask)[0] self.m = self.m[sel] self.v = self.v[sel] self.t = self.t[sel] if self.sel is not None: self.sel = self.sel[sel] def zero_out(self, zero_t=False) -> None: """Zero out the moments. Called when resetting gaussians opacities.""" assert not self.is_reset(), ( "Cannot extend state that has not been initialized. Call forward() at least once first." ) self.m = torch.zeros_like(self.m) self.v = torch.zeros_like(self.v) if zero_t: self.t = torch.zeros_like(self.t) def replace(self, from_indices: Tensor, dest_indices: Tensor, zero_t=False) -> None: """Replace the internal state to duplicate entries at the specified indices.""" assert not self.is_reset(), ( "Cannot extend state that has not been initialized. Call forward() at least once first." ) self.m[dest_indices] = self.m[from_indices] self.v[dest_indices] = self.v[from_indices] if zero_t: self.t[dest_indices] = 0 else: self.t[dest_indices] = self.t[from_indices] def clone(self, clone_mask: Tensor, zero_t=False) -> None: """Clone the internal state to duplicate entries at the specified indices.""" assert not self.is_reset(), ( "Cannot extend state that has not been initialized. Call forward() at least once first." ) num_new_rows = clone_mask.sum() new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype) if zero_t: new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype) else: sel = torch.where(clone_mask)[0] new_t = self.t[sel] self.m = torch.cat([self.m, new_zeros], dim=0) self.v = torch.cat([self.v, new_zeros], dim=0) self.t = torch.cat([self.t, new_t], dim=0) def add(self, nr_new: int) -> None: """Add new entries to the internal state.""" assert not self.is_reset(), ( "Cannot extend state that has not been initialized. Call forward() at least once first." ) new_zeros = torch.zeros((nr_new, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype) new_t = torch.zeros((nr_new, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype) self.m = torch.cat([self.m, new_zeros], dim=0) self.v = torch.cat([self.v, new_zeros], dim=0) self.t = torch.cat([self.t, new_t], dim=0) def split(self, split_mask: Tensor, N: int, zero_t=False) -> None: """Split the internal state to duplicate entries at the specified indices.""" assert not self.is_reset(), ( "Cannot extend state that has not been initialized. Call forward() at least once first." ) # Count how many new rows we need num_new_rows = split_mask.sum() * N # Handle t depending on zero_t flag if zero_t: new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype) else: # Only t needs to copy repeated original values sel = torch.where(split_mask)[0] new_t = self.t[sel].repeat_interleave(N, dim=0) rest_sel = torch.where(~split_mask)[0] # Preallocate zeros directly for m and v new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype) self.m = torch.cat([self.m[rest_sel], new_zeros], dim=0) self.v = torch.cat([self.v[rest_sel], new_zeros], dim=0) self.t = torch.cat([self.t[rest_sel], new_t], dim=0) def get_state(self) -> AdamState: """Get the current internal state.""" return AdamState(self.m, self.v, self.t) def subgroups_view(self, slices: dict[str, slice]) -> dict[str, "AdamInputSmoothing"]: """ Create lightweight subgroups that share memory with the main tensor states. Args: slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.: {"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)} Returns: dict[str, AdamInputSmoothing]: Submodules that share self.m and self.v tensors. """ if not hasattr(self, "m") or self.m.ndim == 0: raise RuntimeError("Cannot create subgroups before the first forward() call.") subgroups = {} for name, slc in slices.items(): sub = AdamInputSmoothing( beta1=self.beta1, beta2=self.beta2, eps=self.eps, input_slice=None ) # share the same memory (not copy) sub.m = self.m[..., slc] sub.v = self.v[..., slc] sub.t = self.t # shared time step subgroups[name] = sub return subgroups def aggregate_from_subgroups(self, subgroups: dict[str, "AdamInputSmoothing"], slices: dict[str, slice]) -> None: """ Aggregate states from subgroups back into the main module. Args: subgroups (dict[str, AdamInputSmoothing]): Submodules created via subgroups_view. slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.: {"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)} """ if not hasattr(self, "m") or self.is_reset(): raise RuntimeError("Cannot aggregate states before the first forward() call.") # Adjust stats shape first_m_val = next(iter(subgroups.values())).m if self.m.shape[:-1] != first_m_val.shape[:-1]: self.m = torch.zeros((*first_m_val.shape[:-1], self.m.shape[-1]), dtype=first_m_val.dtype, device=first_m_val.device) self.v = torch.zeros((*first_m_val.shape[:-1], self.v.shape[-1]), dtype=first_m_val.dtype, device=first_m_val.device) for name, slc in slices.items(): sub = subgroups[name] self.m[..., slc] = sub.m self.v[..., slc] = sub.v # Assume time step is the same across all subgroups self.t = next(iter(subgroups.values())).t def is_reset(self) -> bool: """Check if the internal state is reset.""" assert self.m.shape == self.v.shape, "First and second moment vectors must have the same shape." return bool(self.m.ndim == 0 and self.v.ndim == 0 and self.t == 0)