AbstractPhil commited on
Commit
1daa7e1
Β·
verified Β·
1 Parent(s): 3b2c00f

Create prototype1_trainer.py

Browse files
Files changed (1) hide show
  1. prototype1_trainer.py +579 -0
prototype1_trainer.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Twin Stereo Diffusion β€” Fresnel Γ— Johanna Spectral Denoising
3
+ ==============================================================
4
+ Fresnel sees the clean image. Johanna sees the noise.
5
+ Procrustes alignment between their spectral bases IS the noise.
6
+
7
+ Training:
8
+ clean image ──→ Fresnel ──→ (U_f, S_f, Vt_f) target
9
+ noised image ──→ Johanna ──→ (U_j, S_j, Vt_j) input
10
+ R = Procrustes(U_j β†’ U_f) rotation = noise signature
11
+ Denoiser(S_j, R, t, labels) β†’ S_f predict clean magnitudes
12
+
13
+ Inference:
14
+ x_t ──→ Johanna ──→ S_j ──→ Denoiser ──→ S_pred
15
+ decode(U_j, S_pred, Vt_j) ──→ xΜ‚_0
16
+ flow step: x_{t-dt}
17
+ final pass: x_0 ──→ Fresnel encode/decode ──→ crisp output
18
+ """
19
+
20
+ import os
21
+ import math
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import torchvision
26
+ import torchvision.transforms as T
27
+ import numpy as np
28
+ from tqdm import tqdm
29
+
30
+ try:
31
+ from google.colab import userdata
32
+ os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
33
+ from huggingface_hub import login
34
+ login(token=os.environ["HF_TOKEN"])
35
+ except Exception:
36
+ pass
37
+
38
+
39
+ # ═══════════════════════════════════════════════════════════════
40
+ # FROZEN TWINS
41
+ # ═══════════════════════════════════════════════════════════════
42
+
43
+ def load_twins(device='cuda'):
44
+ """Load both frozen SVAE twins at 128Γ—128."""
45
+ from geolip_svae import load_model
46
+
47
+ fresnel, f_cfg = load_model(hf_version='v12_imagenet128', device=device)
48
+ fresnel.eval()
49
+ for p in fresnel.parameters():
50
+ p.requires_grad = False
51
+ print(f" Fresnel-small loaded: {sum(p.numel() for p in fresnel.parameters()):,} params (frozen)")
52
+
53
+ johanna, j_cfg = load_model(hf_version='v16_johanna_omega', device=device)
54
+ johanna.eval()
55
+ for p in johanna.parameters():
56
+ p.requires_grad = False
57
+ print(f" Johanna-small loaded: {sum(p.numel() for p in johanna.parameters()):,} params (frozen)")
58
+
59
+ return fresnel, johanna
60
+
61
+
62
+ # ═══════════════════════════════════════════════════════════════
63
+ # PROCRUSTES ALIGNMENT
64
+ # ═══════════════════════════════════════════════════════════════
65
+
66
+ def batched_procrustes(A, B):
67
+ """Find orthogonal R such that A @ R β‰ˆ B.
68
+
69
+ Args:
70
+ A: (batch, M, D) β€” source (Johanna's U)
71
+ B: (batch, M, D) β€” target (Fresnel's U)
72
+
73
+ Returns:
74
+ R: (batch, D, D) β€” orthogonal rotation
75
+ """
76
+ M = torch.bmm(B.transpose(-2, -1), A) # (batch, D, D)
77
+ U, S, Vt = torch.linalg.svd(M)
78
+ return torch.bmm(Vt.transpose(-2, -1), U.transpose(-2, -1))
79
+
80
+
81
+ def compute_procrustes_features(U_j, U_f, D=16):
82
+ """Compute per-patch Procrustes rotation and extract features.
83
+
84
+ Args:
85
+ U_j: (B, N, V, D) β€” Johanna's left singular vectors
86
+ U_f: (B, N, V, D) β€” Fresnel's left singular vectors
87
+
88
+ Returns:
89
+ R: (B, N, D, D) β€” rotation matrices
90
+ R_feat: (B, N, D*D) β€” flattened rotation for projection
91
+ """
92
+ B, N, V, D = U_j.shape
93
+ Uj = U_j.reshape(B * N, V, D)
94
+ Uf = U_f.reshape(B * N, V, D)
95
+ R = batched_procrustes(Uj, Uf) # (B*N, D, D)
96
+ R = R.reshape(B, N, D, D)
97
+ R_feat = R.reshape(B, N, D * D)
98
+ return R, R_feat
99
+
100
+
101
+ # ═══════════════════════════════════════════════════════════════
102
+ # TILED CIFAR-10 DATASET
103
+ # ═══════════════════════════════════════════════════════════════
104
+
105
+ CIFAR_MEAN = (0.4914, 0.4822, 0.4465)
106
+ CIFAR_STD = (0.2470, 0.2435, 0.2616)
107
+
108
+
109
+ class TiledCIFAR(torch.utils.data.Dataset):
110
+ """4 CIFAR-10 images (32β†’64) tiled 2Γ—2 into 128Γ—128."""
111
+
112
+ def __init__(self, train=True, n_samples=50000):
113
+ self.n_samples = n_samples
114
+ self.cifar = torchvision.datasets.CIFAR10(
115
+ root='./data', train=train, download=True,
116
+ transform=T.Compose([
117
+ T.Resize(64, interpolation=T.InterpolationMode.BILINEAR),
118
+ T.ToTensor(),
119
+ T.Normalize(CIFAR_MEAN, CIFAR_STD),
120
+ ]))
121
+ self.n = len(self.cifar)
122
+
123
+ def __len__(self):
124
+ return self.n_samples
125
+
126
+ def __getitem__(self, idx):
127
+ ids = torch.randint(0, self.n, (4,))
128
+ imgs, labels = [], []
129
+ for i in ids:
130
+ img, lab = self.cifar[i.item()]
131
+ imgs.append(img)
132
+ labels.append(lab)
133
+ top = torch.cat([imgs[0], imgs[1]], dim=2)
134
+ bot = torch.cat([imgs[2], imgs[3]], dim=2)
135
+ return torch.cat([top, bot], dim=1), torch.tensor(labels, dtype=torch.long)
136
+
137
+
138
+ # ═══════════════════════════════════════════════════════════════
139
+ # NOISE SCHEDULE
140
+ # ═══════════════════════════════════════════════════════════════
141
+
142
+ def add_noise(x0, t):
143
+ """Linear flow-matching interpolation: x_t = (1-t)*x0 + t*Ξ΅.
144
+
145
+ Args:
146
+ x0: (B, 3, 128, 128) clean images
147
+ t: (B,) timesteps in [0, 1]
148
+
149
+ Returns:
150
+ x_t: noised images
151
+ eps: the noise that was added
152
+ """
153
+ eps = torch.randn_like(x0)
154
+ t_exp = t.view(-1, 1, 1, 1)
155
+ x_t = (1 - t_exp) * x0 + t_exp * eps
156
+ return x_t, eps
157
+
158
+
159
+ # ═══════════════════════════════════════════════════════════════
160
+ # SPECTRAL DENOISER
161
+ # ═══════════════════════════════════════════════════════════════
162
+
163
+ class SinusoidalPosEmb(nn.Module):
164
+ def __init__(self, dim):
165
+ super().__init__()
166
+ self.dim = dim
167
+
168
+ def forward(self, t):
169
+ half = self.dim // 2
170
+ emb = math.log(10000) / (half - 1)
171
+ emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb)
172
+ emb = t.unsqueeze(1) * emb.unsqueeze(0)
173
+ return torch.cat([emb.sin(), emb.cos()], dim=1)
174
+
175
+
176
+ class AdaLN(nn.Module):
177
+ def __init__(self, dim, cond_dim):
178
+ super().__init__()
179
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
180
+ self.proj = nn.Linear(cond_dim, dim * 2)
181
+ nn.init.zeros_(self.proj.weight)
182
+ nn.init.zeros_(self.proj.bias)
183
+
184
+ def forward(self, x, cond):
185
+ s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1)
186
+ return self.norm(x) * (1 + s[0]) + s[1]
187
+
188
+
189
+ class StereoBlock(nn.Module):
190
+ """Transformer block with AdaLN and Procrustes-conditioned cross-path."""
191
+
192
+ def __init__(self, dim, n_heads, cond_dim):
193
+ super().__init__()
194
+ self.adaln1 = AdaLN(dim, cond_dim)
195
+ self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True)
196
+ self.adaln2 = AdaLN(dim, cond_dim)
197
+ self.ff = nn.Sequential(
198
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
199
+
200
+ def forward(self, x, cond):
201
+ h = self.adaln1(x, cond)
202
+ h, _ = self.attn(h, h, h)
203
+ x = x + h
204
+ return x + self.ff(self.adaln2(x, cond))
205
+
206
+
207
+ class StereoDenoiser(nn.Module):
208
+ """Predicts clean Fresnel omega tokens from noisy Johanna observations.
209
+
210
+ Input: S_j (B, N, D) β€” Johanna's singular values
211
+ R_feat (B, N, DΒ²) β€” Procrustes rotation features
212
+ t (B,) β€” noise level
213
+ labels (B, 4) β€” tile class labels
214
+
215
+ Output: S_f_pred (B, N, D) β€” predicted clean Fresnel singular values
216
+ """
217
+
218
+ def __init__(self, n_patches=64, omega_dim=16, hidden=256,
219
+ depth=8, n_heads=8, n_classes=10, n_tiles=4):
220
+ super().__init__()
221
+ self.omega_dim = omega_dim
222
+ D2 = omega_dim * omega_dim
223
+
224
+ # Input: omega tokens + Procrustes features
225
+ self.input_proj = nn.Linear(omega_dim + D2, hidden)
226
+ self.input_proj_no_R = nn.Linear(omega_dim, hidden)
227
+
228
+ # Positional embedding
229
+ self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02)
230
+
231
+ # Timestep embedding
232
+ self.time_emb = nn.Sequential(
233
+ SinusoidalPosEmb(hidden),
234
+ nn.Linear(hidden, hidden), nn.GELU(),
235
+ nn.Linear(hidden, hidden))
236
+
237
+ # Class embedding
238
+ self.class_emb = nn.Embedding(n_classes, hidden // n_tiles)
239
+ self.class_proj = nn.Linear(hidden, hidden)
240
+
241
+ # Transformer blocks
242
+ self.blocks = nn.ModuleList([
243
+ StereoBlock(hidden, n_heads, hidden) for _ in range(depth)])
244
+
245
+ # Output
246
+ self.out_norm = nn.LayerNorm(hidden)
247
+ self.out_proj = nn.Linear(hidden, omega_dim)
248
+ nn.init.zeros_(self.out_proj.weight)
249
+ nn.init.zeros_(self.out_proj.bias)
250
+
251
+ def forward(self, S_j, t, labels, R_feat=None):
252
+ B = S_j.shape[0]
253
+
254
+ # Project input (with or without Procrustes features)
255
+ if R_feat is not None:
256
+ h = self.input_proj(torch.cat([S_j, R_feat], dim=-1))
257
+ else:
258
+ h = self.input_proj_no_R(S_j)
259
+ h = h + self.pos_emb
260
+
261
+ # Conditioning
262
+ t_emb = self.time_emb(t)
263
+ c_emb = self.class_proj(self.class_emb(labels).reshape(B, -1))
264
+ cond = t_emb + c_emb
265
+
266
+ # Transformer
267
+ for block in self.blocks:
268
+ h = block(h, cond)
269
+
270
+ # Predict residual: S_f β‰ˆ S_j + correction
271
+ return S_j + self.out_proj(self.out_norm(h))
272
+
273
+
274
+ # ═══════════════════════════════════════════════════════════════
275
+ # TRAINING
276
+ # ═══════════════════════════════════════════════════════════════
277
+
278
+ def train(epochs=100, batch_size=64, lr=3e-4, hidden=256, depth=8,
279
+ n_heads=8, n_train=50000, device='cuda'):
280
+
281
+ device = torch.device(device if torch.cuda.is_available() else 'cpu')
282
+
283
+ print("\n" + "=" * 70)
284
+ print("TWIN STEREO DIFFUSION β€” Fresnel Γ— Johanna")
285
+ print("=" * 70)
286
+
287
+ # ── Load frozen twins ──
288
+ fresnel, johanna = load_twins(device)
289
+
290
+ # ── Data ──
291
+ train_ds = TiledCIFAR(train=True, n_samples=n_train)
292
+ val_ds = TiledCIFAR(train=False, n_samples=5000)
293
+ train_loader = torch.utils.data.DataLoader(
294
+ train_ds, batch_size=batch_size, shuffle=True,
295
+ num_workers=4, pin_memory=True, drop_last=True)
296
+ val_loader = torch.utils.data.DataLoader(
297
+ val_ds, batch_size=batch_size, shuffle=False,
298
+ num_workers=4, pin_memory=True)
299
+
300
+ # ── Denoiser ──
301
+ denoiser = StereoDenoiser(
302
+ n_patches=64, omega_dim=16, hidden=hidden,
303
+ depth=depth, n_heads=n_heads).to(device)
304
+
305
+ n_params = sum(p.numel() for p in denoiser.parameters())
306
+ print(f"\n StereoDenoiser: {n_params:,} params")
307
+ print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}")
308
+ print(f" Training: {n_train} samples, batch={batch_size}")
309
+ print(f" Pipeline: Johanna(noised) + Procrustes β†’ predict Fresnel(clean)")
310
+ print("=" * 70)
311
+
312
+ opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01)
313
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
314
+
315
+ save_dir = '/content/stereo_checkpoints'
316
+ os.makedirs(save_dir, exist_ok=True)
317
+ best_val = float('inf')
318
+
319
+ for epoch in range(1, epochs + 1):
320
+ denoiser.train()
321
+ total_loss, total_r_norm, n = 0, 0, 0
322
+
323
+ pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
324
+ bar_format='{l_bar}{bar:20}{r_bar}')
325
+ for images, labels in pbar:
326
+ images = images.to(device)
327
+ labels = labels.to(device)
328
+ B = images.shape[0]
329
+
330
+ # ── Sample timestep ──
331
+ t = torch.rand(B, device=device)
332
+
333
+ # ── Noise the image ──
334
+ x_noised, eps = add_noise(images, t)
335
+
336
+ # ── Encode through both twins ──
337
+ with torch.no_grad():
338
+ f_out = fresnel(images) # clean
339
+ j_out = johanna(x_noised) # noised
340
+
341
+ S_f = f_out['svd']['S'] # target: (B, 64, 16)
342
+ S_j = j_out['svd']['S'] # input: (B, 64, 16)
343
+
344
+ # ── Procrustes alignment ──
345
+ with torch.no_grad():
346
+ R, R_feat = compute_procrustes_features(
347
+ j_out['svd']['U'], f_out['svd']['U'])
348
+
349
+ # ── Predict clean omega tokens ──
350
+ S_pred = denoiser(S_j, t, labels, R_feat)
351
+ loss = F.mse_loss(S_pred, S_f)
352
+
353
+ opt.zero_grad()
354
+ loss.backward()
355
+ torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0)
356
+ opt.step()
357
+
358
+ total_loss += loss.item() * B
359
+ with torch.no_grad():
360
+ total_r_norm += (R - torch.eye(16, device=device)).norm(dim=(-2, -1)).mean().item() * B
361
+ n += B
362
+ pbar.set_postfix_str(f"loss={loss.item():.6f}")
363
+
364
+ sched.step()
365
+
366
+ # ── Validation ──
367
+ denoiser.eval()
368
+ val_loss, val_n = 0, 0
369
+ with torch.no_grad():
370
+ for images, labels in val_loader:
371
+ images, labels = images.to(device), labels.to(device)
372
+ B = images.shape[0]
373
+ t = torch.rand(B, device=device)
374
+ x_noised, _ = add_noise(images, t)
375
+ f_out = fresnel(images)
376
+ j_out = johanna(x_noised)
377
+ _, R_feat = compute_procrustes_features(
378
+ j_out['svd']['U'], f_out['svd']['U'])
379
+ S_pred = denoiser(j_out['svd']['S'], t, labels, R_feat)
380
+ val_loss += F.mse_loss(S_pred, f_out['svd']['S']).item() * B
381
+ val_n += B
382
+
383
+ train_l = total_loss / n
384
+ val_l = val_loss / val_n
385
+ r_norm = total_r_norm / n
386
+
387
+ if val_l < best_val:
388
+ best_val = val_l
389
+ torch.save({
390
+ 'epoch': epoch, 'val_loss': val_l,
391
+ 'model_state_dict': denoiser.state_dict(),
392
+ 'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads},
393
+ }, os.path.join(save_dir, 'best.pt'))
394
+
395
+ if epoch % 5 == 0 or epoch <= 5:
396
+ print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} "
397
+ f"best={best_val:.6f} ||R-I||={r_norm:.3f}")
398
+
399
+ # ── Sample ──
400
+ if epoch % 25 == 0:
401
+ sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir)
402
+
403
+ print(f"\n TRAINING COMPLETE β€” best val: {best_val:.6f}")
404
+ return denoiser
405
+
406
+
407
+ # ═══════════════════════════════════════════════════════════════
408
+ # SAMPLING β€” ITERATIVE STEREO DENOISING
409
+ # ═══════════════════════════════════════════════════════════════
410
+
411
+ @torch.no_grad()
412
+ def sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir,
413
+ n_samples=4, n_steps=50):
414
+ """Generate samples using iterative twin denoising.
415
+
416
+ 1. Start from pure noise x_T
417
+ 2. At each step:
418
+ a. Johanna encodes x_t β†’ (U_j, S_j, Vt_j)
419
+ b. Denoiser predicts clean S_f from S_j
420
+ c. Decode through Johanna's basis β†’ xΜ‚_0 estimate
421
+ d. Flow step toward xΜ‚_0
422
+ 3. Final pass: encode through Fresnel β†’ decode with clean basis
423
+ """
424
+ from geolip_svae.model import stitch_patches
425
+
426
+ denoiser.eval()
427
+
428
+ labels = torch.randint(0, 10, (n_samples, 4), device=device)
429
+ class_names = ['plane', 'car', 'bird', 'cat', 'deer',
430
+ 'dog', 'frog', 'horse', 'ship', 'truck']
431
+
432
+ # Start from noise
433
+ x = torch.randn(n_samples, 3, 128, 128, device=device)
434
+
435
+ for step in range(n_steps):
436
+ t_val = 1.0 - step / n_steps
437
+ t = torch.full((n_samples,), t_val, device=device)
438
+
439
+ # Johanna sees current state
440
+ j_out = johanna(x)
441
+ S_j = j_out['svd']['S']
442
+
443
+ # Denoiser predicts clean omega tokens (no R at inference)
444
+ S_pred = denoiser(S_j, t, labels, R_feat=None)
445
+
446
+ # Decode through Johanna's basis
447
+ decoded = johanna.decode_patches(
448
+ j_out['svd']['U'], S_pred, j_out['svd']['Vt'])
449
+ ps = johanna.patch_size
450
+ gh = gw = int(math.sqrt(S_j.shape[1]))
451
+ x_hat_0 = johanna.boundary_smooth(stitch_patches(decoded, gh, gw, ps))
452
+
453
+ # Flow step toward clean estimate
454
+ if step < n_steps - 1:
455
+ dt = 1.0 / n_steps
456
+ velocity = (x_hat_0 - x) / (t_val + 1e-4)
457
+ x = x - dt * velocity
458
+ else:
459
+ x = x_hat_0
460
+
461
+ # ── Final Fresnel polish ──
462
+ # Encode through Fresnel to get clean basis, re-decode
463
+ f_out = fresnel(x)
464
+ f_decoded = fresnel.decode_patches(
465
+ f_out['svd']['U'], f_out['svd']['S'], f_out['svd']['Vt'])
466
+ x_final = fresnel.boundary_smooth(stitch_patches(f_decoded, gh, gw, ps))
467
+
468
+ # ── Denormalize and save ──
469
+ mean = torch.tensor(CIFAR_MEAN).reshape(1, 3, 1, 1).to(device)
470
+ std = torch.tensor(CIFAR_STD).reshape(1, 3, 1, 1).to(device)
471
+
472
+ x_johanna = (x * std + mean).clamp(0, 1).cpu()
473
+ x_fresnel = (x_final * std + mean).clamp(0, 1).cpu()
474
+
475
+ import matplotlib
476
+ matplotlib.use('Agg')
477
+ import matplotlib.pyplot as plt
478
+
479
+ fig, axes = plt.subplots(n_samples, 2, figsize=(8, n_samples * 3))
480
+ if n_samples == 1:
481
+ axes = axes.unsqueeze(0)
482
+ for i in range(n_samples):
483
+ tile_labels = [class_names[l] for l in labels[i].cpu().tolist()]
484
+ axes[i, 0].imshow(x_johanna[i].permute(1, 2, 0).numpy())
485
+ axes[i, 0].set_title(f"Johanna decode: {tile_labels}", fontsize=7)
486
+ axes[i, 0].axis('off')
487
+ axes[i, 1].imshow(x_fresnel[i].permute(1, 2, 0).numpy())
488
+ axes[i, 1].set_title(f"Fresnel polish: {tile_labels}", fontsize=7)
489
+ axes[i, 1].axis('off')
490
+ plt.suptitle(f"Twin Stereo Diffusion β€” Epoch {epoch}", fontsize=10)
491
+ plt.tight_layout()
492
+ fname = os.path.join(save_dir, f'stereo_ep{epoch:03d}.png')
493
+ plt.savefig(fname, dpi=150, bbox_inches='tight')
494
+ plt.close()
495
+ print(f" Samples saved: {fname}")
496
+ print(f" Labels: {labels.cpu().tolist()}")
497
+
498
+
499
+ # ═══════════════════════════════════════════════════════════════
500
+ # ADVANCED SAMPLING β€” DUAL-ENCODE REFINEMENT
501
+ # ═══════════════════════════════════════════════════════════════
502
+
503
+ @torch.no_grad()
504
+ def sample_stereo_refined(denoiser, fresnel, johanna, labels, device,
505
+ n_steps=50):
506
+ """Two-pass refinement: use Fresnel to estimate R at inference.
507
+
508
+ At each step:
509
+ 1. Johanna(x_t) β†’ (U_j, S_j, Vt_j)
510
+ 2. Pass 1: Denoiser(S_j, t, labels) β†’ S_pred (no R)
511
+ 3. Decode β†’ xΜ‚_0, encode through Fresnel β†’ U_f_est
512
+ 4. R_est = Procrustes(U_j, U_f_est)
513
+ 5. Pass 2: Denoiser(S_j, t, labels, R_est) β†’ S_refined
514
+ 6. Decode through Fresnel's estimated basis β†’ x_{t-1}
515
+ """
516
+ from geolip_svae.model import stitch_patches
517
+
518
+ B = labels.shape[0]
519
+ x = torch.randn(B, 3, 128, 128, device=device)
520
+ ps = johanna.patch_size
521
+
522
+ for step in range(n_steps):
523
+ t_val = 1.0 - step / n_steps
524
+ t = torch.full((B,), t_val, device=device)
525
+
526
+ # Johanna encodes current state
527
+ j_out = johanna(x)
528
+ S_j = j_out['svd']['S']
529
+ gh = gw = int(math.sqrt(S_j.shape[1]))
530
+
531
+ # Pass 1: predict without R
532
+ S_pred_1 = denoiser(S_j, t, labels, R_feat=None)
533
+
534
+ # Decode pass 1 through Johanna
535
+ dec_1 = johanna.decode_patches(j_out['svd']['U'], S_pred_1, j_out['svd']['Vt'])
536
+ x_est = johanna.boundary_smooth(stitch_patches(dec_1, gh, gw, ps))
537
+
538
+ # Fresnel sees the estimate β†’ get clean-style basis
539
+ f_est = fresnel(x_est)
540
+
541
+ # Procrustes: how far is Johanna's basis from Fresnel's?
542
+ _, R_feat = compute_procrustes_features(
543
+ j_out['svd']['U'], f_est['svd']['U'])
544
+
545
+ # Pass 2: predict WITH R conditioning
546
+ S_pred_2 = denoiser(S_j, t, labels, R_feat)
547
+
548
+ # Decode through Fresnel's estimated basis
549
+ dec_2 = fresnel.decode_patches(
550
+ f_est['svd']['U'], S_pred_2, f_est['svd']['Vt'])
551
+ x_clean = fresnel.boundary_smooth(stitch_patches(dec_2, gh, gw, ps))
552
+
553
+ # Flow step
554
+ if step < n_steps - 1:
555
+ dt = 1.0 / n_steps
556
+ velocity = (x_clean - x) / (t_val + 1e-4)
557
+ x = x - dt * velocity
558
+ else:
559
+ x = x_clean
560
+
561
+ return x
562
+
563
+
564
+ # ═══════════════════════════════════════════════════════════════
565
+ # CLI
566
+ # ═══════════════════════════════════════════════════════════════
567
+
568
+ if __name__ == "__main__":
569
+ torch.set_float32_matmul_precision('high')
570
+
571
+ train(
572
+ epochs=100,
573
+ batch_size=64, # 2 VAE forwards per batch, keep it moderate
574
+ lr=3e-4,
575
+ hidden=256,
576
+ depth=8,
577
+ n_heads=8,
578
+ n_train=50000,
579
+ )