Spaces:
Sleeping
Sleeping
File size: 16,972 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 | import torch
from jaxtyping import Float
from torch import nn as nn, Tensor
class SlicedG3RNorm(nn.Module):
def __init__(self, num_features, input_slice, eps=1e-8):
"""
Apply G3R normalization to a slice of the input features.
Devide the input by the maximum absolute value in each channel.
Args:
num_features (int): Total number of features (channels).
input_slice (slice): Size of each slice to normalize independently.
eps (float): Small constant to prevent division by zero.
"""
super().__init__()
self.input_slice = input_slice
dummy = torch.zeros(1, num_features)
chunk = dummy[:, input_slice]
self.slice_size = chunk.shape[-1]
self.eps = eps
def forward(self, x: Float[Tensor, "B C"]):
"""
Args:
x (Tensor): Shape (B, C) where C = num_features
Returns:
Tensor: Same shape, only subset of channels normalized
"""
# Split input into the slice to normalize and the rest
chunk = x[:, self.input_slice]
# Compute max absolute value per channel
# Detach to avoid backpropagating through max operation
max_val_per_channel = chunk.abs().max(0, keepdim=True)[0].detach() + self.eps
# Apply G3R normalization to the selected slice
# Replace the normalized slice back into the original input
x = x.clone()
x[:, self.input_slice] = chunk / max_val_per_channel
return x
class SlicedBatchNorm1d(nn.Module):
def __init__(self, num_features, input_slice, eps=1e-8, affine=False, track_running_stats=True):
"""
Apply normalization independently to a slice of the input features.
Args:
num_features (int): Total number of features (channels).
input_slice (slice): Size of each slice to normalize independently.
eps (float): Small constant to prevent division by zero.
affine (bool): Whether to include learnable scale and bias per slice.
"""
super().__init__()
self.input_slice = input_slice
dummy = torch.zeros(1, num_features)
chunk = dummy[:, input_slice]
self.slice_size = chunk.shape[-1]
self.eps = eps
# Create a BatchNorm1d module for each slice
self.slice_norm = nn.BatchNorm1d(self.slice_size, eps=eps, affine=affine,
track_running_stats=track_running_stats)
def forward(self, x):
"""
Args:
x (Tensor): Shape (B, C) where C = num_features
Returns:
Tensor: Same shape, only subset of channels normalized
"""
B, C = x.shape
# Split input into the slice to normalize and the rest
chunk = x[:, self.input_slice]
# Apply normalization to the selected slice
chunk = self.slice_norm(chunk)
# Replace the normalized slice back into the original input
x = x.clone()
x[:, self.input_slice] = chunk
return x
class CustomGroupNorm(nn.Module):
def __init__(self, group_sizes, eps=1e-8, affine=True):
"""
Args:
group_sizes (list[int]): List of channel counts for each group. Must sum to total input channels.
eps (float): Small constant to prevent division by zero.
affine (bool): Whether to include learnable scale and bias per group.
"""
super().__init__()
self.group_sizes = group_sizes
self.total_channels = sum(group_sizes)
self.eps = eps
# Create a LayerNorm module for each group
self.group_norms = nn.ModuleList([
nn.LayerNorm([size], eps=eps, elementwise_affine=affine)
for size in group_sizes
])
def forward(self, x):
"""
Args:
x (Tensor): Shape (B, C, H, W)
Returns:
Tensor: Same shape, group-wise normalized
"""
B, C = x.shape
assert C == self.total_channels, (
f"Input has {C} channels, expected {self.total_channels} from group sizes {self.group_sizes}"
)
# Split input into channel groups
splits = torch.split(x, self.group_sizes, dim=1)
normed = []
for i, g in enumerate(splits):
normed_group = self.group_norms[i](g)
normed.append(normed_group)
return torch.cat(normed, dim=1)
class AdamState:
def __init__(self, m, v, t):
self.m = m # First moment vector
self.v = v # Second moment vector
self.t = t # Time step
def slice_length(s, dim):
step = s.step or 1
start = s.start if s.start is not None else (0 if step > 0 else dim - 1)
stop = s.stop if s.stop is not None else (dim if step > 0 else -1)
if start < 0: start += dim
if stop < 0: stop += dim
start = max(0, min(dim, start))
stop = max(0, min(dim, stop))
return max(0, (stop - start + (step - 1)) // step) if step > 0 else \
max(0, (start - stop + (-step - 1)) // -step)
@torch.compile(dynamic=True)
def _adam_smooth_unmasked(m, v, t, chunk, beta1, beta2, eps) -> Tensor:
"""Fused moment update + bias-corrected output for the unmasked path."""
m.lerp_(chunk, 1 - beta1)
v.mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2)
t_bc = t.reshape(t.shape[0], *([1] * (m.ndim - 1)))
bias1 = 1 - beta1 ** t_bc
bias2_sqrt = (1 - beta2 ** t_bc).sqrt_()
denom = v.sqrt().div_(bias2_sqrt).add_(eps)
return m.div(bias1).div_(denom)
@torch.compile(dynamic=True)
def _adam_smooth_masked(m, v, t, sel, chunk, beta1, beta2, eps) -> Tensor:
"""Fused moment update + bias-corrected output for the masked path."""
m_sel = m[sel].lerp_(chunk, 1 - beta1)
m[sel] = m_sel
v_sel = v[sel].mul_(beta2).addcmul_(chunk, chunk, value=1 - beta2)
v[sel] = v_sel
t[sel] += 1
t_sel = t[sel].reshape(-1, *([1] * (m.ndim - 1)))
m_hat = m_sel / (1 - beta1 ** t_sel)
v_hat = v_sel / (1 - beta2 ** t_sel)
return m_hat / (torch.sqrt(v_hat) + eps)
class AdamInputSmoothing(nn.Module):
def __init__(self, beta1=0.9, beta2=0.999, eps=1e-15, input_slice: slice | None = None,
shape: tuple | None = None,
device=None):
"""
Implements Adam-like smoothing for input vectors.
Args:
beta1 (float): Exponential decay rate for the first moment estimates.
beta2 (float): Exponential decay rate for the second moment estimates.
eps (float): Small constant to prevent division by zero.
input_slice (slice, optional): If provided, only apply smoothing to this slice of the input.
"""
super().__init__()
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.input_slice: slice | None = input_slice
if self.input_slice is not None:
assert isinstance(self.input_slice, slice), "input_slice must be a slice or None"
# Initialize first and second moment vectors
if shape is None:
self.reset()
else:
self.initialize(shape,
device=device)
def forward(self, x: Tensor) -> Tensor:
"""
Apply Adam-like smoothing to the input.
Args:
x (Tensor): Input tensor of shape (..., input_dim)
Returns:
Tensor: Smoothed tensor of same shape as input
"""
# Select the relevant slice of the input
chunk = x[..., self.input_slice] if self.input_slice is not None else x
# Initialize internal state if needed
if self.is_reset():
self.initialize(chunk.shape, device=chunk.device)
chunk_detached = chunk.detach()
if self.sel is None:
# Increment step first (matches PyTorch Adam convention)
self.t += 1
# Fused moment update + bias-corrected output (compiled kernel)
chunk_smoothed = _adam_smooth_unmasked(self.m, self.v, self.t, chunk_detached,
self.beta1, self.beta2, self.eps)
else:
# Fused masked update (compiled kernel)
chunk_smoothed = _adam_smooth_masked(self.m, self.v, self.t, self.sel, chunk_detached,
self.beta1, self.beta2, self.eps)
# Replace in original tensor
if self.input_slice is not None:
output_shape = slice_length(self.input_slice, x.shape[-1])
if output_shape == x.shape[-1]:
x_out = chunk_smoothed
else:
# only replace a slice, so we need to clone to avoid modifying input
x_out = x.clone()
x_out[..., self.input_slice] = chunk_smoothed
else:
# we overwrite the whole tensor, no need to clone
x_out = chunk_smoothed
return x_out
def reset(self):
"""Reset the internal state."""
self.m = torch.tensor(0, dtype=torch.float32)
self.v = torch.tensor(0, dtype=torch.float32)
self.t = torch.tensor(0, dtype=torch.int64)
self.sel = None
def initialize(self, shape, device) -> None:
"""Initialize the internal state with zeros for the given number of elements and input dimension."""
self.m = torch.zeros(shape, dtype=torch.float32, device=device)
self.v = torch.zeros(shape, dtype=torch.float32, device=device)
self.t = torch.zeros(shape[0], dtype=torch.int64, device=device)
self.sel = None
def update_state(self, adam_state: AdamState) -> None:
"""Update the internal state with provided values."""
m, v, t = adam_state.m, adam_state.v, adam_state.t
self.m = m
self.v = v
self.t = t
self.sel = None
def prune(self, prune_mask: Tensor) -> None:
"""Prune the internal state to only keep entries at the specified indices."""
assert not self.is_reset(), (
"Cannot prune state that has not been initialized. Call forward() at least once first."
)
sel = torch.where(~prune_mask)[0]
self.m = self.m[sel]
self.v = self.v[sel]
self.t = self.t[sel]
if self.sel is not None:
self.sel = self.sel[sel]
def zero_out(self, zero_t=False) -> None:
"""Zero out the moments. Called when resetting gaussians opacities."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
self.m = torch.zeros_like(self.m)
self.v = torch.zeros_like(self.v)
if zero_t:
self.t = torch.zeros_like(self.t)
def replace(self, from_indices: Tensor, dest_indices: Tensor, zero_t=False) -> None:
"""Replace the internal state to duplicate entries at the specified indices."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
self.m[dest_indices] = self.m[from_indices]
self.v[dest_indices] = self.v[from_indices]
if zero_t:
self.t[dest_indices] = 0
else:
self.t[dest_indices] = self.t[from_indices]
def clone(self, clone_mask: Tensor, zero_t=False) -> None:
"""Clone the internal state to duplicate entries at the specified indices."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
num_new_rows = clone_mask.sum()
new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype)
if zero_t:
new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype)
else:
sel = torch.where(clone_mask)[0]
new_t = self.t[sel]
self.m = torch.cat([self.m, new_zeros], dim=0)
self.v = torch.cat([self.v, new_zeros], dim=0)
self.t = torch.cat([self.t, new_t], dim=0)
def add(self, nr_new: int) -> None:
"""Add new entries to the internal state."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
new_zeros = torch.zeros((nr_new, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype)
new_t = torch.zeros((nr_new, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype)
self.m = torch.cat([self.m, new_zeros], dim=0)
self.v = torch.cat([self.v, new_zeros], dim=0)
self.t = torch.cat([self.t, new_t], dim=0)
def split(self, split_mask: Tensor, N: int, zero_t=False) -> None:
"""Split the internal state to duplicate entries at the specified indices."""
assert not self.is_reset(), (
"Cannot extend state that has not been initialized. Call forward() at least once first."
)
# Count how many new rows we need
num_new_rows = split_mask.sum() * N
# Handle t depending on zero_t flag
if zero_t:
new_t = torch.zeros((num_new_rows, *self.t.shape[1:]), device=self.t.device, dtype=self.t.dtype)
else:
# Only t needs to copy repeated original values
sel = torch.where(split_mask)[0]
new_t = self.t[sel].repeat_interleave(N, dim=0)
rest_sel = torch.where(~split_mask)[0]
# Preallocate zeros directly for m and v
new_zeros = torch.zeros((num_new_rows, *self.m.shape[1:]), device=self.m.device, dtype=self.m.dtype)
self.m = torch.cat([self.m[rest_sel], new_zeros], dim=0)
self.v = torch.cat([self.v[rest_sel], new_zeros], dim=0)
self.t = torch.cat([self.t[rest_sel], new_t], dim=0)
def get_state(self) -> AdamState:
"""Get the current internal state."""
return AdamState(self.m, self.v, self.t)
def subgroups_view(self, slices: dict[str, slice]) -> dict[str, "AdamInputSmoothing"]:
"""
Create lightweight subgroups that share memory with the main tensor states.
Args:
slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.:
{"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)}
Returns:
dict[str, AdamInputSmoothing]: Submodules that share self.m and self.v tensors.
"""
if not hasattr(self, "m") or self.m.ndim == 0:
raise RuntimeError("Cannot create subgroups before the first forward() call.")
subgroups = {}
for name, slc in slices.items():
sub = AdamInputSmoothing(
beta1=self.beta1,
beta2=self.beta2,
eps=self.eps,
input_slice=None
)
# share the same memory (not copy)
sub.m = self.m[..., slc]
sub.v = self.v[..., slc]
sub.t = self.t # shared time step
subgroups[name] = sub
return subgroups
def aggregate_from_subgroups(self, subgroups: dict[str, "AdamInputSmoothing"], slices: dict[str, slice]) -> None:
"""
Aggregate states from subgroups back into the main module.
Args:
subgroups (dict[str, AdamInputSmoothing]): Submodules created via subgroups_view.
slices (dict[str, slice]): Mapping from subgroup name to slice, e.g.:
{"means": slice(0, 3) ,"scale": slice(3, 6), "rotation": slice(6, 10), "opacity": slice(10, 11), "sh": slice(11, 59)}
"""
if not hasattr(self, "m") or self.is_reset():
raise RuntimeError("Cannot aggregate states before the first forward() call.")
# Adjust stats shape
first_m_val = next(iter(subgroups.values())).m
if self.m.shape[:-1] != first_m_val.shape[:-1]:
self.m = torch.zeros((*first_m_val.shape[:-1], self.m.shape[-1]), dtype=first_m_val.dtype,
device=first_m_val.device)
self.v = torch.zeros((*first_m_val.shape[:-1], self.v.shape[-1]), dtype=first_m_val.dtype,
device=first_m_val.device)
for name, slc in slices.items():
sub = subgroups[name]
self.m[..., slc] = sub.m
self.v[..., slc] = sub.v
# Assume time step is the same across all subgroups
self.t = next(iter(subgroups.values())).t
def is_reset(self) -> bool:
"""Check if the internal state is reset."""
assert self.m.shape == self.v.shape, "First and second moment vectors must have the same shape."
return bool(self.m.ndim == 0 and self.v.ndim == 0 and self.t == 0)
|