Spaces:
Sleeping
Sleeping
File size: 5,232 Bytes
6f17d2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | """
SIREN-based implicit decoder.
For any continuous coordinate x = (u, v) ∈ [0,1]²:
RGB(x) = G_θ( γ(x), z_x )
where γ is Fourier positional encoding and z_x is bilinearly
interpolated from the encoder feature grid Z.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SineLayer(nn.Module):
"""Single SIREN layer: sin(ω₀ · (Wx + b))"""
def __init__(self, in_features, out_features, omega_0=30.0, is_first=False):
super().__init__()
self.omega_0 = omega_0
self.linear = nn.Linear(in_features, out_features)
self._init_weights(is_first, in_features)
def _init_weights(self, is_first, fan_in):
with torch.no_grad():
if is_first:
self.linear.weight.uniform_(-1.0 / fan_in, 1.0 / fan_in)
else:
bound = math.sqrt(6.0 / fan_in) / self.omega_0
self.linear.weight.uniform_(-bound, bound)
def forward(self, x):
return torch.sin(self.omega_0 * self.linear(x))
class FourierEncoding(nn.Module):
"""Fourier positional encoding γ(x) with L frequency bands."""
def __init__(self, n_bands: int = 10, input_dim: int = 2):
super().__init__()
self.n_bands = n_bands
self.input_dim = input_dim
# output dim = input_dim * 2 * n_bands (sin + cos per band per dim)
self.out_dim = input_dim * 2 * n_bands
def forward(self, coords: torch.Tensor) -> torch.Tensor:
"""coords: (..., input_dim) → (..., out_dim)"""
freqs = 2.0 ** torch.arange(self.n_bands, device=coords.device, dtype=coords.dtype)
# (n_bands,)
# coords: (..., D) → (..., D, 1) × (1, n_bands) → (..., D, n_bands)
scaled = coords.unsqueeze(-1) * freqs * math.pi
enc = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1)
# (..., D, 2*n_bands) → (..., D*2*n_bands)
return enc.flatten(-2)
class SIRENDecoder(nn.Module):
"""
Implicit decoder:
input = Fourier(coord) ⊕ z_x
output = RGB ∈ [-1, 1]
"""
def __init__(
self,
feat_dim: int = 768,
hidden_dim: int = 256,
n_layers: int = 5,
omega_0: float = 30.0,
fourier_bands: int = 10,
out_channels: int = 3,
):
super().__init__()
self.fourier = FourierEncoding(n_bands=fourier_bands, input_dim=2)
coord_dim = self.fourier.out_dim # 2 * 2 * L = 40
in_dim = coord_dim + feat_dim
layers = []
layers.append(SineLayer(in_dim, hidden_dim, omega_0=omega_0, is_first=True))
# Fix: is_first init uses fan_in=in_dim=808 which makes coord contribution
# negligible (±0.00124). Re-init so coord columns use 1/coord_dim (±0.025)
# and feature columns use standard SIREN hidden-layer bounds.
with torch.no_grad():
w = layers[0].linear.weight # (hidden_dim, in_dim)
w[:, :coord_dim].uniform_(-1.0 / coord_dim, 1.0 / coord_dim)
feat_bound = math.sqrt(6.0 / feat_dim) / omega_0
w[:, coord_dim:].uniform_(-feat_bound, feat_bound)
for _ in range(n_layers - 2):
layers.append(SineLayer(hidden_dim, hidden_dim, omega_0=omega_0))
# final linear (no sine) → RGB
final = nn.Linear(hidden_dim, out_channels)
with torch.no_grad():
bound = math.sqrt(6.0 / hidden_dim) / omega_0
final.weight.uniform_(-bound, bound)
layers.append(final)
self.net = nn.ModuleList(layers)
def forward(self, coords: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
"""
Args:
coords: (B, N, 2) continuous coordinates in [0, 1]
features: (B, C, H_z, W_z) spatial feature grid from encoder
Returns:
rgb: (B, N, 3) predicted RGB in [-1, 1]
"""
B, N, _ = coords.shape
# 1) Fourier encode coordinates
enc = self.fourier(coords) # (B, N, coord_dim)
# 2) bilinear sample from feature grid
# grid_sample wants coords in [-1, 1]
grid = coords * 2.0 - 1.0 # [0,1] → [-1,1]
grid = grid.unsqueeze(1) # (B, 1, N, 2)
z_x = F.grid_sample(
features, grid, mode="bilinear", align_corners=True
) # (B, C, 1, N)
z_x = z_x.squeeze(2).permute(0, 2, 1) # (B, N, C)
# 3) concatenate and pass through SIREN
h = torch.cat([enc, z_x], dim=-1) # (B, N, coord_dim+C)
for layer in self.net[:-1]:
h = layer(h)
rgb = self.net[-1](h) # (B, N, 3)
rgb = torch.tanh(rgb) # clamp to [-1, 1]
return rgb
def make_coord_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
"""Create a flat coordinate grid in [0, 1]². Returns (1, H*W, 2)."""
ys = torch.linspace(0, 1, H, device=device)
xs = torch.linspace(0, 1, W, device=device)
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
coords = torch.stack([grid_x, grid_y], dim=-1) # (H, W, 2) x then y
return coords.reshape(1, H * W, 2)
|