AbstractPhil commited on
Commit
b4a22b4
Β·
verified Β·
1 Parent(s): 094c5fd

Create constellation_diffusion.py

Browse files
Files changed (1) hide show
  1. constellation_diffusion.py +561 -0
constellation_diffusion.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flow Matching Diffusion with Constellation Relay Regulator
4
+ =============================================================
5
+ ODE-based flow matching (not DDPM) on CIFAR-10.
6
+ Constellation relay inserted at LayerNorm boundaries as
7
+ geometric regulator.
8
+
9
+ Flow matching:
10
+ Forward: x_t = (1-t) * x_0 + t * Ξ΅
11
+ Target: v = Ξ΅ - x_0
12
+ Loss: ||v_pred(x_t, t) - v||Β²
13
+ Sample: Euler ODE from t=1 β†’ t=0
14
+
15
+ Architecture:
16
+ Small UNet with ConvNeXt blocks
17
+ Middle: self-attention + constellation relay after each norm
18
+ Time + class conditioning via adaptive normalization
19
+
20
+ The relay operates at the normalized manifold between blocks,
21
+ snapping geometry back to the constellation reference frame
22
+ after each attention + conv perturbation.
23
+ """
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ import numpy as np
29
+ import math
30
+ import os
31
+ import time
32
+ from tqdm import tqdm
33
+ from torchvision import datasets, transforms
34
+ from torchvision.utils import save_image, make_grid
35
+
36
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+ torch.backends.cuda.matmul.allow_tf32 = True
38
+ torch.backends.cudnn.allow_tf32 = True
39
+
40
+
41
+ # ══════════════════════════════════════════════════════════════════
42
+ # CONSTELLATION RELAY (adapted for feature maps)
43
+ # ══════════════════════════════════════════════════════════════════
44
+
45
+ class ConstellationRelay(nn.Module):
46
+ """
47
+ Geometric regulator for feature maps.
48
+ Operates on channel dimension after spatial pooling or per-pixel.
49
+
50
+ Input: (B, C, H, W) feature map
51
+ Mode: 'channel' β€” pool spatial, relay on (B, C), unpool back
52
+ 'pixel' β€” relay on (B*H*W, C) β€” expensive but thorough
53
+ """
54
+ def __init__(self, channels, patch_dim=16, n_anchors=16, n_phases=3,
55
+ pw_hidden=32, gate_init=-3.0, mode='channel'):
56
+ super().__init__()
57
+ assert channels % patch_dim == 0
58
+ self.channels = channels
59
+ self.patch_dim = patch_dim
60
+ self.n_patches = channels // patch_dim
61
+ self.n_anchors = n_anchors
62
+ self.n_phases = n_phases
63
+ self.mode = mode
64
+
65
+ P, A, d = self.n_patches, n_anchors, patch_dim
66
+
67
+ home = torch.empty(P, A, d)
68
+ nn.init.xavier_normal_(home.view(P * A, d))
69
+ home = F.normalize(home.view(P, A, d), dim=-1)
70
+ self.register_buffer('home', home)
71
+ self.anchors = nn.Parameter(home.clone())
72
+
73
+ tri_dim = n_phases * A
74
+ self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))
75
+ self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden))
76
+ self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
77
+ self.pw_b2 = nn.Parameter(torch.zeros(1, P, d))
78
+ for p in range(P):
79
+ nn.init.xavier_normal_(self.pw_w1.data[p])
80
+ nn.init.xavier_normal_(self.pw_w2.data[p])
81
+ self.pw_norm = nn.LayerNorm(d)
82
+ self.gates = nn.Parameter(torch.full((P,), gate_init))
83
+ self.norm = nn.LayerNorm(channels)
84
+
85
+ def drift(self):
86
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
87
+ return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
88
+
89
+ def at_phase(self, t):
90
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
91
+ omega = self.drift().unsqueeze(-1)
92
+ so = omega.sin().clamp(min=1e-7)
93
+ return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
94
+
95
+ def _relay_core(self, x_flat):
96
+ """x_flat: (N, C) β†’ (N, C)"""
97
+ N, C = x_flat.shape
98
+ P, A, d = self.n_patches, self.n_anchors, self.patch_dim
99
+
100
+ x_n = self.norm(x_flat)
101
+ patches = x_n.reshape(N, P, d)
102
+ patches_n = F.normalize(patches, dim=-1)
103
+
104
+ phases = torch.linspace(0, 1, self.n_phases).tolist()
105
+ tris = []
106
+ for t in phases:
107
+ at = F.normalize(self.at_phase(t), dim=-1)
108
+ tris.append(1.0 - torch.einsum('npd,pad->npa', patches_n, at))
109
+ tri = torch.cat(tris, dim=-1)
110
+
111
+ h = F.gelu(torch.einsum('npt,pth->nph', tri, self.pw_w1) + self.pw_b1)
112
+ pw = self.pw_norm(torch.einsum('nph,phd->npd', h, self.pw_w2) + self.pw_b2)
113
+
114
+ g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)
115
+ blended = g * pw + (1-g) * patches
116
+ return x_flat + blended.reshape(N, C)
117
+
118
+ def forward(self, x):
119
+ """x: (B, C, H, W)"""
120
+ B, C, H, W = x.shape
121
+ if self.mode == 'channel':
122
+ # Global average pool β†’ relay β†’ broadcast back
123
+ pooled = x.mean(dim=(-2, -1)) # (B, C)
124
+ relayed = self._relay_core(pooled) # (B, C)
125
+ # Scale feature map by relay correction
126
+ scale = (relayed / (pooled + 1e-8)).unsqueeze(-1).unsqueeze(-1)
127
+ return x * scale.clamp(-3, 3) # prevent extreme scaling
128
+ else:
129
+ # Per-pixel relay β€” (B*H*W, C)
130
+ x_flat = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
131
+ out = self._relay_core(x_flat)
132
+ return out.reshape(B, H, W, C).permute(0, 3, 1, 2)
133
+
134
+
135
+ # ══════════════════════════════════════════════════════════════════
136
+ # BUILDING BLOCKS
137
+ # ══════════════════════════════════════════════════════════════════
138
+
139
+ class SinusoidalPosEmb(nn.Module):
140
+ def __init__(self, dim):
141
+ super().__init__()
142
+ self.dim = dim
143
+
144
+ def forward(self, t):
145
+ half = self.dim // 2
146
+ emb = math.log(10000) / (half - 1)
147
+ emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
148
+ emb = t.unsqueeze(-1) * emb.unsqueeze(0)
149
+ return torch.cat([emb.sin(), emb.cos()], dim=-1)
150
+
151
+
152
+ class AdaGroupNorm(nn.Module):
153
+ """Group norm with adaptive scale/shift from conditioning."""
154
+ def __init__(self, channels, cond_dim, n_groups=8):
155
+ super().__init__()
156
+ self.gn = nn.GroupNorm(min(n_groups, channels), channels, affine=False)
157
+ self.proj = nn.Linear(cond_dim, channels * 2)
158
+ nn.init.zeros_(self.proj.weight)
159
+ nn.init.zeros_(self.proj.bias)
160
+
161
+ def forward(self, x, cond):
162
+ x = self.gn(x)
163
+ scale, shift = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
164
+ return x * (1 + scale) + shift
165
+
166
+
167
+ class ConvBlock(nn.Module):
168
+ """ConvNeXt-style block with adaptive norm."""
169
+ def __init__(self, channels, cond_dim, use_relay=False):
170
+ super().__init__()
171
+ self.dw_conv = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
172
+ self.norm = AdaGroupNorm(channels, cond_dim)
173
+ self.pw1 = nn.Conv2d(channels, channels * 4, 1)
174
+ self.pw2 = nn.Conv2d(channels * 4, channels, 1)
175
+ self.act = nn.GELU()
176
+
177
+ self.relay = ConstellationRelay(
178
+ channels, patch_dim=min(16, channels),
179
+ n_anchors=min(16, channels),
180
+ n_phases=3, pw_hidden=32, gate_init=-3.0,
181
+ mode='channel') if use_relay else None
182
+
183
+ def forward(self, x, cond):
184
+ residual = x
185
+ x = self.dw_conv(x)
186
+ x = self.norm(x, cond)
187
+ x = self.pw1(x)
188
+ x = self.act(x)
189
+ x = self.pw2(x)
190
+ x = residual + x
191
+ if self.relay is not None:
192
+ x = self.relay(x)
193
+ return x
194
+
195
+
196
+ class SelfAttnBlock(nn.Module):
197
+ """Simple self-attention for feature maps."""
198
+ def __init__(self, channels, n_heads=4):
199
+ super().__init__()
200
+ self.n_heads = n_heads
201
+ self.head_dim = channels // n_heads
202
+ self.norm = nn.GroupNorm(8, channels)
203
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
204
+ self.out = nn.Conv2d(channels, channels, 1)
205
+ nn.init.zeros_(self.out.weight)
206
+ nn.init.zeros_(self.out.bias)
207
+
208
+ def forward(self, x):
209
+ B, C, H, W = x.shape
210
+ residual = x
211
+ x = self.norm(x)
212
+ qkv = self.qkv(x).reshape(B, 3, self.n_heads, self.head_dim, H * W)
213
+ q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
214
+ attn = F.scaled_dot_product_attention(q, k, v)
215
+ out = attn.reshape(B, C, H, W)
216
+ return residual + self.out(out)
217
+
218
+
219
+ class Downsample(nn.Module):
220
+ def __init__(self, channels):
221
+ super().__init__()
222
+ self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
223
+
224
+ def forward(self, x):
225
+ return self.conv(x)
226
+
227
+
228
+ class Upsample(nn.Module):
229
+ def __init__(self, channels):
230
+ super().__init__()
231
+ self.conv = nn.Conv2d(channels, channels, 3, padding=1)
232
+
233
+ def forward(self, x):
234
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
235
+ return self.conv(x)
236
+
237
+
238
+ # ══════════════════════════════════════════════════════════════════
239
+ # FLOW MATCHING UNET
240
+ # ══════════════════════════════════════════════════════════════════
241
+
242
+ class FlowMatchUNet(nn.Module):
243
+ """
244
+ Clean UNet for flow matching.
245
+ Explicit skip tracking β€” no dynamic insertion.
246
+
247
+ Encoder: [64@32] β†’ down β†’ [128@16] β†’ down β†’ [256@8]
248
+ Middle: [256@8] with attention + relay
249
+ Decoder: [256@8] β†’ up β†’ [128@16] β†’ up β†’ [64@32]
250
+ """
251
+ def __init__(
252
+ self,
253
+ in_channels=3,
254
+ base_channels=64,
255
+ channel_mults=(1, 2, 4),
256
+ n_classes=10,
257
+ cond_dim=256,
258
+ use_relay=True,
259
+ ):
260
+ super().__init__()
261
+ self.use_relay = use_relay
262
+ self.channel_mults = channel_mults
263
+
264
+ # Time + class conditioning
265
+ self.time_emb = nn.Sequential(
266
+ SinusoidalPosEmb(cond_dim),
267
+ nn.Linear(cond_dim, cond_dim), nn.GELU(),
268
+ nn.Linear(cond_dim, cond_dim))
269
+ self.class_emb = nn.Embedding(n_classes, cond_dim)
270
+
271
+ # Input projection
272
+ self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
273
+
274
+ # Build encoder: 2 conv blocks per level, then downsample
275
+ self.enc = nn.ModuleList()
276
+ self.enc_down = nn.ModuleList()
277
+ ch_in = base_channels
278
+ enc_channels = [base_channels] # track channels at each skip point
279
+
280
+ for i, mult in enumerate(channel_mults):
281
+ ch_out = base_channels * mult
282
+ self.enc.append(nn.ModuleList([
283
+ ConvBlock(ch_in, cond_dim) if ch_in == ch_out
284
+ else nn.Sequential(nn.Conv2d(ch_in, ch_out, 1),
285
+ ConvBlock(ch_out, cond_dim)),
286
+ ConvBlock(ch_out, cond_dim),
287
+ ]))
288
+ ch_in = ch_out
289
+ enc_channels.append(ch_out)
290
+ if i < len(channel_mults) - 1:
291
+ self.enc_down.append(Downsample(ch_out))
292
+
293
+ # Middle
294
+ mid_ch = ch_in
295
+ self.mid_block1 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay)
296
+ self.mid_attn = SelfAttnBlock(mid_ch, n_heads=4)
297
+ self.mid_block2 = ConvBlock(mid_ch, cond_dim, use_relay=use_relay)
298
+
299
+ # Build decoder: upsample, concat skip, 2 conv blocks per level
300
+ self.dec_up = nn.ModuleList()
301
+ self.dec_skip_proj = nn.ModuleList()
302
+ self.dec = nn.ModuleList()
303
+
304
+ for i in range(len(channel_mults) - 1, -1, -1):
305
+ mult = channel_mults[i]
306
+ ch_out = base_channels * mult
307
+ skip_ch = enc_channels.pop()
308
+
309
+ # Project concatenated channels
310
+ self.dec_skip_proj.append(nn.Conv2d(ch_in + skip_ch, ch_out, 1))
311
+ self.dec.append(nn.ModuleList([
312
+ ConvBlock(ch_out, cond_dim),
313
+ ConvBlock(ch_out, cond_dim),
314
+ ]))
315
+ ch_in = ch_out
316
+ if i > 0:
317
+ self.dec_up.append(Upsample(ch_out))
318
+
319
+ # Output
320
+ self.out_norm = nn.GroupNorm(8, ch_in)
321
+ self.out_conv = nn.Conv2d(ch_in, in_channels, 3, padding=1)
322
+ nn.init.zeros_(self.out_conv.weight)
323
+ nn.init.zeros_(self.out_conv.bias)
324
+
325
+ def forward(self, x, t, class_labels):
326
+ cond = self.time_emb(t) + self.class_emb(class_labels)
327
+
328
+ h = self.in_conv(x)
329
+ skips = [h]
330
+
331
+ # Encoder
332
+ for i in range(len(self.channel_mults)):
333
+ for block in self.enc[i]:
334
+ if isinstance(block, ConvBlock):
335
+ h = block(h, cond)
336
+ elif isinstance(block, nn.Sequential):
337
+ # Conv1x1 then ConvBlock
338
+ h = block[0](h)
339
+ h = block[1](h, cond)
340
+ else:
341
+ h = block(h)
342
+ skips.append(h)
343
+ if i < len(self.enc_down):
344
+ h = self.enc_down[i](h)
345
+
346
+ # Middle
347
+ h = self.mid_block1(h, cond)
348
+ h = self.mid_attn(h)
349
+ h = self.mid_block2(h, cond)
350
+
351
+ # Decoder
352
+ for i in range(len(self.channel_mults)):
353
+ skip = skips.pop()
354
+ # Upsample first if needed (except first decoder level)
355
+ if i > 0:
356
+ h = self.dec_up[i - 1](h)
357
+ h = torch.cat([h, skip], dim=1)
358
+ h = self.dec_skip_proj[i](h)
359
+ for block in self.dec[i]:
360
+ h = block(h, cond)
361
+
362
+ h = self.out_norm(h)
363
+ h = F.silu(h)
364
+ return self.out_conv(h)
365
+
366
+
367
+ # ══════════════════════════════════════════════════════════════════
368
+ # FLOW MATCHING TRAINING
369
+ # ══════════════════════════════════════════════════════════════════
370
+
371
+ # Hyperparams
372
+ BATCH = 128
373
+ EPOCHS = 50
374
+ LR = 3e-4
375
+ BASE_CH = 64
376
+ USE_RELAY = True
377
+ N_CLASSES = 10
378
+ SAMPLE_EVERY = 5
379
+ N_SAMPLE_STEPS = 50 # Euler ODE steps for sampling
380
+
381
+ print("=" * 70)
382
+ print("FLOW MATCHING + CONSTELLATION RELAY REGULATOR")
383
+ print(f" Dataset: CIFAR-10")
384
+ print(f" Base channels: {BASE_CH}")
385
+ print(f" Relay: {USE_RELAY}")
386
+ print(f" Flow matching: ODE (conditional)")
387
+ print(f" Sampler: Euler, {N_SAMPLE_STEPS} steps")
388
+ print(f" Device: {DEVICE}")
389
+ print("=" * 70)
390
+
391
+ # Data
392
+ transform = transforms.Compose([
393
+ transforms.RandomHorizontalFlip(),
394
+ transforms.ToTensor(),
395
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
396
+ ])
397
+ train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
398
+ train_loader = torch.utils.data.DataLoader(
399
+ train_ds, batch_size=BATCH, shuffle=True,
400
+ num_workers=4, pin_memory=True, drop_last=True)
401
+
402
+ print(f" Train: {len(train_ds):,} images")
403
+
404
+ # Model
405
+ model = FlowMatchUNet(
406
+ in_channels=3, base_channels=BASE_CH,
407
+ channel_mults=(1, 2, 4), n_classes=N_CLASSES,
408
+ cond_dim=256, use_relay=USE_RELAY
409
+ ).to(DEVICE)
410
+
411
+ n_params = sum(p.numel() for p in model.parameters())
412
+ relay_params = sum(p.numel() for n, p in model.named_parameters() if 'relay' in n)
413
+ print(f" Total params: {n_params:,}")
414
+ print(f" Relay params: {relay_params:,} ({100*relay_params/n_params:.1f}%)")
415
+
416
+ # Count relay modules
417
+ n_relays = sum(1 for m in model.modules() if isinstance(m, ConstellationRelay))
418
+ print(f" Relay modules: {n_relays}")
419
+
420
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
421
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
422
+ optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6)
423
+ scaler = torch.amp.GradScaler("cuda")
424
+
425
+ os.makedirs("samples", exist_ok=True)
426
+ os.makedirs("checkpoints", exist_ok=True)
427
+
428
+
429
+ @torch.no_grad()
430
+ def sample(model, n_samples=64, n_steps=50, class_label=None):
431
+ """Euler ODE sampling from t=1 (noise) to t=0 (data)."""
432
+ model.eval()
433
+ B = n_samples
434
+ x = torch.randn(B, 3, 32, 32, device=DEVICE)
435
+
436
+ if class_label is not None:
437
+ labels = torch.full((B,), class_label, dtype=torch.long, device=DEVICE)
438
+ else:
439
+ labels = torch.randint(0, N_CLASSES, (B,), device=DEVICE)
440
+
441
+ dt = 1.0 / n_steps
442
+ for step in range(n_steps):
443
+ t_val = 1.0 - step * dt
444
+ t = torch.full((B,), t_val, device=DEVICE)
445
+
446
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
447
+ v = model(x, t, labels)
448
+
449
+ x = x - v * dt # Euler step: x_{t-dt} = x_t - v * dt
450
+
451
+ # Clamp to valid range
452
+ x = x.clamp(-1, 1)
453
+ return x, labels
454
+
455
+
456
+ # ══════════════════════════════════════════════════════════════════
457
+ # TRAINING LOOP
458
+ # ══════════════════════════════════════════════════════════════════
459
+
460
+ print(f"\n{'='*70}")
461
+ print(f"TRAINING β€” {EPOCHS} epochs")
462
+ print(f"{'='*70}")
463
+
464
+ best_loss = float('inf')
465
+ gs = 0
466
+
467
+ for epoch in range(EPOCHS):
468
+ model.train()
469
+ t0 = time.time()
470
+ total_loss = 0
471
+ n = 0
472
+
473
+ pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
474
+ for images, labels in pbar:
475
+ images = images.to(DEVICE, non_blocking=True) # (B, 3, 32, 32) in [-1, 1]
476
+ labels = labels.to(DEVICE, non_blocking=True)
477
+ B = images.shape[0]
478
+
479
+ # Flow matching: sample t, compute x_t and target velocity
480
+ t = torch.rand(B, device=DEVICE)
481
+ eps = torch.randn_like(images)
482
+
483
+ # x_t = (1-t) * x_0 + t * eps
484
+ t_b = t.view(B, 1, 1, 1)
485
+ x_t = (1 - t_b) * images + t_b * eps
486
+
487
+ # Target velocity: v = eps - x_0
488
+ v_target = eps - images
489
+
490
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
491
+ v_pred = model(x_t, t, labels)
492
+ loss = F.mse_loss(v_pred, v_target)
493
+
494
+ optimizer.zero_grad(set_to_none=True)
495
+ scaler.scale(loss).backward()
496
+ scaler.unscale_(optimizer)
497
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
498
+ scaler.step(optimizer)
499
+ scaler.update()
500
+ scheduler.step()
501
+ gs += 1
502
+
503
+ total_loss += loss.item()
504
+ n += 1
505
+
506
+ if n % 20 == 0:
507
+ pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")
508
+
509
+ elapsed = time.time() - t0
510
+ avg_loss = total_loss / n
511
+
512
+ # Checkpoint
513
+ mk = ""
514
+ if avg_loss < best_loss:
515
+ best_loss = avg_loss
516
+ torch.save({
517
+ 'state_dict': model.state_dict(),
518
+ 'epoch': epoch + 1,
519
+ 'loss': avg_loss,
520
+ 'use_relay': USE_RELAY,
521
+ }, 'checkpoints/flow_match_best.pt')
522
+ mk = " β˜…"
523
+
524
+ print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} "
525
+ f"({elapsed:.0f}s){mk}")
526
+
527
+ # Sample
528
+ if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:
529
+ samples, sample_labels = sample(model, n_samples=64, n_steps=N_SAMPLE_STEPS)
530
+ # Denormalize
531
+ samples = (samples + 1) / 2 # [-1,1] β†’ [0,1]
532
+ grid = make_grid(samples, nrow=8, normalize=False)
533
+ save_image(grid, f'samples/epoch_{epoch+1:03d}.png')
534
+ print(f" β†’ Saved samples/epoch_{epoch+1:03d}.png")
535
+
536
+ # Per-class samples
537
+ if (epoch + 1) % (SAMPLE_EVERY * 2) == 0:
538
+ class_names = ['plane', 'auto', 'bird', 'cat', 'deer',
539
+ 'dog', 'frog', 'horse', 'ship', 'truck']
540
+ for c in range(N_CLASSES):
541
+ cs, _ = sample(model, n_samples=8, n_steps=N_SAMPLE_STEPS, class_label=c)
542
+ cs = (cs + 1) / 2
543
+ save_image(make_grid(cs, nrow=8),
544
+ f'samples/epoch_{epoch+1:03d}_class_{class_names[c]}.png')
545
+
546
+ # Relay diagnostics
547
+ if USE_RELAY and (epoch + 1) % 10 == 0:
548
+ print(f" Relay diagnostics:")
549
+ for name, module in model.named_modules():
550
+ if isinstance(module, ConstellationRelay):
551
+ drift = module.drift().mean().item()
552
+ gate = module.gates.sigmoid().mean().item()
553
+ print(f" {name}: drift={drift:.4f} rad "
554
+ f"({math.degrees(drift):.1f}Β°) gate={gate:.4f}")
555
+
556
+
557
+ print(f"\n{'='*70}")
558
+ print(f"DONE β€” Best loss: {best_loss:.4f}")
559
+ print(f" Params: {n_params:,} (relay: {relay_params:,})")
560
+ print(f" Samples in: samples/")
561
+ print(f"{'='*70}")