priyadip commited on
Commit
6f17d2a
·
verified ·
1 Parent(s): 0046a78

Upload models/decoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/decoder.py +137 -0
models/decoder.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SIREN-based implicit decoder.
3
+
4
+ For any continuous coordinate x = (u, v) ∈ [0,1]²:
5
+ RGB(x) = G_θ( γ(x), z_x )
6
+
7
+ where γ is Fourier positional encoding and z_x is bilinearly
8
+ interpolated from the encoder feature grid Z.
9
+ """
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class SineLayer(nn.Module):
18
+ """Single SIREN layer: sin(ω₀ · (Wx + b))"""
19
+
20
+ def __init__(self, in_features, out_features, omega_0=30.0, is_first=False):
21
+ super().__init__()
22
+ self.omega_0 = omega_0
23
+ self.linear = nn.Linear(in_features, out_features)
24
+ self._init_weights(is_first, in_features)
25
+
26
+ def _init_weights(self, is_first, fan_in):
27
+ with torch.no_grad():
28
+ if is_first:
29
+ self.linear.weight.uniform_(-1.0 / fan_in, 1.0 / fan_in)
30
+ else:
31
+ bound = math.sqrt(6.0 / fan_in) / self.omega_0
32
+ self.linear.weight.uniform_(-bound, bound)
33
+
34
+ def forward(self, x):
35
+ return torch.sin(self.omega_0 * self.linear(x))
36
+
37
+
38
+ class FourierEncoding(nn.Module):
39
+ """Fourier positional encoding γ(x) with L frequency bands."""
40
+
41
+ def __init__(self, n_bands: int = 10, input_dim: int = 2):
42
+ super().__init__()
43
+ self.n_bands = n_bands
44
+ self.input_dim = input_dim
45
+ # output dim = input_dim * 2 * n_bands (sin + cos per band per dim)
46
+ self.out_dim = input_dim * 2 * n_bands
47
+
48
+ def forward(self, coords: torch.Tensor) -> torch.Tensor:
49
+ """coords: (..., input_dim) → (..., out_dim)"""
50
+ freqs = 2.0 ** torch.arange(self.n_bands, device=coords.device, dtype=coords.dtype)
51
+ # (n_bands,)
52
+ # coords: (..., D) → (..., D, 1) × (1, n_bands) → (..., D, n_bands)
53
+ scaled = coords.unsqueeze(-1) * freqs * math.pi
54
+ enc = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1)
55
+ # (..., D, 2*n_bands) → (..., D*2*n_bands)
56
+ return enc.flatten(-2)
57
+
58
+
59
+ class SIRENDecoder(nn.Module):
60
+ """
61
+ Implicit decoder:
62
+ input = Fourier(coord) ⊕ z_x
63
+ output = RGB ∈ [-1, 1]
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ feat_dim: int = 768,
69
+ hidden_dim: int = 256,
70
+ n_layers: int = 5,
71
+ omega_0: float = 30.0,
72
+ fourier_bands: int = 10,
73
+ out_channels: int = 3,
74
+ ):
75
+ super().__init__()
76
+ self.fourier = FourierEncoding(n_bands=fourier_bands, input_dim=2)
77
+ coord_dim = self.fourier.out_dim # 2 * 2 * L = 40
78
+ in_dim = coord_dim + feat_dim
79
+
80
+ layers = []
81
+ layers.append(SineLayer(in_dim, hidden_dim, omega_0=omega_0, is_first=True))
82
+ # Fix: is_first init uses fan_in=in_dim=808 which makes coord contribution
83
+ # negligible (±0.00124). Re-init so coord columns use 1/coord_dim (±0.025)
84
+ # and feature columns use standard SIREN hidden-layer bounds.
85
+ with torch.no_grad():
86
+ w = layers[0].linear.weight # (hidden_dim, in_dim)
87
+ w[:, :coord_dim].uniform_(-1.0 / coord_dim, 1.0 / coord_dim)
88
+ feat_bound = math.sqrt(6.0 / feat_dim) / omega_0
89
+ w[:, coord_dim:].uniform_(-feat_bound, feat_bound)
90
+ for _ in range(n_layers - 2):
91
+ layers.append(SineLayer(hidden_dim, hidden_dim, omega_0=omega_0))
92
+ # final linear (no sine) → RGB
93
+ final = nn.Linear(hidden_dim, out_channels)
94
+ with torch.no_grad():
95
+ bound = math.sqrt(6.0 / hidden_dim) / omega_0
96
+ final.weight.uniform_(-bound, bound)
97
+ layers.append(final)
98
+ self.net = nn.ModuleList(layers)
99
+
100
+ def forward(self, coords: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Args:
103
+ coords: (B, N, 2) continuous coordinates in [0, 1]
104
+ features: (B, C, H_z, W_z) spatial feature grid from encoder
105
+ Returns:
106
+ rgb: (B, N, 3) predicted RGB in [-1, 1]
107
+ """
108
+ B, N, _ = coords.shape
109
+
110
+ # 1) Fourier encode coordinates
111
+ enc = self.fourier(coords) # (B, N, coord_dim)
112
+
113
+ # 2) bilinear sample from feature grid
114
+ # grid_sample wants coords in [-1, 1]
115
+ grid = coords * 2.0 - 1.0 # [0,1] → [-1,1]
116
+ grid = grid.unsqueeze(1) # (B, 1, N, 2)
117
+ z_x = F.grid_sample(
118
+ features, grid, mode="bilinear", align_corners=True
119
+ ) # (B, C, 1, N)
120
+ z_x = z_x.squeeze(2).permute(0, 2, 1) # (B, N, C)
121
+
122
+ # 3) concatenate and pass through SIREN
123
+ h = torch.cat([enc, z_x], dim=-1) # (B, N, coord_dim+C)
124
+ for layer in self.net[:-1]:
125
+ h = layer(h)
126
+ rgb = self.net[-1](h) # (B, N, 3)
127
+ rgb = torch.tanh(rgb) # clamp to [-1, 1]
128
+ return rgb
129
+
130
+
131
+ def make_coord_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
132
+ """Create a flat coordinate grid in [0, 1]². Returns (1, H*W, 2)."""
133
+ ys = torch.linspace(0, 1, H, device=device)
134
+ xs = torch.linspace(0, 1, W, device=device)
135
+ grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
136
+ coords = torch.stack([grid_x, grid_y], dim=-1) # (H, W, 2) x then y
137
+ return coords.reshape(1, H * W, 2)