File size: 9,505 Bytes
f805fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
"""Building blocks for the flow matching U-Net.

Contains: SinusoidalEmbedding, ResBlock (with AdaGN), SelfAttention.
These are the primitives that the U-Net and VQ-VAE are composed from.

Reference: docs/build_spec.md §2.3 (U-Net architecture, ResBlock with AdaGN,
conditioning embedding) and docs/foundations_guide.md Part 7 (how AdaGN works).
"""

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F


class SinusoidalEmbedding(nn.Module):
    """Map a scalar t ∈ [0, 1] to a fixed-frequency sinusoidal embedding.

    Uses the same encoding as Vaswani et al. (2017) "Attention Is All You Need",
    but applied to a continuous scalar instead of discrete positions.

    Math: for dimension index i in [0, dim):
        embed[2i]   = sin(t * 10000^(-2i/dim))
        embed[2i+1] = cos(t * 10000^(-2i/dim))

    This gives the flow time a rich, high-dimensional representation where
    nearby t values have similar embeddings (smooth) but the model can still
    distinguish fine differences (high-frequency components).

    Args:
        dim: Output embedding dimension. Must be even.
    """

    def __init__(self, dim: int) -> None:
        super().__init__()
        if dim % 2 != 0:
            raise ValueError(f"SinusoidalEmbedding dim must be even, got {dim}")
        self.dim = dim

        # Precompute the frequency denominators: 10000^(2i/dim) for i=0..dim/2-1
        # Stored as a buffer (not a parameter — no gradients needed).
        half_dim = dim // 2
        exponents = torch.arange(half_dim, dtype=torch.float32) / half_dim  # [0, 1)
        inv_freq = 1.0 / (10000.0 ** exponents)  # [half_dim]
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        """Embed scalar timesteps.

        Args:
            t: Flow time values, shape [B] or [B, 1]. Values in [0, 1].

        Returns:
            Embedding of shape [B, dim].
        """
        # Flatten to [B]
        if t.ndim == 0:
            t = t.unsqueeze(0)
        t = t.view(-1).float()  # [B]

        # Outer product: [B, 1] * [1, half_dim] → [B, half_dim]
        angles = t.unsqueeze(1) * self.inv_freq.unsqueeze(0)  # [B, half_dim]

        # Interleave sin and cos → [B, dim]
        emb = torch.cat([angles.sin(), angles.cos()], dim=-1)  # [B, dim]
        return emb


