""" SIREN-based implicit decoder. For any continuous coordinate x = (u, v) ∈ [0,1]²: RGB(x) = G_θ( γ(x), z_x ) where γ is Fourier positional encoding and z_x is bilinearly interpolated from the encoder feature grid Z. """ import math import torch import torch.nn as nn import torch.nn.functional as F class SineLayer(nn.Module): """Single SIREN layer: sin(ω₀ · (Wx + b))""" def __init__(self, in_features, out_features, omega_0=30.0, is_first=False): super().__init__() self.omega_0 = omega_0 self.linear = nn.Linear(in_features, out_features) self._init_weights(is_first, in_features) def _init_weights(self, is_first, fan_in): with torch.no_grad(): if is_first: self.linear.weight.uniform_(-1.0 / fan_in, 1.0 / fan_in) else: bound = math.sqrt(6.0 / fan_in) / self.omega_0 self.linear.weight.uniform_(-bound, bound) def forward(self, x): return torch.sin(self.omega_0 * self.linear(x)) class FourierEncoding(nn.Module): """Fourier positional encoding γ(x) with L frequency bands.""" def __init__(self, n_bands: int = 10, input_dim: int = 2): super().__init__() self.n_bands = n_bands self.input_dim = input_dim # output dim = input_dim * 2 * n_bands (sin + cos per band per dim) self.out_dim = input_dim * 2 * n_bands def forward(self, coords: torch.Tensor) -> torch.Tensor: """coords: (..., input_dim) → (..., out_dim)""" freqs = 2.0 ** torch.arange(self.n_bands, device=coords.device, dtype=coords.dtype) # (n_bands,) # coords: (..., D) → (..., D, 1) × (1, n_bands) → (..., D, n_bands) scaled = coords.unsqueeze(-1) * freqs * math.pi enc = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1) # (..., D, 2*n_bands) → (..., D*2*n_bands) return enc.flatten(-2) class SIRENDecoder(nn.Module): """ Implicit decoder: input = Fourier(coord) ⊕ z_x output = RGB ∈ [-1, 1] """ def __init__( self, feat_dim: int = 768, hidden_dim: int = 256, n_layers: int = 5, omega_0: float = 30.0, fourier_bands: int = 10, out_channels: int = 3, ): super().__init__() self.fourier = FourierEncoding(n_bands=fourier_bands, input_dim=2) coord_dim = self.fourier.out_dim # 2 * 2 * L = 40 in_dim = coord_dim + feat_dim layers = [] layers.append(SineLayer(in_dim, hidden_dim, omega_0=omega_0, is_first=True)) # Fix: is_first init uses fan_in=in_dim=808 which makes coord contribution # negligible (±0.00124). Re-init so coord columns use 1/coord_dim (±0.025) # and feature columns use standard SIREN hidden-layer bounds. with torch.no_grad(): w = layers[0].linear.weight # (hidden_dim, in_dim) w[:, :coord_dim].uniform_(-1.0 / coord_dim, 1.0 / coord_dim) feat_bound = math.sqrt(6.0 / feat_dim) / omega_0 w[:, coord_dim:].uniform_(-feat_bound, feat_bound) for _ in range(n_layers - 2): layers.append(SineLayer(hidden_dim, hidden_dim, omega_0=omega_0)) # final linear (no sine) → RGB final = nn.Linear(hidden_dim, out_channels) with torch.no_grad(): bound = math.sqrt(6.0 / hidden_dim) / omega_0 final.weight.uniform_(-bound, bound) layers.append(final) self.net = nn.ModuleList(layers) def forward(self, coords: torch.Tensor, features: torch.Tensor) -> torch.Tensor: """ Args: coords: (B, N, 2) continuous coordinates in [0, 1] features: (B, C, H_z, W_z) spatial feature grid from encoder Returns: rgb: (B, N, 3) predicted RGB in [-1, 1] """ B, N, _ = coords.shape # 1) Fourier encode coordinates enc = self.fourier(coords) # (B, N, coord_dim) # 2) bilinear sample from feature grid # grid_sample wants coords in [-1, 1] grid = coords * 2.0 - 1.0 # [0,1] → [-1,1] grid = grid.unsqueeze(1) # (B, 1, N, 2) z_x = F.grid_sample( features, grid, mode="bilinear", align_corners=True ) # (B, C, 1, N) z_x = z_x.squeeze(2).permute(0, 2, 1) # (B, N, C) # 3) concatenate and pass through SIREN h = torch.cat([enc, z_x], dim=-1) # (B, N, coord_dim+C) for layer in self.net[:-1]: h = layer(h) rgb = self.net[-1](h) # (B, N, 3) rgb = torch.tanh(rgb) # clamp to [-1, 1] return rgb def make_coord_grid(H: int, W: int, device: torch.device) -> torch.Tensor: """Create a flat coordinate grid in [0, 1]². Returns (1, H*W, 2).""" ys = torch.linspace(0, 1, H, device=device) xs = torch.linspace(0, 1, W, device=device) grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") coords = torch.stack([grid_x, grid_y], dim=-1) # (H, W, 2) x then y return coords.reshape(1, H * W, 2)