AbstractPhil commited on
Commit
1efed72
Β·
verified Β·
1 Parent(s): ea1ca83

Create prototype2_trainer.py

Browse files
Files changed (1) hide show
  1. prototype2_trainer.py +569 -0
prototype2_trainer.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Twin Stereo Diffusion β€” Fresnel Γ— Johanna Spectral Denoising
3
+ ==============================================================
4
+ Fresnel sees the clean image. Johanna sees the noise.
5
+ Procrustes alignment between their spectral bases IS the noise.
6
+
7
+ Training:
8
+ clean image ──→ Fresnel ──→ (U_f, S_f, Vt_f) target
9
+ noised image ──→ Johanna ──→ (U_j, S_j, Vt_j) input
10
+ R = Procrustes(U_j β†’ U_f) rotation = noise signature
11
+ Denoiser(S_j, R, t, labels) β†’ S_f predict clean magnitudes
12
+
13
+ Inference:
14
+ x_t ──→ Johanna ──→ S_j ──→ Denoiser ──→ S_pred
15
+ decode(U_j, S_pred, Vt_j) ──→ xΜ‚_0
16
+ flow step: x_{t-dt}
17
+ final pass: x_0 ──→ Fresnel encode/decode ──→ crisp output
18
+ """
19
+
20
+ import os
21
+ import math
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ import torchvision
26
+ import torchvision.transforms as T
27
+ import numpy as np
28
+ from tqdm import tqdm
29
+
30
+ try:
31
+ from google.colab import userdata
32
+ os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
33
+ from huggingface_hub import login
34
+ login(token=os.environ["HF_TOKEN"])
35
+ except Exception:
36
+ pass
37
+
38
+
39
+ # ═══════════════════════════════════════════════════════════════
40
+ # FROZEN TWINS
41
+ # ═══════════════════════════════════════════════════════════════
42
+
43
+ def load_twins(device='cuda'):
44
+ """Load both frozen SVAE twins at 128Γ—128."""
45
+ from geolip_svae import load_model
46
+
47
+ fresnel, f_cfg = load_model(hf_version='v12_imagenet128', device=device)
48
+ fresnel.eval()
49
+ for p in fresnel.parameters():
50
+ p.requires_grad = False
51
+ print(f" Fresnel-small loaded: {sum(p.numel() for p in fresnel.parameters()):,} params (frozen)")
52
+
53
+ johanna, j_cfg = load_model(hf_version='v16_johanna_omega', device=device)
54
+ johanna.eval()
55
+ for p in johanna.parameters():
56
+ p.requires_grad = False
57
+ print(f" Johanna-small loaded: {sum(p.numel() for p in johanna.parameters()):,} params (frozen)")
58
+
59
+ return fresnel, johanna
60
+
61
+
62
+ # ═══════════════════════════════════════════════════════════════
63
+ # PROCRUSTES ALIGNMENT
64
+ # ═══════════════════════════════════════════════════════════════
65
+
66
+ def batched_procrustes(A, B):
67
+ """Find orthogonal R such that A @ R β‰ˆ B.
68
+
69
+ Args:
70
+ A: (batch, M, D) β€” source (Johanna's U)
71
+ B: (batch, M, D) β€” target (Fresnel's U)
72
+
73
+ Returns:
74
+ R: (batch, D, D) β€” orthogonal rotation
75
+ """
76
+ M = torch.bmm(B.transpose(-2, -1), A) # (batch, D, D)
77
+ U, S, Vt = torch.linalg.svd(M)
78
+ return torch.bmm(Vt.transpose(-2, -1), U.transpose(-2, -1))
79
+
80
+
81
+ def compute_procrustes_features(U_j, U_f, D=16):
82
+ """Compute per-patch Procrustes rotation and extract features.
83
+
84
+ Args:
85
+ U_j: (B, N, V, D) β€” Johanna's left singular vectors
86
+ U_f: (B, N, V, D) β€” Fresnel's left singular vectors
87
+
88
+ Returns:
89
+ R: (B, N, D, D) β€” rotation matrices
90
+ R_feat: (B, N, D*D) β€” flattened rotation for projection
91
+ """
92
+ B, N, V, D = U_j.shape
93
+ Uj = U_j.reshape(B * N, V, D)
94
+ Uf = U_f.reshape(B * N, V, D)
95
+ R = batched_procrustes(Uj, Uf) # (B*N, D, D)
96
+ R = R.reshape(B, N, D, D)
97
+ R_feat = R.reshape(B, N, D * D)
98
+ return R, R_feat
99
+
100
+
101
+ # ═══════════════════════════════════════════════════════════════
102
+ # TINY IMAGENET DATASET (64β†’128)
103
+ # ═══════════════════════════════════════════════════════════════
104
+
105
+ IMG_MEAN = (0.4802, 0.4481, 0.3975)
106
+ IMG_STD = (0.2770, 0.2691, 0.2821)
107
+
108
+
109
+ class TinyImageNet128(torch.utils.data.Dataset):
110
+ """TinyImageNet (200 classes, 64Γ—64) upscaled to 128Γ—128."""
111
+
112
+ def __init__(self, split='train'):
113
+ from datasets import load_dataset
114
+ self.ds = load_dataset('zh-plus/tiny-imagenet', split=split)
115
+ self.transform = T.Compose([
116
+ T.Resize(128, interpolation=T.InterpolationMode.BILINEAR),
117
+ T.ToTensor(),
118
+ T.Normalize(IMG_MEAN, IMG_STD),
119
+ ])
120
+
121
+ def __len__(self):
122
+ return len(self.ds)
123
+
124
+ def __getitem__(self, idx):
125
+ item = self.ds[idx]
126
+ img = item['image']
127
+ if img.mode != 'RGB':
128
+ img = img.convert('RGB')
129
+ return self.transform(img), item['label']
130
+
131
+
132
+ # ═════��═════════════════════════════════════════════════════════
133
+ # NOISE SCHEDULE
134
+ # ═══════════════════════════════════════════════════════════════
135
+
136
+ def add_noise(x0, t):
137
+ """Linear flow-matching interpolation: x_t = (1-t)*x0 + t*Ξ΅.
138
+
139
+ Args:
140
+ x0: (B, 3, 128, 128) clean images
141
+ t: (B,) timesteps in [0, 1]
142
+
143
+ Returns:
144
+ x_t: noised images
145
+ eps: the noise that was added
146
+ """
147
+ eps = torch.randn_like(x0)
148
+ t_exp = t.view(-1, 1, 1, 1)
149
+ x_t = (1 - t_exp) * x0 + t_exp * eps
150
+ return x_t, eps
151
+
152
+
153
+ # ═══════════════════════════════════════════════════════════════
154
+ # SPECTRAL DENOISER
155
+ # ═══════════════════════════════════════════════════════════════
156
+
157
+ class SinusoidalPosEmb(nn.Module):
158
+ def __init__(self, dim):
159
+ super().__init__()
160
+ self.dim = dim
161
+
162
+ def forward(self, t):
163
+ half = self.dim // 2
164
+ emb = math.log(10000) / (half - 1)
165
+ emb = torch.exp(torch.arange(half, device=t.device, dtype=torch.float) * -emb)
166
+ emb = t.unsqueeze(1) * emb.unsqueeze(0)
167
+ return torch.cat([emb.sin(), emb.cos()], dim=1)
168
+
169
+
170
+ class AdaLN(nn.Module):
171
+ def __init__(self, dim, cond_dim):
172
+ super().__init__()
173
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
174
+ self.proj = nn.Linear(cond_dim, dim * 2)
175
+ nn.init.zeros_(self.proj.weight)
176
+ nn.init.zeros_(self.proj.bias)
177
+
178
+ def forward(self, x, cond):
179
+ s = self.proj(cond).unsqueeze(1).chunk(2, dim=-1)
180
+ return self.norm(x) * (1 + s[0]) + s[1]
181
+
182
+
183
+ class StereoBlock(nn.Module):
184
+ """Transformer block with AdaLN and Procrustes-conditioned cross-path."""
185
+
186
+ def __init__(self, dim, n_heads, cond_dim):
187
+ super().__init__()
188
+ self.adaln1 = AdaLN(dim, cond_dim)
189
+ self.attn = nn.MultiheadAttention(dim, n_heads, batch_first=True)
190
+ self.adaln2 = AdaLN(dim, cond_dim)
191
+ self.ff = nn.Sequential(
192
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
193
+
194
+ def forward(self, x, cond):
195
+ h = self.adaln1(x, cond)
196
+ h, _ = self.attn(h, h, h)
197
+ x = x + h
198
+ return x + self.ff(self.adaln2(x, cond))
199
+
200
+
201
+ class StereoDenoiser(nn.Module):
202
+ """Predicts clean Fresnel omega tokens from noisy Johanna observations.
203
+
204
+ Input: S_j (B, N, D) β€” Johanna's singular values
205
+ R_feat (B, N, DΒ²) β€” Procrustes rotation features
206
+ t (B,) β€” noise level
207
+ labels (B,) β€” class labels
208
+
209
+ Output: S_f_pred (B, N, D) β€” predicted clean Fresnel singular values
210
+ """
211
+
212
+ def __init__(self, n_patches=64, omega_dim=16, hidden=256,
213
+ depth=8, n_heads=8, n_classes=200):
214
+ super().__init__()
215
+ self.omega_dim = omega_dim
216
+ D2 = omega_dim * omega_dim
217
+
218
+ # Input: omega tokens + Procrustes features
219
+ self.input_proj = nn.Linear(omega_dim + D2, hidden)
220
+ self.input_proj_no_R = nn.Linear(omega_dim, hidden)
221
+
222
+ # Positional embedding
223
+ self.pos_emb = nn.Parameter(torch.randn(1, n_patches, hidden) * 0.02)
224
+
225
+ # Timestep embedding
226
+ self.time_emb = nn.Sequential(
227
+ SinusoidalPosEmb(hidden),
228
+ nn.Linear(hidden, hidden), nn.GELU(),
229
+ nn.Linear(hidden, hidden))
230
+
231
+ # Class embedding: single label β†’ hidden
232
+ self.class_emb = nn.Embedding(n_classes, hidden)
233
+
234
+ # Transformer blocks
235
+ self.blocks = nn.ModuleList([
236
+ StereoBlock(hidden, n_heads, hidden) for _ in range(depth)])
237
+
238
+ # Output
239
+ self.out_norm = nn.LayerNorm(hidden)
240
+ self.out_proj = nn.Linear(hidden, omega_dim)
241
+ nn.init.zeros_(self.out_proj.weight)
242
+ nn.init.zeros_(self.out_proj.bias)
243
+
244
+ def forward(self, S_j, t, labels, R_feat=None):
245
+ B = S_j.shape[0]
246
+
247
+ # Project input (with or without Procrustes features)
248
+ if R_feat is not None:
249
+ h = self.input_proj(torch.cat([S_j, R_feat], dim=-1))
250
+ else:
251
+ h = self.input_proj_no_R(S_j)
252
+ h = h + self.pos_emb
253
+
254
+ # Conditioning
255
+ t_emb = self.time_emb(t)
256
+ c_emb = self.class_emb(labels) # (B, hidden)
257
+ cond = t_emb + c_emb
258
+
259
+ # Transformer
260
+ for block in self.blocks:
261
+ h = block(h, cond)
262
+
263
+ # Predict residual: S_f β‰ˆ S_j + correction
264
+ return S_j + self.out_proj(self.out_norm(h))
265
+
266
+
267
+ # ═══════════════════════════════��═══════════════════════════════
268
+ # TRAINING
269
+ # ═══════════════════════════════════════════════════════════════
270
+
271
+ def train(epochs=100, batch_size=64, lr=3e-4, hidden=256, depth=8,
272
+ n_heads=8, device='cuda'):
273
+
274
+ device = torch.device(device if torch.cuda.is_available() else 'cpu')
275
+
276
+ print("\n" + "=" * 70)
277
+ print("TWIN STEREO DIFFUSION β€” Fresnel Γ— Johanna")
278
+ print("=" * 70)
279
+
280
+ # ── Load frozen twins ──
281
+ fresnel, johanna = load_twins(device)
282
+
283
+ # ── Data ──
284
+ print("\n Loading TinyImageNet...")
285
+ train_ds = TinyImageNet128(split='train')
286
+ val_ds = TinyImageNet128(split='valid')
287
+ train_loader = torch.utils.data.DataLoader(
288
+ train_ds, batch_size=batch_size, shuffle=True,
289
+ num_workers=4, pin_memory=True, drop_last=True)
290
+ val_loader = torch.utils.data.DataLoader(
291
+ val_ds, batch_size=batch_size, shuffle=False,
292
+ num_workers=4, pin_memory=True)
293
+
294
+ # ── Denoiser ──
295
+ denoiser = StereoDenoiser(
296
+ n_patches=64, omega_dim=16, hidden=hidden,
297
+ depth=depth, n_heads=n_heads).to(device)
298
+
299
+ n_params = sum(p.numel() for p in denoiser.parameters())
300
+ print(f"\n StereoDenoiser: {n_params:,} params")
301
+ print(f" Hidden={hidden}, Depth={depth}, Heads={n_heads}")
302
+ print(f" Dataset: TinyImageNet 200 classes, {len(train_ds)} train, {len(val_ds)} val")
303
+ print(f" Pipeline: Johanna(noised) + Procrustes β†’ predict Fresnel(clean)")
304
+ print("=" * 70)
305
+
306
+ opt = torch.optim.AdamW(denoiser.parameters(), lr=lr, weight_decay=0.01)
307
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
308
+
309
+ save_dir = '/content/stereo_checkpoints'
310
+ os.makedirs(save_dir, exist_ok=True)
311
+ best_val = float('inf')
312
+
313
+ for epoch in range(1, epochs + 1):
314
+ denoiser.train()
315
+ total_loss, total_r_norm, n = 0, 0, 0
316
+
317
+ pbar = tqdm(train_loader, desc=f"Ep {epoch}/{epochs}",
318
+ bar_format='{l_bar}{bar:20}{r_bar}')
319
+ for images, labels in pbar:
320
+ images = images.to(device)
321
+ labels = labels.to(device)
322
+ B = images.shape[0]
323
+
324
+ # ── Sample timestep ──
325
+ t = torch.rand(B, device=device)
326
+
327
+ # ── Noise the image ──
328
+ x_noised, eps = add_noise(images, t)
329
+
330
+ # ── Encode through both twins ──
331
+ with torch.no_grad():
332
+ f_out = fresnel(images) # clean
333
+ j_out = johanna(x_noised) # noised
334
+
335
+ S_f = f_out['svd']['S'] # target: (B, 64, 16)
336
+ S_j = j_out['svd']['S'] # input: (B, 64, 16)
337
+
338
+ # ── Procrustes alignment ──
339
+ with torch.no_grad():
340
+ R, R_feat = compute_procrustes_features(
341
+ j_out['svd']['U'], f_out['svd']['U'])
342
+
343
+ # ── Predict clean omega tokens ──
344
+ # R dropout: 20% of batches train without R (for inference path)
345
+ use_R = torch.rand(1).item() > 0.2
346
+ S_pred = denoiser(S_j, t, labels, R_feat if use_R else None)
347
+ loss = F.mse_loss(S_pred, S_f)
348
+
349
+ opt.zero_grad()
350
+ loss.backward()
351
+ torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=1.0)
352
+ opt.step()
353
+
354
+ total_loss += loss.item() * B
355
+ with torch.no_grad():
356
+ total_r_norm += (R - torch.eye(16, device=device)).norm(dim=(-2, -1)).mean().item() * B
357
+ n += B
358
+ pbar.set_postfix_str(f"loss={loss.item():.6f}")
359
+
360
+ sched.step()
361
+
362
+ # ── Validation ──
363
+ denoiser.eval()
364
+ val_loss, val_n = 0, 0
365
+ with torch.no_grad():
366
+ for images, labels in val_loader:
367
+ images, labels = images.to(device), labels.to(device)
368
+ B = images.shape[0]
369
+ t = torch.rand(B, device=device)
370
+ x_noised, _ = add_noise(images, t)
371
+ f_out = fresnel(images)
372
+ j_out = johanna(x_noised)
373
+ _, R_feat = compute_procrustes_features(
374
+ j_out['svd']['U'], f_out['svd']['U'])
375
+ S_pred = denoiser(j_out['svd']['S'], t, labels, R_feat)
376
+ val_loss += F.mse_loss(S_pred, f_out['svd']['S']).item() * B
377
+ val_n += B
378
+
379
+ train_l = total_loss / n
380
+ val_l = val_loss / val_n
381
+ r_norm = total_r_norm / n
382
+
383
+ if val_l < best_val:
384
+ best_val = val_l
385
+ torch.save({
386
+ 'epoch': epoch, 'val_loss': val_l,
387
+ 'model_state_dict': denoiser.state_dict(),
388
+ 'config': {'hidden': hidden, 'depth': depth, 'n_heads': n_heads},
389
+ }, os.path.join(save_dir, 'best.pt'))
390
+
391
+ print(f" ep{epoch:3d} | loss={train_l:.6f} val={val_l:.6f} "
392
+ f"best={best_val:.6f} ||R-I||={r_norm:.3f}")
393
+
394
+ # ── Sample every epoch ──
395
+ sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir)
396
+
397
+ print(f"\n TRAINING COMPLETE β€” best val: {best_val:.6f}")
398
+ return denoiser
399
+
400
+
401
+ # ═══════════════════════════════════════════════════════════════
402
+ # SAMPLING β€” ITERATIVE STEREO DENOISING
403
+ # ═══════════════════════════════════════════════════════════════
404
+
405
+ @torch.no_grad()
406
+ def sample_stereo(denoiser, fresnel, johanna, device, epoch, save_dir,
407
+ n_samples=4, n_steps=50):
408
+ """Generate samples using iterative twin denoising.
409
+
410
+ 1. Start from pure noise x_T
411
+ 2. At each step:
412
+ a. Johanna encodes x_t β†’ (U_j, S_j, Vt_j)
413
+ b. Denoiser predicts clean S_f from S_j
414
+ c. Decode through Johanna's basis β†’ xΜ‚_0 estimate
415
+ d. Flow step toward xΜ‚_0
416
+ 3. Final pass: encode through Fresnel β†’ decode with clean basis
417
+ """
418
+ from geolip_svae.model import stitch_patches
419
+
420
+ denoiser.eval()
421
+
422
+ labels = torch.randint(0, 200, (n_samples,), device=device)
423
+
424
+ # Start from noise
425
+ x = torch.randn(n_samples, 3, 128, 128, device=device)
426
+
427
+ for step in range(n_steps):
428
+ t_val = 1.0 - step / n_steps
429
+ t = torch.full((n_samples,), t_val, device=device)
430
+
431
+ # Johanna sees current state
432
+ j_out = johanna(x)
433
+ S_j = j_out['svd']['S']
434
+
435
+ # Denoiser predicts clean omega tokens (no R at inference)
436
+ S_pred = denoiser(S_j, t, labels, R_feat=None)
437
+
438
+ # Decode through Johanna's basis
439
+ decoded = johanna.decode_patches(
440
+ j_out['svd']['U'], S_pred, j_out['svd']['Vt'])
441
+ ps = johanna.patch_size
442
+ gh = gw = int(math.sqrt(S_j.shape[1]))
443
+ x_hat_0 = johanna.boundary_smooth(stitch_patches(decoded, gh, gw, ps))
444
+
445
+ # Flow step toward clean estimate
446
+ if step < n_steps - 1:
447
+ dt = 1.0 / n_steps
448
+ velocity = (x_hat_0 - x) / (t_val + 1e-4)
449
+ x = x + dt * velocity
450
+ else:
451
+ x = x_hat_0
452
+
453
+ # ── Final Fresnel polish ──
454
+ f_out = fresnel(x)
455
+ f_decoded = fresnel.decode_patches(
456
+ f_out['svd']['U'], f_out['svd']['S'], f_out['svd']['Vt'])
457
+ x_final = fresnel.boundary_smooth(stitch_patches(f_decoded, gh, gw, ps))
458
+
459
+ # ── Denormalize and save ──
460
+ mean = torch.tensor(IMG_MEAN).reshape(1, 3, 1, 1).to(device)
461
+ std = torch.tensor(IMG_STD).reshape(1, 3, 1, 1).to(device)
462
+
463
+ x_johanna = (x * std + mean).clamp(0, 1).cpu()
464
+ x_fresnel = (x_final * std + mean).clamp(0, 1).cpu()
465
+
466
+ import matplotlib
467
+ matplotlib.use('Agg')
468
+ import matplotlib.pyplot as plt
469
+
470
+ fig, axes = plt.subplots(n_samples, 2, figsize=(8, n_samples * 3))
471
+ if n_samples == 1:
472
+ axes = axes.reshape(1, -1)
473
+ for i in range(n_samples):
474
+ cls = labels[i].item()
475
+ axes[i, 0].imshow(x_johanna[i].permute(1, 2, 0).numpy())
476
+ axes[i, 0].set_title(f"Johanna decode: class {cls}", fontsize=8)
477
+ axes[i, 0].axis('off')
478
+ axes[i, 1].imshow(x_fresnel[i].permute(1, 2, 0).numpy())
479
+ axes[i, 1].set_title(f"Fresnel polish: class {cls}", fontsize=8)
480
+ axes[i, 1].axis('off')
481
+ plt.suptitle(f"Twin Stereo Diffusion β€” Epoch {epoch}", fontsize=10)
482
+ plt.tight_layout()
483
+ fname = os.path.join(save_dir, f'stereo_ep{epoch:03d}.png')
484
+ plt.savefig(fname, dpi=150, bbox_inches='tight')
485
+ plt.close()
486
+ print(f" Samples saved: {fname}")
487
+ print(f" Labels: {labels.cpu().tolist()}")
488
+
489
+
490
+ # ═══════════════════════════════════════════════════════════════
491
+ # ADVANCED SAMPLING β€” DUAL-ENCODE REFINEMENT
492
+ # ═══════════════════════════════════════════════════════════════
493
+
494
+ @torch.no_grad()
495
+ def sample_stereo_refined(denoiser, fresnel, johanna, labels, device,
496
+ n_steps=50):
497
+ """Two-pass refinement: use Fresnel to estimate R at inference.
498
+
499
+ At each step:
500
+ 1. Johanna(x_t) β†’ (U_j, S_j, Vt_j)
501
+ 2. Pass 1: Denoiser(S_j, t, labels) β†’ S_pred (no R)
502
+ 3. Decode β†’ xΜ‚_0, encode through Fresnel β†’ U_f_est
503
+ 4. R_est = Procrustes(U_j, U_f_est)
504
+ 5. Pass 2: Denoiser(S_j, t, labels, R_est) β†’ S_refined
505
+ 6. Decode through Fresnel's estimated basis β†’ x_{t-1}
506
+ """
507
+ from geolip_svae.model import stitch_patches
508
+
509
+ B = labels.shape[0]
510
+ x = torch.randn(B, 3, 128, 128, device=device)
511
+ ps = johanna.patch_size
512
+
513
+ for step in range(n_steps):
514
+ t_val = 1.0 - step / n_steps
515
+ t = torch.full((B,), t_val, device=device)
516
+
517
+ # Johanna encodes current state
518
+ j_out = johanna(x)
519
+ S_j = j_out['svd']['S']
520
+ gh = gw = int(math.sqrt(S_j.shape[1]))
521
+
522
+ # Pass 1: predict without R
523
+ S_pred_1 = denoiser(S_j, t, labels, R_feat=None)
524
+
525
+ # Decode pass 1 through Johanna
526
+ dec_1 = johanna.decode_patches(j_out['svd']['U'], S_pred_1, j_out['svd']['Vt'])
527
+ x_est = johanna.boundary_smooth(stitch_patches(dec_1, gh, gw, ps))
528
+
529
+ # Fresnel sees the estimate β†’ get clean-style basis
530
+ f_est = fresnel(x_est)
531
+
532
+ # Procrustes: how far is Johanna's basis from Fresnel's?
533
+ _, R_feat = compute_procrustes_features(
534
+ j_out['svd']['U'], f_est['svd']['U'])
535
+
536
+ # Pass 2: predict WITH R conditioning
537
+ S_pred_2 = denoiser(S_j, t, labels, R_feat)
538
+
539
+ # Decode through Fresnel's estimated basis
540
+ dec_2 = fresnel.decode_patches(
541
+ f_est['svd']['U'], S_pred_2, f_est['svd']['Vt'])
542
+ x_clean = fresnel.boundary_smooth(stitch_patches(dec_2, gh, gw, ps))
543
+
544
+ # Flow step
545
+ if step < n_steps - 1:
546
+ dt = 1.0 / n_steps
547
+ velocity = (x_clean - x) / (t_val + 1e-4)
548
+ x = x + dt * velocity
549
+ else:
550
+ x = x_clean
551
+
552
+ return x
553
+
554
+
555
+ # ═══════════════════════════════════════════════════════════════
556
+ # CLI
557
+ # ═══════════════════════════════════════════════════════════════
558
+
559
+ if __name__ == "__main__":
560
+ torch.set_float32_matmul_precision('high')
561
+
562
+ train(
563
+ epochs=100,
564
+ batch_size=64,
565
+ lr=3e-4,
566
+ hidden=256,
567
+ depth=8,
568
+ n_heads=8,
569
+ )