class ResBlock(nn.Module):
    """Residual block with Adaptive Group Normalization (AdaGN).

    Architecture:
        GroupNorm → SiLU → Conv3x3 → GroupNorm (modulated by cond) → SiLU → Conv3x3 → + skip

    The conditioning vector (time + action embedding) modulates the second
    normalization layer via learned scale and shift. This is how the U-Net
    knows *what action was taken* and *what flow time step we're at* — it
    adjusts the internal feature processing based on these signals.

    Why AdaGN and not concatenation or cross-attention?
    - The action doesn't change what's in the image — it changes *how* the image
      should change. AdaGN modulates processing, which is the right inductive bias.
    - Cross-attention is expensive; AdaGN is a single linear layer per block.

    Reference: docs/build_spec.md §2.3 (ResBlock with AdaGN code),
               docs/foundations_guide.md Part 7.

    Args:
        in_ch: Input channels.
        out_ch: Output channels.
        cond_dim: Conditioning embedding dimension (time + action).
        num_groups: Number of groups for GroupNorm.
    """

    def __init__(
        self,
        in_ch: int,
        out_ch: int,
        cond_dim: int = 512,
        num_groups: int = 32,
    ) -> None:
        super().__init__()

        # First conv path: norm → activation → conv
        self.norm1 = nn.GroupNorm(num_groups, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)

        # Second conv path: norm (modulated by AdaGN) → activation → conv
        self.norm2 = nn.GroupNorm(num_groups, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)

        # AdaGN: conditioning → scale and shift for norm2
        # Output is 2 * out_ch: first half is scale, second half is shift.
        self.adagn = nn.Linear(cond_dim, out_ch * 2)

        # Skip / residual connection: 1x1 conv if channel count changes, else identity
        self.skip = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        """Forward pass.

        Args:
            x: Feature map, shape [B, in_ch, H, W].
            cond: Conditioning embedding, shape [B, cond_dim].

        Returns:
            Output feature map, shape [B, out_ch, H, W].
        """
        h = F.silu(self.norm1(x))
        h = self.conv1(h)

        # Adaptive GroupNorm: modulate normalized features with conditioning
        # scale and shift are [B, out_ch], need to broadcast to [B, out_ch, H, W]
        scale, shift = self.adagn(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
        h = self.norm2(h) * (1 + scale) + shift

        h = F.silu(h)
        h = self.conv2(h)

        return h + self.skip(x)


class SelfAttention(nn.Module):
    """Standard QKV self-attention applied at 2D spatial positions.

    Used only at 16×16 resolution in the U-Net (compute is O(n²) in spatial
    size — at 128×128 it would be prohibitive: 16384² = 268M attention entries
    vs 256² = 65K at 16×16).

    Architecture: GroupNorm → QKV projection → scaled dot-product attention → output projection

    Args:
        channels: Number of input/output channels.
        num_heads: Number of attention heads. channels must be divisible by num_heads.
        num_groups: Number of groups for GroupNorm.
    """

    def __init__(
        self,
        channels: int,
        num_heads: int = 4,
        num_groups: int = 32,
    ) -> None:
        super().__init__()
        if channels % num_heads != 0:
            raise ValueError(
                f"channels ({channels}) must be divisible by num_heads ({num_heads})"
            )

        self.num_heads = num_heads
        self.head_dim = channels // num_heads

        self.norm = nn.GroupNorm(num_groups, channels)
        # Single projection for Q, K, V concatenated
        self.qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
        self.out_proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply self-attention over spatial positions.

        Args:
            x: Feature map, shape [B, C, H, W].

        Returns:
            Attended feature map, shape [B, C, H, W] (residual added).
        """
        residual = x
        B, C, H, W = x.shape

        x = self.norm(x)

        # Project to Q, K, V: [B, 3*C, H, W]
        qkv = self.qkv(x)
        # Reshape to [B, 3, num_heads, head_dim, H*W] then permute
        qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]  # each [B, num_heads, head_dim, H*W]

        # Transpose to [B, num_heads, H*W, head_dim] for attention
        q = q.permute(0, 1, 3, 2)  # [B, heads, H*W, head_dim]
        k = k.permute(0, 1, 3, 2)  # [B, heads, H*W, head_dim]
        v = v.permute(0, 1, 3, 2)  # [B, heads, H*W, head_dim]

        # Scaled dot-product attention
        # Using PyTorch's efficient implementation when available
        attn_out = F.scaled_dot_product_attention(q, k, v)  # [B, heads, H*W, head_dim]

        # Reshape back to [B, C, H, W]
        attn_out = attn_out.permute(0, 1, 3, 2)  # [B, heads, head_dim, H*W]
        attn_out = attn_out.reshape(B, C, H, W)

        # Output projection + residual
        return self.out_proj(attn_out) + residual


class Downsample(nn.Module):
    """Spatial downsampling via strided convolution.

    Reduces spatial dimensions by 2× using a 4×4 conv with stride 2.
    Learned downsampling (not just pooling) — the model learns what
    information to preserve vs discard at each resolution.

    Args:
        channels: Number of input/output channels (preserved).
    """

    def __init__(self, channels: int) -> None:
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Downsample by 2×.

        Args:
            x: Feature map, shape [B, C, H, W].

        Returns:
            Downsampled feature map, shape [B, C, H//2, W//2].
        """
        return self.conv(x)


class Upsample(nn.Module):
    """Spatial upsampling via nearest-neighbor interpolation + convolution.

    Increases spatial dimensions by 2×. Nearest-neighbor avoids checkerboard
    artifacts that transposed convolutions can produce. The subsequent conv
    learns to smooth and refine the upsampled features.

    Args:
        channels: Number of input/output channels (preserved).
    """

    def __init__(self, channels: int) -> None:
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Upsample by 2×.

        Args:
            x: Feature map, shape [B, C, H, W].

        Returns:
            Upsampled feature map, shape [B, C, H*2, W*2].
        """
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        return self.conv(x)