Spaces:
Sleeping
Sleeping
| """ | |
| SIREN-based implicit decoder. | |
| For any continuous coordinate x = (u, v) ∈ [0,1]²: | |
| RGB(x) = G_θ( γ(x), z_x ) | |
| where γ is Fourier positional encoding and z_x is bilinearly | |
| interpolated from the encoder feature grid Z. | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SineLayer(nn.Module): | |
| """Single SIREN layer: sin(ω₀ · (Wx + b))""" | |
| def __init__(self, in_features, out_features, omega_0=30.0, is_first=False): | |
| super().__init__() | |
| self.omega_0 = omega_0 | |
| self.linear = nn.Linear(in_features, out_features) | |
| self._init_weights(is_first, in_features) | |
| def _init_weights(self, is_first, fan_in): | |
| with torch.no_grad(): | |
| if is_first: | |
| self.linear.weight.uniform_(-1.0 / fan_in, 1.0 / fan_in) | |
| else: | |
| bound = math.sqrt(6.0 / fan_in) / self.omega_0 | |
| self.linear.weight.uniform_(-bound, bound) | |
| def forward(self, x): | |
| return torch.sin(self.omega_0 * self.linear(x)) | |
| class FourierEncoding(nn.Module): | |
| """Fourier positional encoding γ(x) with L frequency bands.""" | |
| def __init__(self, n_bands: int = 10, input_dim: int = 2): | |
| super().__init__() | |
| self.n_bands = n_bands | |
| self.input_dim = input_dim | |
| # output dim = input_dim * 2 * n_bands (sin + cos per band per dim) | |
| self.out_dim = input_dim * 2 * n_bands | |
| def forward(self, coords: torch.Tensor) -> torch.Tensor: | |
| """coords: (..., input_dim) → (..., out_dim)""" | |
| freqs = 2.0 ** torch.arange(self.n_bands, device=coords.device, dtype=coords.dtype) | |
| # (n_bands,) | |
| # coords: (..., D) → (..., D, 1) × (1, n_bands) → (..., D, n_bands) | |
| scaled = coords.unsqueeze(-1) * freqs * math.pi | |
| enc = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1) | |
| # (..., D, 2*n_bands) → (..., D*2*n_bands) | |
| return enc.flatten(-2) | |
| class SIRENDecoder(nn.Module): | |
| """ | |
| Implicit decoder: | |
| input = Fourier(coord) ⊕ z_x | |
| output = RGB ∈ [-1, 1] | |
| """ | |
| def __init__( | |
| self, | |
| feat_dim: int = 768, | |
| hidden_dim: int = 256, | |
| n_layers: int = 5, | |
| omega_0: float = 30.0, | |
| fourier_bands: int = 10, | |
| out_channels: int = 3, | |
| ): | |
| super().__init__() | |
| self.fourier = FourierEncoding(n_bands=fourier_bands, input_dim=2) | |
| coord_dim = self.fourier.out_dim # 2 * 2 * L = 40 | |
| in_dim = coord_dim + feat_dim | |
| layers = [] | |
| layers.append(SineLayer(in_dim, hidden_dim, omega_0=omega_0, is_first=True)) | |
| # Fix: is_first init uses fan_in=in_dim=808 which makes coord contribution | |
| # negligible (±0.00124). Re-init so coord columns use 1/coord_dim (±0.025) | |
| # and feature columns use standard SIREN hidden-layer bounds. | |
| with torch.no_grad(): | |
| w = layers[0].linear.weight # (hidden_dim, in_dim) | |
| w[:, :coord_dim].uniform_(-1.0 / coord_dim, 1.0 / coord_dim) | |
| feat_bound = math.sqrt(6.0 / feat_dim) / omega_0 | |
| w[:, coord_dim:].uniform_(-feat_bound, feat_bound) | |
| for _ in range(n_layers - 2): | |
| layers.append(SineLayer(hidden_dim, hidden_dim, omega_0=omega_0)) | |
| # final linear (no sine) → RGB | |
| final = nn.Linear(hidden_dim, out_channels) | |
| with torch.no_grad(): | |
| bound = math.sqrt(6.0 / hidden_dim) / omega_0 | |
| final.weight.uniform_(-bound, bound) | |
| layers.append(final) | |
| self.net = nn.ModuleList(layers) | |
| def forward(self, coords: torch.Tensor, features: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| coords: (B, N, 2) continuous coordinates in [0, 1] | |
| features: (B, C, H_z, W_z) spatial feature grid from encoder | |
| Returns: | |
| rgb: (B, N, 3) predicted RGB in [-1, 1] | |
| """ | |
| B, N, _ = coords.shape | |
| # 1) Fourier encode coordinates | |
| enc = self.fourier(coords) # (B, N, coord_dim) | |
| # 2) bilinear sample from feature grid | |
| # grid_sample wants coords in [-1, 1] | |
| grid = coords * 2.0 - 1.0 # [0,1] → [-1,1] | |
| grid = grid.unsqueeze(1) # (B, 1, N, 2) | |
| z_x = F.grid_sample( | |
| features, grid, mode="bilinear", align_corners=True | |
| ) # (B, C, 1, N) | |
| z_x = z_x.squeeze(2).permute(0, 2, 1) # (B, N, C) | |
| # 3) concatenate and pass through SIREN | |
| h = torch.cat([enc, z_x], dim=-1) # (B, N, coord_dim+C) | |
| for layer in self.net[:-1]: | |
| h = layer(h) | |
| rgb = self.net[-1](h) # (B, N, 3) | |
| rgb = torch.tanh(rgb) # clamp to [-1, 1] | |
| return rgb | |
| def make_coord_grid(H: int, W: int, device: torch.device) -> torch.Tensor: | |
| """Create a flat coordinate grid in [0, 1]². Returns (1, H*W, 2).""" | |
| ys = torch.linspace(0, 1, H, device=device) | |
| xs = torch.linspace(0, 1, W, device=device) | |
| grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") | |
| coords = torch.stack([grid_x, grid_y], dim=-1) # (H, W, 2) x then y | |
| return coords.reshape(1, H * W, 2) | |