itdf-space / models /decoder.py
priyadip's picture
Upload models/decoder.py with huggingface_hub
6f17d2a verified
"""
SIREN-based implicit decoder.
For any continuous coordinate x = (u, v) ∈ [0,1]²:
RGB(x) = G_θ( γ(x), z_x )
where γ is Fourier positional encoding and z_x is bilinearly
interpolated from the encoder feature grid Z.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SineLayer(nn.Module):
"""Single SIREN layer: sin(ω₀ · (Wx + b))"""
def __init__(self, in_features, out_features, omega_0=30.0, is_first=False):
super().__init__()
self.omega_0 = omega_0
self.linear = nn.Linear(in_features, out_features)
self._init_weights(is_first, in_features)
def _init_weights(self, is_first, fan_in):
with torch.no_grad():
if is_first:
self.linear.weight.uniform_(-1.0 / fan_in, 1.0 / fan_in)
else:
bound = math.sqrt(6.0 / fan_in) / self.omega_0
self.linear.weight.uniform_(-bound, bound)
def forward(self, x):
return torch.sin(self.omega_0 * self.linear(x))
class FourierEncoding(nn.Module):
"""Fourier positional encoding γ(x) with L frequency bands."""
def __init__(self, n_bands: int = 10, input_dim: int = 2):
super().__init__()
self.n_bands = n_bands
self.input_dim = input_dim
# output dim = input_dim * 2 * n_bands (sin + cos per band per dim)
self.out_dim = input_dim * 2 * n_bands
def forward(self, coords: torch.Tensor) -> torch.Tensor:
"""coords: (..., input_dim) → (..., out_dim)"""
freqs = 2.0 ** torch.arange(self.n_bands, device=coords.device, dtype=coords.dtype)
# (n_bands,)
# coords: (..., D) → (..., D, 1) × (1, n_bands) → (..., D, n_bands)
scaled = coords.unsqueeze(-1) * freqs * math.pi
enc = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1)
# (..., D, 2*n_bands) → (..., D*2*n_bands)
return enc.flatten(-2)
class SIRENDecoder(nn.Module):
"""
Implicit decoder:
input = Fourier(coord) ⊕ z_x
output = RGB ∈ [-1, 1]
"""
def __init__(
self,
feat_dim: int = 768,
hidden_dim: int = 256,
n_layers: int = 5,
omega_0: float = 30.0,
fourier_bands: int = 10,
out_channels: int = 3,
):
super().__init__()
self.fourier = FourierEncoding(n_bands=fourier_bands, input_dim=2)
coord_dim = self.fourier.out_dim # 2 * 2 * L = 40
in_dim = coord_dim + feat_dim
layers = []
layers.append(SineLayer(in_dim, hidden_dim, omega_0=omega_0, is_first=True))
# Fix: is_first init uses fan_in=in_dim=808 which makes coord contribution
# negligible (±0.00124). Re-init so coord columns use 1/coord_dim (±0.025)
# and feature columns use standard SIREN hidden-layer bounds.
with torch.no_grad():
w = layers[0].linear.weight # (hidden_dim, in_dim)
w[:, :coord_dim].uniform_(-1.0 / coord_dim, 1.0 / coord_dim)
feat_bound = math.sqrt(6.0 / feat_dim) / omega_0
w[:, coord_dim:].uniform_(-feat_bound, feat_bound)
for _ in range(n_layers - 2):
layers.append(SineLayer(hidden_dim, hidden_dim, omega_0=omega_0))
# final linear (no sine) → RGB
final = nn.Linear(hidden_dim, out_channels)
with torch.no_grad():
bound = math.sqrt(6.0 / hidden_dim) / omega_0
final.weight.uniform_(-bound, bound)
layers.append(final)
self.net = nn.ModuleList(layers)
def forward(self, coords: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
"""
Args:
coords: (B, N, 2) continuous coordinates in [0, 1]
features: (B, C, H_z, W_z) spatial feature grid from encoder
Returns:
rgb: (B, N, 3) predicted RGB in [-1, 1]
"""
B, N, _ = coords.shape
# 1) Fourier encode coordinates
enc = self.fourier(coords) # (B, N, coord_dim)
# 2) bilinear sample from feature grid
# grid_sample wants coords in [-1, 1]
grid = coords * 2.0 - 1.0 # [0,1] → [-1,1]
grid = grid.unsqueeze(1) # (B, 1, N, 2)
z_x = F.grid_sample(
features, grid, mode="bilinear", align_corners=True
) # (B, C, 1, N)
z_x = z_x.squeeze(2).permute(0, 2, 1) # (B, N, C)
# 3) concatenate and pass through SIREN
h = torch.cat([enc, z_x], dim=-1) # (B, N, coord_dim+C)
for layer in self.net[:-1]:
h = layer(h)
rgb = self.net[-1](h) # (B, N, 3)
rgb = torch.tanh(rgb) # clamp to [-1, 1]
return rgb
def make_coord_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
"""Create a flat coordinate grid in [0, 1]². Returns (1, H*W, 2)."""
ys = torch.linspace(0, 1, H, device=device)
xs = torch.linspace(0, 1, W, device=device)
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
coords = torch.stack([grid_x, grid_y], dim=-1) # (H, W, 2) x then y
return coords.reshape(1, H * W, 2)