File size: 5,232 Bytes
6f17d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
SIREN-based implicit decoder.

For any continuous coordinate x = (u, v) ∈ [0,1]²:
    RGB(x) = G_θ( γ(x), z_x )

where γ is Fourier positional encoding and z_x is bilinearly
interpolated from the encoder feature grid Z.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SineLayer(nn.Module):
    """Single SIREN layer:  sin(ω₀ · (Wx + b))"""

    def __init__(self, in_features, out_features, omega_0=30.0, is_first=False):
        super().__init__()
        self.omega_0 = omega_0
        self.linear = nn.Linear(in_features, out_features)
        self._init_weights(is_first, in_features)

    def _init_weights(self, is_first, fan_in):
        with torch.no_grad():
            if is_first:
                self.linear.weight.uniform_(-1.0 / fan_in, 1.0 / fan_in)
            else:
                bound = math.sqrt(6.0 / fan_in) / self.omega_0
                self.linear.weight.uniform_(-bound, bound)

    def forward(self, x):
        return torch.sin(self.omega_0 * self.linear(x))


class FourierEncoding(nn.Module):
    """Fourier positional encoding γ(x) with L frequency bands."""

    def __init__(self, n_bands: int = 10, input_dim: int = 2):
        super().__init__()
        self.n_bands = n_bands
        self.input_dim = input_dim
        # output dim = input_dim * 2 * n_bands  (sin + cos per band per dim)
        self.out_dim = input_dim * 2 * n_bands

    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        """coords: (..., input_dim) → (..., out_dim)"""
        freqs = 2.0 ** torch.arange(self.n_bands, device=coords.device, dtype=coords.dtype)
        # (n_bands,)
        # coords: (..., D) → (..., D, 1)  ×  (1, n_bands) → (..., D, n_bands)
        scaled = coords.unsqueeze(-1) * freqs * math.pi
        enc = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1)
        # (..., D, 2*n_bands) → (..., D*2*n_bands)
        return enc.flatten(-2)


class SIRENDecoder(nn.Module):
    """
    Implicit decoder:
        input  = Fourier(coord) ⊕ z_x
        output = RGB ∈ [-1, 1]
    """

    def __init__(
        self,
        feat_dim: int = 768,
        hidden_dim: int = 256,
        n_layers: int = 5,
        omega_0: float = 30.0,
        fourier_bands: int = 10,
        out_channels: int = 3,
    ):
        super().__init__()
        self.fourier = FourierEncoding(n_bands=fourier_bands, input_dim=2)
        coord_dim = self.fourier.out_dim          # 2 * 2 * L = 40
        in_dim = coord_dim + feat_dim

        layers = []
        layers.append(SineLayer(in_dim, hidden_dim, omega_0=omega_0, is_first=True))
        # Fix: is_first init uses fan_in=in_dim=808 which makes coord contribution
        # negligible (±0.00124). Re-init so coord columns use 1/coord_dim (±0.025)
        # and feature columns use standard SIREN hidden-layer bounds.
        with torch.no_grad():
            w = layers[0].linear.weight        # (hidden_dim, in_dim)
            w[:, :coord_dim].uniform_(-1.0 / coord_dim, 1.0 / coord_dim)
            feat_bound = math.sqrt(6.0 / feat_dim) / omega_0
            w[:, coord_dim:].uniform_(-feat_bound, feat_bound)
        for _ in range(n_layers - 2):
            layers.append(SineLayer(hidden_dim, hidden_dim, omega_0=omega_0))
        # final linear (no sine) → RGB
        final = nn.Linear(hidden_dim, out_channels)
        with torch.no_grad():
            bound = math.sqrt(6.0 / hidden_dim) / omega_0
            final.weight.uniform_(-bound, bound)
        layers.append(final)
        self.net = nn.ModuleList(layers)

    def forward(self, coords: torch.Tensor, features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            coords:   (B, N, 2) continuous coordinates in [0, 1]
            features: (B, C, H_z, W_z) spatial feature grid from encoder
        Returns:
            rgb: (B, N, 3) predicted RGB in [-1, 1]
        """
        B, N, _ = coords.shape

        # 1) Fourier encode coordinates
        enc = self.fourier(coords)                 # (B, N, coord_dim)

        # 2) bilinear sample from feature grid
        #    grid_sample wants coords in [-1, 1]
        grid = coords * 2.0 - 1.0                 # [0,1] → [-1,1]
        grid = grid.unsqueeze(1)                   # (B, 1, N, 2)
        z_x = F.grid_sample(
            features, grid, mode="bilinear", align_corners=True
        )  # (B, C, 1, N)
        z_x = z_x.squeeze(2).permute(0, 2, 1)     # (B, N, C)

        # 3) concatenate and pass through SIREN
        h = torch.cat([enc, z_x], dim=-1)          # (B, N, coord_dim+C)
        for layer in self.net[:-1]:
            h = layer(h)
        rgb = self.net[-1](h)                      # (B, N, 3)
        rgb = torch.tanh(rgb)                      # clamp to [-1, 1]
        return rgb


def make_coord_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
    """Create a flat coordinate grid in [0, 1]².  Returns (1, H*W, 2)."""
    ys = torch.linspace(0, 1, H, device=device)
    xs = torch.linspace(0, 1, W, device=device)
    grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
    coords = torch.stack([grid_x, grid_y], dim=-1)   # (H, W, 2)  x then y
    return coords.reshape(1, H * W, 2)