AbstractPhil commited on
Commit
0a70478
Β·
verified Β·
1 Parent(s): 889b51c

Create GLFM_trainer_model.py

Browse files
Files changed (1) hide show
  1. GLFM_trainer_model.py +652 -0
GLFM_trainer_model.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Geometric Lookup Flow Matching (GLFM)
4
+ ========================================
5
+ A flow matching variant where velocity prediction is driven by
6
+ geometric address lookup on S^15.
7
+
8
+ Core insight (empirical):
9
+ The constellation bottleneck doesn't reconstruct encoder features.
10
+ It produces cos_sim β‰ˆ 0 to its input. Instead, the triangulation
11
+ profile acts as a continuous ADDRESS on the unit hypersphere,
12
+ and the generator produces velocity fields from that address.
13
+
14
+ This is: v(x_t, t, c) = Generator(Address(x_t), t, c)
15
+ where Address(x) = triangulate(project_to_sphere(encode(x)))
16
+
17
+ GLFM formalizes this into three stages:
18
+
19
+ Stage 1 β€” GEOMETRIC ADDRESSING
20
+ Encoder maps x_t to multiple resolution embeddings on S^15.
21
+ Each resolution captures different spatial frequency information.
22
+ Triangulation against fixed anchors produces a structured address.
23
+
24
+ Stage 2 β€” ADDRESS CONDITIONING
25
+ The geometric address is concatenated with:
26
+ - Timestep embedding (sinusoidal)
27
+ - Class/text conditioning
28
+ - Noise level features
29
+ The conditioning modulates WHAT to generate at this address.
30
+
31
+ Stage 3 β€” VELOCITY GENERATION
32
+ A deep MLP generates the velocity field from the conditioned address.
33
+ This is NOT reconstruction β€” it's generation from a lookup.
34
+ The generator never sees the raw encoder features.
35
+
36
+ Key properties:
37
+ - Address space is geometrically structured (Voronoi cells on S^15)
38
+ - Anchors self-organize: <0.29 rad = frame holders, >0.29 = task encoders
39
+ - Precision-invariant (works at fp8)
40
+ - 21Γ— compression with zero velocity quality loss
41
+ - Multi-scale addressing captures both coarse and fine structure
42
+ """
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+ import math
48
+ import os
49
+ import time
50
+ from tqdm import tqdm
51
+ from torchvision import datasets, transforms
52
+ from torchvision.utils import save_image, make_grid
53
+
54
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
55
+ torch.backends.cuda.matmul.allow_tf32 = True
56
+ torch.backends.cudnn.allow_tf32 = True
57
+
58
+
59
+ # ══════════════════════════════════════════════════════════════════
60
+ # STAGE 1: GEOMETRIC ADDRESSING
61
+ # ══════════════════════════════════════════════════════════════════
62
+
63
+ class GeometricAddressEncoder(nn.Module):
64
+ """
65
+ Maps spatial features to geometric addresses on S^15.
66
+
67
+ Multi-scale: produces addresses at 2 resolutions.
68
+ - Coarse: global pool β†’ single 256d embedding β†’ 1 address
69
+ - Fine: per-spatial-position β†’ 256d embeddings β†’ HW addresses
70
+
71
+ Each address is triangulated against the constellation.
72
+ The combined triangulation profiles form the full geometric address.
73
+ """
74
+ def __init__(
75
+ self,
76
+ spatial_channels, # C from encoder output
77
+ spatial_size, # H (=W) from encoder output
78
+ embed_dim=256,
79
+ patch_dim=16,
80
+ n_anchors=16,
81
+ n_phases=3,
82
+ ):
83
+ super().__init__()
84
+ self.spatial_channels = spatial_channels
85
+ self.spatial_size = spatial_size
86
+ self.embed_dim = embed_dim
87
+ self.patch_dim = patch_dim
88
+ self.n_patches = embed_dim // patch_dim
89
+ self.n_anchors = n_anchors
90
+ self.n_phases = n_phases
91
+
92
+ P, A, d = self.n_patches, n_anchors, patch_dim
93
+
94
+ # Coarse address: global pool β†’ sphere
95
+ self.coarse_proj = nn.Sequential(
96
+ nn.Linear(spatial_channels, embed_dim),
97
+ nn.LayerNorm(embed_dim),
98
+ )
99
+
100
+ # Fine address: per-position β†’ sphere
101
+ self.fine_proj = nn.Sequential(
102
+ nn.Linear(spatial_channels, embed_dim),
103
+ nn.LayerNorm(embed_dim),
104
+ )
105
+
106
+ # Shared constellation β€” same anchors for both scales
107
+ home = torch.empty(P, A, d)
108
+ nn.init.xavier_normal_(home.view(P * A, d))
109
+ home = F.normalize(home.view(P, A, d), dim=-1)
110
+ self.register_buffer('home', home)
111
+ self.anchors = nn.Parameter(home.clone())
112
+
113
+ # Triangulation dimensions per address
114
+ self.tri_dim = P * A * n_phases # 768
115
+
116
+ # Total address dim: coarse(768) + fine_aggregated(768)
117
+ self.address_dim = self.tri_dim * 2
118
+
119
+ def drift(self):
120
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
121
+ return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
122
+
123
+ def at_phase(self, t):
124
+ h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
125
+ omega = self.drift().unsqueeze(-1)
126
+ so = omega.sin().clamp(min=1e-7)
127
+ return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
128
+
129
+ def triangulate(self, patches_n):
130
+ """patches_n: (..., P, d) β†’ (..., P*A*n_phases)"""
131
+ shape = patches_n.shape[:-2]
132
+ P, A, d = self.n_patches, self.n_anchors, self.patch_dim
133
+ flat = patches_n.reshape(-1, P, d)
134
+ phases = torch.linspace(0, 1, self.n_phases, device=flat.device).tolist()
135
+ tris = []
136
+ for t in phases:
137
+ at = F.normalize(self.at_phase(t), dim=-1)
138
+ tris.append(1.0 - torch.einsum('bpd,pad->bpa', flat, at))
139
+ tri = torch.cat(tris, dim=-1).reshape(flat.shape[0], -1)
140
+ return tri.reshape(*shape, -1)
141
+
142
+ def forward(self, feature_map):
143
+ """
144
+ feature_map: (B, C, H, W) from encoder
145
+ Returns: (B, address_dim) geometric address
146
+ """
147
+ B, C, H, W = feature_map.shape
148
+
149
+ # Coarse: global pool β†’ single address
150
+ coarse = feature_map.mean(dim=(-2, -1)) # (B, C)
151
+ coarse_emb = self.coarse_proj(coarse) # (B, embed_dim)
152
+ coarse_patches = F.normalize(
153
+ coarse_emb.reshape(B, self.n_patches, self.patch_dim), dim=-1)
154
+ coarse_addr = self.triangulate(coarse_patches) # (B, tri_dim)
155
+
156
+ # Fine: per-position, then aggregate
157
+ fine = feature_map.permute(0, 2, 3, 1).reshape(B * H * W, C) # (BHW, C)
158
+ fine_emb = self.fine_proj(fine) # (BHW, embed_dim)
159
+ fine_patches = F.normalize(
160
+ fine_emb.reshape(B * H * W, self.n_patches, self.patch_dim), dim=-1)
161
+ fine_addr = self.triangulate(fine_patches) # (BHW, tri_dim)
162
+ # Aggregate fine addresses: mean + max pooling
163
+ fine_addr = fine_addr.reshape(B, H * W, -1)
164
+ fine_mean = fine_addr.mean(dim=1) # (B, tri_dim)
165
+ fine_max = fine_addr.max(dim=1).values # (B, tri_dim)
166
+ # Combine mean and max via learned gate
167
+ fine_combined = (fine_mean + fine_max) / 2 # (B, tri_dim)
168
+
169
+ # Full address = coarse + fine
170
+ return torch.cat([coarse_addr, fine_combined], dim=-1) # (B, 2*tri_dim)
171
+
172
+
173
+ # ══════════════════════════════════════════════════════════════════
174
+ # STAGE 2: ADDRESS CONDITIONING
175
+ # ══════════════════════════════════════════════════════════════════
176
+
177
+ class AddressConditioner(nn.Module):
178
+ """
179
+ Combines geometric address with timestep and class conditioning.
180
+ Produces a conditioned address vector ready for the generator.
181
+ """
182
+ def __init__(self, address_dim, cond_dim=256, output_dim=1024):
183
+ super().__init__()
184
+ self.time_emb = nn.Sequential(
185
+ SinusoidalPosEmb(cond_dim),
186
+ nn.Linear(cond_dim, cond_dim), nn.GELU(),
187
+ nn.Linear(cond_dim, cond_dim))
188
+
189
+ # Noise level features β€” learned embedding of discretized t
190
+ self.noise_emb = nn.Embedding(64, cond_dim)
191
+
192
+ self.fuse = nn.Sequential(
193
+ nn.Linear(address_dim + cond_dim * 3, output_dim),
194
+ nn.GELU(),
195
+ nn.LayerNorm(output_dim),
196
+ )
197
+
198
+ def forward(self, address, t, class_emb):
199
+ """
200
+ address: (B, address_dim) from geometric encoder
201
+ t: (B,) timestep
202
+ class_emb: (B, cond_dim) class embedding
203
+ Returns: (B, output_dim) conditioned address
204
+ """
205
+ t_emb = self.time_emb(t)
206
+ # Discretize t for noise level embedding
207
+ t_discrete = (t * 63).long().clamp(0, 63)
208
+ n_emb = self.noise_emb(t_discrete)
209
+
210
+ combined = torch.cat([address, t_emb, class_emb, n_emb], dim=-1)
211
+ return self.fuse(combined)
212
+
213
+
214
+ # ══════════════════════════════════════════════════════════════════
215
+ # STAGE 3: VELOCITY GENERATOR
216
+ # ══════════════════════════════════════════════════════════════════
217
+
218
+ class VelocityGenerator(nn.Module):
219
+ """
220
+ Generates spatial velocity features from a conditioned address.
221
+ NOT reconstruction β€” generation from geometric lookup.
222
+ """
223
+ def __init__(self, cond_address_dim, spatial_dim, hidden=1024, depth=4):
224
+ super().__init__()
225
+ self.spatial_dim = spatial_dim
226
+
227
+ # Deep residual MLP
228
+ self.blocks = nn.ModuleList()
229
+ self.blocks.append(nn.Sequential(
230
+ nn.Linear(cond_address_dim, hidden),
231
+ nn.GELU(), nn.LayerNorm(hidden)))
232
+
233
+ for _ in range(depth):
234
+ self.blocks.append(ResBlock(hidden))
235
+
236
+ self.head = nn.Sequential(
237
+ nn.Linear(hidden, hidden), nn.GELU(),
238
+ nn.Linear(hidden, spatial_dim))
239
+
240
+ def forward(self, cond_address):
241
+ """
242
+ cond_address: (B, cond_address_dim)
243
+ Returns: (B, spatial_dim) generated velocity features
244
+ """
245
+ h = self.blocks[0](cond_address)
246
+ for block in self.blocks[1:]:
247
+ h = block(h)
248
+ return self.head(h)
249
+
250
+
251
+ class ResBlock(nn.Module):
252
+ def __init__(self, dim):
253
+ super().__init__()
254
+ self.net = nn.Sequential(
255
+ nn.Linear(dim, dim), nn.GELU(), nn.LayerNorm(dim),
256
+ nn.Linear(dim, dim), nn.GELU(), nn.LayerNorm(dim))
257
+
258
+ def forward(self, x):
259
+ return x + self.net(x)
260
+
261
+
262
+ # ══════════════════════════════════════════════════════════════════
263
+ # BUILDING BLOCKS
264
+ # ══════════════════════════════════════════════════════════════════
265
+
266
+ class SinusoidalPosEmb(nn.Module):
267
+ def __init__(self, dim):
268
+ super().__init__()
269
+ self.dim = dim
270
+
271
+ def forward(self, t):
272
+ half = self.dim // 2
273
+ emb = math.log(10000) / (half - 1)
274
+ emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb)
275
+ emb = t.unsqueeze(-1) * emb.unsqueeze(0)
276
+ return torch.cat([emb.sin(), emb.cos()], dim=-1)
277
+
278
+
279
+ class AdaGroupNorm(nn.Module):
280
+ def __init__(self, ch, cond_dim, groups=8):
281
+ super().__init__()
282
+ self.gn = nn.GroupNorm(min(groups, ch), ch, affine=False)
283
+ self.proj = nn.Linear(cond_dim, ch * 2)
284
+ nn.init.zeros_(self.proj.weight); nn.init.zeros_(self.proj.bias)
285
+
286
+ def forward(self, x, cond):
287
+ x = self.gn(x)
288
+ s, sh = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
289
+ return x * (1 + s) + sh
290
+
291
+
292
+ class ConvBlock(nn.Module):
293
+ def __init__(self, ch, cond_dim):
294
+ super().__init__()
295
+ self.dw = nn.Conv2d(ch, ch, 7, padding=3, groups=ch)
296
+ self.norm = AdaGroupNorm(ch, cond_dim)
297
+ self.pw1 = nn.Conv2d(ch, ch * 4, 1)
298
+ self.pw2 = nn.Conv2d(ch * 4, ch, 1)
299
+ self.act = nn.GELU()
300
+
301
+ def forward(self, x, cond):
302
+ r = x
303
+ x = self.act(self.pw1(self.norm(self.dw(x), cond)))
304
+ return r + self.pw2(x)
305
+
306
+
307
+ class Downsample(nn.Module):
308
+ def __init__(self, ch):
309
+ super().__init__()
310
+ self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1)
311
+ def forward(self, x): return self.conv(x)
312
+
313
+
314
+ class Upsample(nn.Module):
315
+ def __init__(self, ch):
316
+ super().__init__()
317
+ self.conv = nn.Conv2d(ch, ch, 3, padding=1)
318
+ def forward(self, x):
319
+ return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
320
+
321
+
322
+ # ══════════════════════════════════════════════════════════════════
323
+ # GLFM UNET
324
+ # ══════════════════════════════════════════════════════════════════
325
+
326
+ class GLFMUNet(nn.Module):
327
+ """
328
+ Geometric Lookup Flow Matching UNet.
329
+
330
+ Encoder β†’ GeometricAddress β†’ Conditioner β†’ VelocityGenerator β†’ Decoder
331
+
332
+ The middle of the UNet is the three-stage GLFM pipeline.
333
+ No attention. No reconstruction. Pure geometric lookup.
334
+ """
335
+ def __init__(
336
+ self,
337
+ in_ch=3,
338
+ base_ch=64,
339
+ ch_mults=(1, 2, 4),
340
+ n_classes=10,
341
+ cond_dim=256,
342
+ embed_dim=256,
343
+ n_anchors=16,
344
+ n_phases=3,
345
+ gen_hidden=1024,
346
+ gen_depth=4,
347
+ ):
348
+ super().__init__()
349
+ self.ch_mults = ch_mults
350
+
351
+ # Class embedding (shared with conditioner)
352
+ self.class_emb = nn.Embedding(n_classes, cond_dim)
353
+
354
+ # Encoder conditioning (for AdaGroupNorm in conv blocks)
355
+ self.enc_time = nn.Sequential(
356
+ SinusoidalPosEmb(cond_dim),
357
+ nn.Linear(cond_dim, cond_dim), nn.GELU(),
358
+ nn.Linear(cond_dim, cond_dim))
359
+
360
+ self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1)
361
+
362
+ # Encoder
363
+ self.enc = nn.ModuleList()
364
+ self.enc_down = nn.ModuleList()
365
+ ch = base_ch
366
+ enc_channels = [base_ch]
367
+
368
+ for i, m in enumerate(ch_mults):
369
+ ch_out = base_ch * m
370
+ self.enc.append(nn.ModuleList([
371
+ ConvBlock(ch, cond_dim) if ch == ch_out
372
+ else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)),
373
+ ConvBlock(ch_out, cond_dim),
374
+ ]))
375
+ ch = ch_out
376
+ enc_channels.append(ch)
377
+ if i < len(ch_mults) - 1:
378
+ self.enc_down.append(Downsample(ch))
379
+
380
+ # β˜… GLFM PIPELINE β˜…
381
+ mid_ch = ch
382
+ H_mid = 32 // (2 ** (len(ch_mults) - 1))
383
+ spatial_dim = mid_ch * H_mid * H_mid
384
+ self.mid_spatial = (mid_ch, H_mid, H_mid)
385
+
386
+ # Stage 1: Geometric Address Encoder
387
+ self.geo_encoder = GeometricAddressEncoder(
388
+ spatial_channels=mid_ch,
389
+ spatial_size=H_mid,
390
+ embed_dim=embed_dim,
391
+ patch_dim=16,
392
+ n_anchors=n_anchors,
393
+ n_phases=n_phases,
394
+ )
395
+
396
+ # Stage 2: Address Conditioner
397
+ self.conditioner = AddressConditioner(
398
+ address_dim=self.geo_encoder.address_dim,
399
+ cond_dim=cond_dim,
400
+ output_dim=gen_hidden,
401
+ )
402
+
403
+ # Stage 3: Velocity Generator
404
+ self.generator = VelocityGenerator(
405
+ cond_address_dim=gen_hidden,
406
+ spatial_dim=spatial_dim,
407
+ hidden=gen_hidden,
408
+ depth=gen_depth,
409
+ )
410
+
411
+ # Decoder
412
+ self.dec_up = nn.ModuleList()
413
+ self.dec_skip = nn.ModuleList()
414
+ self.dec = nn.ModuleList()
415
+
416
+ # Decoder conditioning
417
+ self.dec_time = nn.Sequential(
418
+ SinusoidalPosEmb(cond_dim),
419
+ nn.Linear(cond_dim, cond_dim), nn.GELU(),
420
+ nn.Linear(cond_dim, cond_dim))
421
+
422
+ for i in range(len(ch_mults) - 1, -1, -1):
423
+ ch_out = base_ch * ch_mults[i]
424
+ skip_ch = enc_channels.pop()
425
+ self.dec_skip.append(nn.Conv2d(ch + skip_ch, ch_out, 1))
426
+ self.dec.append(nn.ModuleList([
427
+ ConvBlock(ch_out, cond_dim),
428
+ ConvBlock(ch_out, cond_dim),
429
+ ]))
430
+ ch = ch_out
431
+ if i > 0:
432
+ self.dec_up.append(Upsample(ch))
433
+
434
+ self.out_norm = nn.GroupNorm(8, ch)
435
+ self.out_conv = nn.Conv2d(ch, in_ch, 3, padding=1)
436
+ nn.init.zeros_(self.out_conv.weight)
437
+ nn.init.zeros_(self.out_conv.bias)
438
+
439
+ def forward(self, x, t, class_labels):
440
+ # Conditioning
441
+ enc_cond = self.enc_time(t) + self.class_emb(class_labels)
442
+ dec_cond = self.dec_time(t) + self.class_emb(class_labels)
443
+ cls_emb = self.class_emb(class_labels)
444
+
445
+ h = self.in_conv(x)
446
+ skips = [h]
447
+
448
+ # Encoder
449
+ for i in range(len(self.ch_mults)):
450
+ for block in self.enc[i]:
451
+ if isinstance(block, ConvBlock): h = block(h, enc_cond)
452
+ elif isinstance(block, nn.Sequential):
453
+ h = block[0](h); h = block[1](h, enc_cond)
454
+ skips.append(h)
455
+ if i < len(self.enc_down):
456
+ h = self.enc_down[i](h)
457
+
458
+ # β˜… GLFM: Address β†’ Condition β†’ Generate β˜…
459
+ B = h.shape[0]
460
+ address = self.geo_encoder(h) # Stage 1
461
+ cond_addr = self.conditioner(address, t, cls_emb) # Stage 2
462
+ h = self.generator(cond_addr) # Stage 3
463
+ h = h.reshape(B, *self.mid_spatial)
464
+
465
+ # Decoder
466
+ for i in range(len(self.ch_mults)):
467
+ skip = skips.pop()
468
+ if i > 0:
469
+ h = self.dec_up[i - 1](h)
470
+ h = torch.cat([h, skip], dim=1)
471
+ h = self.dec_skip[i](h)
472
+ for block in self.dec[i]:
473
+ h = block(h, dec_cond)
474
+
475
+ return self.out_conv(F.silu(self.out_norm(h)))
476
+
477
+
478
+ # ══════════════════════════════════════════════════════════════════
479
+ # SAMPLING
480
+ # ══════════════════════════════════════════════════════════════════
481
+
482
+ @torch.no_grad()
483
+ def sample(model, n=64, steps=50, cls=None, n_cls=10):
484
+ model.eval()
485
+ x = torch.randn(n, 3, 32, 32, device=DEVICE)
486
+ labels = (torch.full((n,), cls, dtype=torch.long, device=DEVICE)
487
+ if cls is not None else torch.randint(0, n_cls, (n,), device=DEVICE))
488
+ dt = 1.0 / steps
489
+ for s in range(steps):
490
+ t = torch.full((n,), 1.0 - s * dt, device=DEVICE)
491
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
492
+ v = model(x, t, labels)
493
+ x = x - v.float() * dt
494
+ return x.clamp(-1, 1), labels
495
+
496
+
497
+ # ══════════════════════════════════════════════════════════════════
498
+ # TRAINING
499
+ # ══════════════════════════════════════════════════════════════════
500
+
501
+ BATCH = 128
502
+ EPOCHS = 80
503
+ LR = 3e-4
504
+ SAMPLE_EVERY = 5
505
+
506
+ print("=" * 70)
507
+ print("GEOMETRIC LOOKUP FLOW MATCHING (GLFM)")
508
+ print(f" Three-stage: Address β†’ Condition β†’ Generate")
509
+ print(f" Multi-scale: coarse (global) + fine (per-position)")
510
+ print(f" Device: {DEVICE}")
511
+ print("=" * 70)
512
+
513
+ transform = transforms.Compose([
514
+ transforms.RandomHorizontalFlip(),
515
+ transforms.ToTensor(),
516
+ transforms.Normalize((0.5,)*3, (0.5,)*3),
517
+ ])
518
+ train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
519
+ train_loader = torch.utils.data.DataLoader(
520
+ train_ds, batch_size=BATCH, shuffle=True,
521
+ num_workers=4, pin_memory=True, drop_last=True)
522
+
523
+ model = GLFMUNet(
524
+ in_ch=3, base_ch=64, ch_mults=(1, 2, 4),
525
+ n_classes=10, cond_dim=256, embed_dim=256,
526
+ n_anchors=16, n_phases=3,
527
+ gen_hidden=1024, gen_depth=4,
528
+ ).to(DEVICE)
529
+
530
+ n_params = sum(p.numel() for p in model.parameters())
531
+ n_geo = sum(p.numel() for p in model.geo_encoder.parameters())
532
+ n_cond = sum(p.numel() for p in model.conditioner.parameters())
533
+ n_gen = sum(p.numel() for p in model.generator.parameters())
534
+ n_anchor = sum(p.numel() for n, p in model.named_parameters() if 'anchor' in n)
535
+
536
+ print(f" Total: {n_params:,}")
537
+ print(f" Geo Encoder: {n_geo:,} (Stage 1 β€” address)")
538
+ print(f" Conditioner: {n_cond:,} (Stage 2 β€” fuse)")
539
+ print(f" Generator: {n_gen:,} (Stage 3 β€” velocity)")
540
+ print(f" Anchors: {n_anchor:,}")
541
+ print(f" Address dim: {model.geo_encoder.address_dim} "
542
+ f"(coarse {model.geo_encoder.tri_dim} + fine {model.geo_encoder.tri_dim})")
543
+ print(f" Compression: {model.generator.spatial_dim} β†’ "
544
+ f"{model.geo_encoder.address_dim} "
545
+ f"({model.generator.spatial_dim / model.geo_encoder.address_dim:.1f}Γ—)")
546
+
547
+ # Shape check
548
+ with torch.no_grad():
549
+ d = torch.randn(2, 3, 32, 32, device=DEVICE)
550
+ o = model(d, torch.rand(2, device=DEVICE), torch.randint(0, 10, (2,), device=DEVICE))
551
+ print(f" Shape: {d.shape} β†’ {o.shape} βœ“")
552
+ print(f" Train: {len(train_ds):,}")
553
+
554
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
555
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
556
+ optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6)
557
+ scaler = torch.amp.GradScaler("cuda")
558
+
559
+ os.makedirs("samples_glfm", exist_ok=True)
560
+ os.makedirs("checkpoints", exist_ok=True)
561
+
562
+ print(f"\n{'='*70}")
563
+ print(f"TRAINING β€” {EPOCHS} epochs")
564
+ print(f"{'='*70}")
565
+
566
+ best_loss = float('inf')
567
+ bn = model.geo_encoder # for diagnostics
568
+
569
+ for epoch in range(EPOCHS):
570
+ model.train()
571
+ t0 = time.time()
572
+ total_loss = 0
573
+ n = 0
574
+
575
+ pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b")
576
+ for images, labels in pbar:
577
+ images = images.to(DEVICE, non_blocking=True)
578
+ labels = labels.to(DEVICE, non_blocking=True)
579
+ B = images.shape[0]
580
+
581
+ t = torch.rand(B, device=DEVICE)
582
+ eps = torch.randn_like(images)
583
+ t_b = t.view(B, 1, 1, 1)
584
+ x_t = (1 - t_b) * images + t_b * eps
585
+ v_target = eps - images
586
+
587
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16):
588
+ v_pred = model(x_t, t, labels)
589
+ loss = F.mse_loss(v_pred, v_target)
590
+
591
+ optimizer.zero_grad(set_to_none=True)
592
+ scaler.scale(loss).backward()
593
+ scaler.unscale_(optimizer)
594
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
595
+ scaler.step(optimizer)
596
+ scaler.update()
597
+ scheduler.step()
598
+
599
+ total_loss += loss.item()
600
+ n += 1
601
+ if n % 20 == 0:
602
+ pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}")
603
+
604
+ elapsed = time.time() - t0
605
+ avg_loss = total_loss / n
606
+
607
+ mk = ""
608
+ if avg_loss < best_loss:
609
+ best_loss = avg_loss
610
+ torch.save({
611
+ 'state_dict': model.state_dict(),
612
+ 'epoch': epoch + 1, 'loss': avg_loss,
613
+ }, 'checkpoints/glfm_best.pt')
614
+ mk = " β˜…"
615
+
616
+ print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} "
617
+ f"({elapsed:.0f}s){mk}")
618
+
619
+ # Diagnostics
620
+ if (epoch + 1) % 10 == 0:
621
+ with torch.no_grad():
622
+ drift = bn.drift().detach()
623
+ near = (drift - 0.29154).abs().lt(0.05).float().mean().item()
624
+ crossed = (drift > 0.29154).float().mean().item()
625
+ print(f" β˜… drift: mean={drift.mean():.4f} max={drift.max():.4f} "
626
+ f"near_0.29={near:.1%} crossed={crossed:.1%}")
627
+
628
+ # Sample
629
+ if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:
630
+ imgs, _ = sample(model, 64, 50)
631
+ save_image(make_grid((imgs + 1) / 2, nrow=8), f'samples_glfm/epoch_{epoch+1:03d}.png')
632
+ print(f" β†’ samples_glfm/epoch_{epoch+1:03d}.png")
633
+
634
+ if (epoch + 1) % 20 == 0:
635
+ names = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck']
636
+ for c in range(10):
637
+ cs, _ = sample(model, 8, 50, cls=c)
638
+ save_image(make_grid((cs+1)/2, nrow=8),
639
+ f'samples_glfm/epoch_{epoch+1:03d}_{names[c]}.png')
640
+ print(f" β†’ per-class samples")
641
+
642
+ print(f"\n{'='*70}")
643
+ print(f"GEOMETRIC LOOKUP FLOW MATCHING β€” COMPLETE")
644
+ print(f" Best loss: {best_loss:.4f}")
645
+ print(f" Total: {n_params:,}")
646
+ with torch.no_grad():
647
+ drift = bn.drift().detach()
648
+ near = (drift - 0.29154).abs().lt(0.05).float().mean().item()
649
+ crossed = (drift > 0.29154).float().mean().item()
650
+ print(f" Final drift: mean={drift.mean():.4f} max={drift.max():.4f}")
651
+ print(f" Near 0.29: {near:.1%} Crossed: {crossed:.1%}")
652
+ print(f"{'='*70}")