AbstractPhil commited on
Commit
f112949
·
verified ·
1 Parent(s): 8e45131

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +1074 -0
trainer.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BEATRIX FLOW-MATCHING - CIFAR-10 (T5 Text Encoder)
3
+ ===================================================
4
+
5
+ SD 1.5 VAE + Flan-T5-Large text encoder
6
+ Dual tower collectives: vision towers + text towers
7
+
8
+ Text prompts for CIFAR-10 classes:
9
+ "a photo of an airplane"
10
+ "a photo of an automobile"
11
+ etc.
12
+
13
+ Requirements:
14
+ pip install transformers diffusers torchvision tqdm
15
+ pip install git+https://github.com/AbstractEyes/geofractal
16
+
17
+ apache license
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import math
23
+ from dataclasses import dataclass
24
+ from typing import Dict, Tuple, Optional, List
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch import Tensor
31
+ from torch.utils.data import DataLoader, Dataset
32
+ from torchvision import datasets, transforms
33
+ from torchvision.utils import make_grid, save_image
34
+ from huggingface_hub import HfApi, upload_file, create_repo
35
+ import json
36
+ from tqdm import tqdm
37
+
38
+ # =============================================================================
39
+ # GEOFRACTAL IMPORTS
40
+ # =============================================================================
41
+
42
+ from geofractal.router.wide_router import WideRouter
43
+ from geofractal.router.prefab.agatha.beatrix_tension_oscillator import (
44
+ BeatrixOscillator,
45
+ ScheduleType,
46
+ )
47
+ from geofractal.router.prefab.geometric_tower_builder import (
48
+ TowerConfig,
49
+ FusionType,
50
+ ConfigurableCollective,
51
+ build_tower_collective,
52
+ preset_pos_neg_pairs,
53
+ )
54
+ from geofractal.router.prefab.geometric_conv_tower_builder import (
55
+ ConvTowerConfig,
56
+ ConvTowerCollective,
57
+ build_conv_collective,
58
+ preset_conv_pos_neg,
59
+ )
60
+
61
+
62
+ # =============================================================================
63
+ # CIFAR-10 CLASS PROMPTS
64
+ # =============================================================================
65
+
66
+ CIFAR10_PROMPTS = [
67
+ "a photo of an airplane",
68
+ "a photo of an automobile",
69
+ "a photo of a bird",
70
+ "a photo of a cat",
71
+ "a photo of a deer",
72
+ "a photo of a dog",
73
+ "a photo of a frog",
74
+ "a photo of a horse",
75
+ "a photo of a ship",
76
+ "a photo of a truck",
77
+ ]
78
+
79
+
80
+ # =============================================================================
81
+ # SD 1.5 VAE
82
+ # =============================================================================
83
+
84
+ class SD15VAE(nn.Module):
85
+ def __init__(self, freeze: bool = True):
86
+ super().__init__()
87
+ from diffusers import AutoencoderKL
88
+
89
+ self.vae = AutoencoderKL.from_pretrained(
90
+ "runwayml/stable-diffusion-v1-5",
91
+ subfolder="vae",
92
+ torch_dtype=torch.float32,
93
+ )
94
+
95
+ if freeze:
96
+ self.vae.eval()
97
+ for p in self.vae.parameters():
98
+ p.requires_grad = False
99
+
100
+ self.scale_factor = 0.18215
101
+
102
+ @torch.no_grad()
103
+ def encode(self, x: Tensor) -> Tensor:
104
+ return self.vae.encode(x).latent_dist.sample() * self.scale_factor
105
+
106
+ @torch.no_grad()
107
+ def decode(self, z: Tensor) -> Tensor:
108
+ return self.vae.decode(z / self.scale_factor).sample
109
+
110
+
111
+ # =============================================================================
112
+ # FLAN-T5-LARGE TEXT ENCODER
113
+ # =============================================================================
114
+
115
+ class T5TextEncoder(nn.Module):
116
+ """Flan-T5 encoder with bottleneck projection."""
117
+
118
+ def __init__(
119
+ self,
120
+ model_name: str = "google/flan-t5-xl",
121
+ freeze: bool = True,
122
+ max_length: int = 77,
123
+ bottleneck_dim: int = 256,
124
+ ):
125
+ super().__init__()
126
+ from transformers import T5EncoderModel, T5Tokenizer
127
+
128
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
129
+ self.encoder = T5EncoderModel.from_pretrained(model_name)
130
+ self.max_length = max_length
131
+ self.raw_dim = self.encoder.config.d_model # 2048 for XL
132
+ self.output_dim = bottleneck_dim
133
+
134
+ # Bottleneck projection
135
+ self.bottleneck = nn.Sequential(
136
+ nn.Linear(self.raw_dim, bottleneck_dim),
137
+ nn.GELU(),
138
+ nn.Linear(bottleneck_dim, bottleneck_dim),
139
+ )
140
+
141
+ if freeze:
142
+ self.encoder.eval()
143
+ for p in self.encoder.parameters():
144
+ p.requires_grad = False
145
+ # Note: bottleneck stays trainable during cache build, but we detach outputs
146
+
147
+ @torch.no_grad()
148
+ def forward(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]:
149
+ """
150
+ Encode text prompts with bottleneck.
151
+
152
+ Returns:
153
+ sequence: [B, L, bottleneck_dim] - compressed sequence embeddings
154
+ pooled: [B, bottleneck_dim] - compressed mean pooled embedding
155
+ """
156
+ tokens = self.tokenizer(
157
+ texts,
158
+ padding="max_length",
159
+ max_length=self.max_length,
160
+ truncation=True,
161
+ return_tensors="pt",
162
+ )
163
+
164
+ input_ids = tokens.input_ids.to(device)
165
+ attention_mask = tokens.attention_mask.to(device)
166
+
167
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
168
+ sequence_raw = outputs.last_hidden_state # [B, L, raw_dim]
169
+
170
+ # Apply bottleneck
171
+ sequence = self.bottleneck(sequence_raw) # [B, L, bottleneck_dim]
172
+
173
+ # Mean pool over non-padding tokens
174
+ mask_expanded = attention_mask.unsqueeze(-1).float()
175
+ pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
176
+
177
+ return sequence, pooled
178
+
179
+ @torch.no_grad()
180
+ def encode_raw(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]:
181
+ """
182
+ Encode text prompts WITHOUT bottleneck (for caching raw embeddings).
183
+
184
+ Returns:
185
+ sequence: [B, L, raw_dim] - raw T5 embeddings
186
+ pooled: [B, raw_dim] - raw mean pooled embedding
187
+ """
188
+ tokens = self.tokenizer(
189
+ texts,
190
+ padding="max_length",
191
+ max_length=self.max_length,
192
+ truncation=True,
193
+ return_tensors="pt",
194
+ )
195
+
196
+ input_ids = tokens.input_ids.to(device)
197
+ attention_mask = tokens.attention_mask.to(device)
198
+
199
+ outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
200
+ sequence = outputs.last_hidden_state # [B, L, raw_dim]
201
+
202
+ # Mean pool over non-padding tokens
203
+ mask_expanded = attention_mask.unsqueeze(-1).float()
204
+ pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1)
205
+
206
+ return sequence, pooled
207
+
208
+
209
+ # =============================================================================
210
+ # CACHED DATASET (VAE latents + T5 text embeddings per class)
211
+ # =============================================================================
212
+
213
+ class CachedCIFAR10T5(Dataset):
214
+ """
215
+ Pre-cached CIFAR-10 with VAE latents.
216
+ T5 embeddings are computed per-class (not per-image).
217
+ """
218
+
219
+ T5_MODEL = "google/flan-t5-xl" # Change this to use different T5 variant
220
+
221
+ def __init__(
222
+ self,
223
+ train: bool = True,
224
+ image_size: int = 256,
225
+ cache_dir: str = "./cache",
226
+ device: str = "cuda",
227
+ ):
228
+ self.train = train
229
+ # Include T5 model name in cache path
230
+ t5_suffix = self.T5_MODEL.replace("/", "_")
231
+ self.cache_path = Path(cache_dir) / f"cifar10_{t5_suffix}_{'train' if train else 'val'}_{image_size}.pt"
232
+
233
+ if self.cache_path.exists():
234
+ print(f"Loading cache: {self.cache_path}")
235
+ cache = torch.load(self.cache_path, weights_only=False)
236
+ self.latents = cache['latents']
237
+ self.labels = cache['labels']
238
+ self.text_sequence = cache['text_sequence'] # [10, L, dim]
239
+ self.text_pooled = cache['text_pooled'] # [10, dim]
240
+ self.text_dim = cache.get('text_dim', self.text_pooled.shape[-1])
241
+ else:
242
+ print(f"Building cache for {'train' if train else 'val'} set...")
243
+ self._build_cache(image_size, device)
244
+
245
+ def _build_cache(self, image_size: int, device: str):
246
+ # Load encoders
247
+ print(" Loading VAE...")
248
+ vae = SD15VAE(freeze=True).to(device)
249
+ print(f" Loading T5 ({self.T5_MODEL})...")
250
+ t5 = T5TextEncoder(model_name=self.T5_MODEL, freeze=True).to(device)
251
+
252
+ # Encode class prompts - save RAW embeddings (bottleneck is in model)
253
+ print(f" Encoding text prompts (T5 raw_dim={t5.raw_dim})...")
254
+ text_seq, text_pool = t5.encode_raw(CIFAR10_PROMPTS, device)
255
+ self.text_sequence = text_seq.cpu() # [10, L, raw_dim]
256
+ self.text_pooled = text_pool.cpu() # [10, raw_dim]
257
+ self.text_dim = t5.raw_dim # Store raw dim for bottleneck sizing
258
+
259
+ # Encode images
260
+ transform = transforms.Compose([
261
+ transforms.Resize((image_size, image_size)),
262
+ transforms.ToTensor(),
263
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
264
+ ])
265
+
266
+ dataset = datasets.CIFAR10('./data', train=self.train, download=True, transform=transform)
267
+ loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
268
+
269
+ all_latents, all_labels = [], []
270
+
271
+ print(" Encoding images...")
272
+ with torch.no_grad():
273
+ for images, labels in tqdm(loader, desc=" Caching", leave=False):
274
+ images = images.to(device)
275
+ all_latents.append(vae.encode(images).cpu())
276
+ all_labels.append(labels)
277
+
278
+ self.latents = torch.cat(all_latents, dim=0)
279
+ self.labels = torch.cat(all_labels, dim=0)
280
+
281
+ del vae, t5
282
+ torch.cuda.empty_cache()
283
+
284
+ # Save
285
+ self.cache_path.parent.mkdir(parents=True, exist_ok=True)
286
+ torch.save({
287
+ 'latents': self.latents,
288
+ 'labels': self.labels,
289
+ 'text_sequence': self.text_sequence,
290
+ 'text_pooled': self.text_pooled,
291
+ 'text_dim': self.text_dim,
292
+ }, self.cache_path)
293
+ print(f" Saved: {self.cache_path}")
294
+
295
+ def __len__(self):
296
+ return len(self.labels)
297
+
298
+ def __getitem__(self, idx):
299
+ label = self.labels[idx]
300
+ return (
301
+ self.latents[idx],
302
+ self.text_sequence[label], # [L, raw_dim]
303
+ self.text_pooled[label], # [raw_dim]
304
+ label,
305
+ )
306
+
307
+
308
+ # =============================================================================
309
+ # SINUSOIDAL EMBEDDING
310
+ # =============================================================================
311
+
312
+ class SinusoidalEmbed(nn.Module):
313
+ def __init__(self, dim: int):
314
+ super().__init__()
315
+ self.dim = dim
316
+
317
+ def forward(self, t: Tensor) -> Tensor:
318
+ half = self.dim // 2
319
+ freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half)
320
+ args = t.unsqueeze(-1) * freqs
321
+ return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
322
+
323
+
324
+ # =============================================================================
325
+ # CONFIG
326
+ # =============================================================================
327
+
328
+ @dataclass
329
+ class FlowConfig:
330
+ image_size: int = 256
331
+ num_classes: int = 10
332
+ latent_channels: int = 4
333
+ latent_size: int = 32
334
+
335
+ # T5 dimensions
336
+ text_raw_dim: int = 2048 # Raw T5-XL output, overridden by dataset
337
+ text_seq_len: int = 77
338
+ bottleneck_dim: int = 256 # Compressed text dim
339
+
340
+ # Tower collective (transformer-based)
341
+ tower_dim: int = 256
342
+ tower_depth: int = 2
343
+ num_heads: int = 8
344
+ geometric_types: Tuple[str, ...] = ('cantor', 'beatrix', 'helix', 'simplex')
345
+
346
+ # Conv tower types (convolutional)
347
+ conv_types: Tuple[str, ...] = ('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite')
348
+ conv_spatial_size: int = 8 # Spatial size for conv towers
349
+
350
+ # Oscillator
351
+ manifold_dim: int = 1024 # Projected manifold (smaller than latent)
352
+ num_tower_pairs: int = 16 # 32 towers / 2
353
+ osc_steps: int = 50 # For sampling only
354
+ fingerprint_dim: int = 64
355
+
356
+ # Flow
357
+ num_flow_steps: int = 50
358
+ sigma_min: float = 0.001
359
+
360
+ # Training
361
+ batch_size: int = 64
362
+ lr: float = 1e-4
363
+ weight_decay: float = 0.01
364
+ num_epochs: int = 100
365
+
366
+ cache_dir: str = "./cache"
367
+ device: str = "cuda"
368
+ output_dir: str = "./beatrix_cifar_t5"
369
+
370
+ @property
371
+ def latent_flat_dim(self) -> int:
372
+ """Full flattened latent size: 4 × 32 × 32 = 4096"""
373
+ return self.latent_channels * self.latent_size * self.latent_size
374
+
375
+
376
+ # =============================================================================
377
+ # BEATRIX FLOW MODEL (Vision + Text Towers)
378
+ # =============================================================================
379
+
380
+ class BeatrixFlowT5(WideRouter):
381
+ """
382
+ Flow model with dual tower collectives per modality:
383
+
384
+ Vision side:
385
+ - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg)
386
+ - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg)
387
+
388
+ Text side (mirrored):
389
+ - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg)
390
+ - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg)
391
+
392
+ All towers output opinions that combine for velocity prediction.
393
+ """
394
+
395
+ def __init__(self, cfg: FlowConfig):
396
+ super().__init__(name='beatrix_flow_t5', strict=False, auto_discover=False)
397
+ self.objects['cfg'] = cfg
398
+
399
+ # =================================================================
400
+ # TEXT BOTTLENECK (trainable)
401
+ # =================================================================
402
+ self.attach('text_bottleneck_seq', nn.Sequential(
403
+ nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim),
404
+ nn.GELU(),
405
+ nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim),
406
+ ))
407
+ self.attach('text_bottleneck_pool', nn.Sequential(
408
+ nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim),
409
+ nn.GELU(),
410
+ nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim),
411
+ ))
412
+
413
+ # =================================================================
414
+ # VISION GEOMETRIC TOWERS (pos/neg pairs)
415
+ # =================================================================
416
+ vision_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types))
417
+
418
+ vision_geo_collective = build_tower_collective(
419
+ configs=vision_geo_configs,
420
+ dim=cfg.tower_dim,
421
+ default_depth=cfg.tower_depth,
422
+ num_heads=cfg.num_heads,
423
+ ffn_mult=4.0,
424
+ dropout=0.1,
425
+ fingerprint_dim=cfg.fingerprint_dim,
426
+ fusion_type='adaptive',
427
+ name='vision_geo',
428
+ )
429
+ self.attach('vision_geo', vision_geo_collective)
430
+
431
+ # =================================================================
432
+ # VISION CONV TOWERS (pos/neg pairs)
433
+ # =================================================================
434
+ vision_conv_configs = preset_conv_pos_neg(list(cfg.conv_types))
435
+
436
+ vision_conv_collective = build_conv_collective(
437
+ configs=vision_conv_configs,
438
+ dim=cfg.tower_dim,
439
+ default_depth=cfg.tower_depth,
440
+ fingerprint_dim=cfg.fingerprint_dim,
441
+ spatial_size=cfg.conv_spatial_size,
442
+ name='vision_conv',
443
+ )
444
+ self.attach('vision_conv', vision_conv_collective)
445
+
446
+ # =================================================================
447
+ # TEXT GEOMETRIC TOWERS (pos/neg pairs) - MIRRORED
448
+ # =================================================================
449
+ text_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types))
450
+
451
+ text_geo_collective = build_tower_collective(
452
+ configs=text_geo_configs,
453
+ dim=cfg.tower_dim,
454
+ default_depth=cfg.tower_depth,
455
+ num_heads=cfg.num_heads,
456
+ ffn_mult=4.0,
457
+ dropout=0.1,
458
+ fingerprint_dim=cfg.fingerprint_dim,
459
+ fusion_type='adaptive',
460
+ name='text_geo',
461
+ )
462
+ self.attach('text_geo', text_geo_collective)
463
+
464
+ # =================================================================
465
+ # TEXT CONV TOWERS (pos/neg pairs) - MIRRORED
466
+ # =================================================================
467
+ text_conv_configs = preset_conv_pos_neg(list(cfg.conv_types))
468
+
469
+ text_conv_collective = build_conv_collective(
470
+ configs=text_conv_configs,
471
+ dim=cfg.tower_dim,
472
+ default_depth=cfg.tower_depth,
473
+ fingerprint_dim=cfg.fingerprint_dim,
474
+ spatial_size=cfg.conv_spatial_size,
475
+ name='text_conv',
476
+ )
477
+ self.attach('text_conv', text_conv_collective)
478
+
479
+ # =================================================================
480
+ # PROJECTIONS
481
+ # =================================================================
482
+ # Latent patchifier
483
+ patch_size = 4
484
+ num_patches = (cfg.latent_size // patch_size) ** 2
485
+ patch_dim = cfg.latent_channels * patch_size * patch_size
486
+
487
+ self.attach('patch_proj', nn.Linear(patch_dim, cfg.tower_dim))
488
+ self.patch_pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.tower_dim) * 0.02)
489
+ self.objects['patch_size'] = patch_size
490
+ self.objects['num_patches'] = num_patches
491
+
492
+ # Text already at bottleneck_dim (256) = tower_dim, no extra projection needed
493
+
494
+ # =================================================================
495
+ # OSCILLATOR (for sampling)
496
+ # =================================================================
497
+ # Total towers: (4 geo + 4 conv) × pos/neg × 2 modalities = 32 towers
498
+ num_geo_towers = len(vision_geo_configs)
499
+ num_conv_towers = len(vision_conv_configs)
500
+ total_towers = (num_geo_towers + num_conv_towers) * 2 # × 2 for vision + text
501
+
502
+ oscillator = BeatrixOscillator(
503
+ name='oscillator',
504
+ manifold_dim=cfg.manifold_dim,
505
+ tower_dim=cfg.tower_dim,
506
+ num_tower_pairs=total_towers // 2,
507
+ num_theta_probes=4,
508
+ fingerprint_dim=cfg.fingerprint_dim,
509
+ kappa_schedule=ScheduleType.TESLA_369,
510
+ use_intrinsic_tension=True,
511
+ )
512
+ self.attach('oscillator', oscillator)
513
+
514
+ # =================================================================
515
+ # CONDITIONING
516
+ # =================================================================
517
+ # Time embedding
518
+ time_embed = nn.Sequential(
519
+ SinusoidalEmbed(256),
520
+ nn.Linear(256, cfg.tower_dim),
521
+ nn.GELU(),
522
+ nn.Linear(cfg.tower_dim, cfg.tower_dim),
523
+ )
524
+ self.attach('time_embed', time_embed)
525
+
526
+ # Bottlenecked text -> reference anchor
527
+ self.attach('text_to_ref', nn.Sequential(
528
+ nn.Linear(cfg.bottleneck_dim, cfg.manifold_dim),
529
+ nn.GELU(),
530
+ nn.Linear(cfg.manifold_dim, cfg.manifold_dim),
531
+ ))
532
+
533
+ # Time modulation for reference
534
+ self.attach('time_to_ref', nn.Linear(cfg.tower_dim, cfg.manifold_dim))
535
+
536
+ # =================================================================
537
+ # LATENT PROJECTION (4096 <-> manifold_dim)
538
+ # =================================================================
539
+ self.attach('latent_down', nn.Linear(cfg.latent_flat_dim, cfg.manifold_dim))
540
+ self.attach('latent_up', nn.Linear(cfg.manifold_dim, cfg.latent_flat_dim))
541
+
542
+ # Learnable velocity mixing
543
+ self.velocity_mix = nn.Parameter(torch.tensor(0.5))
544
+
545
+ def patchify(self, z: Tensor) -> Tensor:
546
+ """[B, 4, 32, 32] -> [B, num_patches, tower_dim]"""
547
+ B, C, H, W = z.shape
548
+ p = self.objects['patch_size']
549
+
550
+ z = z.unfold(2, p, p).unfold(3, p, p)
551
+ z = z.permute(0, 2, 3, 1, 4, 5).contiguous()
552
+ z = z.view(B, -1, C * p * p)
553
+
554
+ return self['patch_proj'](z) + self.patch_pos_embed
555
+
556
+ def get_tower_outputs(self, z: Tensor, text_seq: Tensor) -> List[Tensor]:
557
+ """
558
+ Run all four tower collectives.
559
+ Returns list of tower opinions [B, tower_dim] (32 total).
560
+ """
561
+ patches = self.patchify(z)
562
+ text_bottlenecked = self['text_bottleneck_seq'](text_seq)
563
+
564
+ # Run all collectives
565
+ vision_geo = self['vision_geo'](patches)
566
+ vision_conv_fused, vision_conv_ops = self['vision_conv'](patches)
567
+ text_geo = self['text_geo'](text_bottlenecked)
568
+ text_conv_fused, text_conv_ops = self['text_conv'](text_bottlenecked)
569
+
570
+ # Collect opinions - use list comprehension (faster than append loop)
571
+ return (
572
+ [op.opinion for op in vision_geo.opinions.values()] +
573
+ list(vision_conv_ops.values()) +
574
+ [op.opinion for op in text_geo.opinions.values()] +
575
+ list(text_conv_ops.values())
576
+ )
577
+
578
+ def forward(
579
+ self,
580
+ z_0: Tensor,
581
+ text_seq: Tensor,
582
+ text_pooled: Tensor,
583
+ labels: Tensor,
584
+ t: Optional[Tensor] = None,
585
+ ) -> Dict[str, Tensor]:
586
+ """Training forward - single step velocity prediction."""
587
+ cfg = self.objects['cfg']
588
+ B = z_0.shape[0]
589
+ device = z_0.device
590
+
591
+ if t is None:
592
+ t = torch.rand(B, device=device)
593
+
594
+ # Flatten latent [B, 4, 32, 32] -> [B, 4096]
595
+ z_0_flat = z_0.flatten(1)
596
+
597
+ # Noise + interpolate in full latent space
598
+ eps = torch.randn_like(z_0)
599
+ eps_flat = eps.flatten(1)
600
+ t_exp = t.view(B, 1, 1, 1)
601
+ z_t = (1 - t_exp) * z_0 + t_exp * eps
602
+ z_t_flat = z_t.flatten(1)
603
+
604
+ # Target velocity (in full latent space)
605
+ v_target = eps_flat - z_0_flat
606
+
607
+ # === PROJECT TO SMALLER MANIFOLD ===
608
+ z_t_proj = self['latent_down'](z_t_flat) # [B, 4096] -> [B, manifold_dim]
609
+
610
+ # Bottleneck pooled text for reference
611
+ text_pooled_bn = self['text_bottleneck_pool'](text_pooled)
612
+
613
+ # Reference from bottlenecked text + time (in manifold space)
614
+ time_emb = self['time_embed'](t)
615
+ x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb)
616
+
617
+ # Get all tower outputs (text_seq bottlenecked inside get_tower_outputs)
618
+ tower_outputs = self.get_tower_outputs(z_t, text_seq)
619
+
620
+ # Compute forces in manifold space
621
+ osc = self['oscillator']
622
+ tower_force, _ = osc.force_generator(z_t_proj, tower_outputs, state_fingerprint=None)
623
+ spring_force = x_ref - z_t_proj
624
+
625
+ # Velocity prediction in manifold space
626
+ tau = torch.sigmoid(self.velocity_mix)
627
+ v_pred_proj = (1 - tau) * spring_force + tau * tower_force
628
+
629
+ # === PROJECT BACK TO FULL LATENT ===
630
+ v_pred = self['latent_up'](v_pred_proj) # [B, manifold_dim] -> [B, 4096]
631
+
632
+ loss = F.mse_loss(v_pred, v_target)
633
+
634
+ return {'loss': loss, 'tau': tau.detach()}
635
+
636
+ @torch.no_grad()
637
+ def sample(
638
+ self,
639
+ text_seq: Tensor,
640
+ text_pooled: Tensor,
641
+ vae: SD15VAE,
642
+ num_steps: Optional[int] = None,
643
+ ) -> Tensor:
644
+ """Generate samples from text conditioning."""
645
+ cfg = self.objects['cfg']
646
+ B = text_seq.shape[0]
647
+ device = text_seq.device
648
+ num_steps = num_steps or cfg.num_flow_steps
649
+
650
+ # Bottleneck pooled text once
651
+ text_pooled_bn = self['text_bottleneck_pool'](text_pooled)
652
+
653
+ # Start from noise
654
+ z = torch.randn(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size, device=device)
655
+
656
+ dt = 1.0 / num_steps
657
+
658
+ for step in range(num_steps):
659
+ t_val = 1.0 - step * dt
660
+ t = torch.full((B,), t_val, device=device)
661
+
662
+ time_emb = self['time_embed'](t)
663
+ x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb)
664
+
665
+ z_flat = z.flatten(1)
666
+
667
+ # Project to manifold
668
+ z_proj = self['latent_down'](z_flat)
669
+
670
+ tower_outputs = self.get_tower_outputs(z, text_seq)
671
+
672
+ osc = self['oscillator']
673
+ tower_force, _ = osc.force_generator(z_proj, tower_outputs, state_fingerprint=None)
674
+ spring_force = x_ref - z_proj
675
+
676
+ tau = torch.sigmoid(self.velocity_mix)
677
+ v_pred_proj = (1 - tau) * spring_force + tau * tower_force
678
+
679
+ # Project back and update
680
+ v_pred = self['latent_up'](v_pred_proj)
681
+ z_flat = z_flat - dt * v_pred
682
+ z = z_flat.view(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size)
683
+
684
+ return vae.decode(z)
685
+
686
+
687
+ # =============================================================================
688
+ # TRAINER
689
+ # =============================================================================
690
+
691
+ class Trainer:
692
+ def __init__(self, cfg: FlowConfig):
693
+ self.cfg = cfg
694
+ self.device = torch.device(cfg.device if torch.cuda.is_available() else "cpu")
695
+ self.output_dir = Path(cfg.output_dir)
696
+ self.output_dir.mkdir(parents=True, exist_ok=True)
697
+
698
+ if torch.cuda.is_available():
699
+ torch.backends.cudnn.benchmark = True
700
+ torch.backends.cuda.matmul.allow_tf32 = True
701
+ torch.backends.cudnn.allow_tf32 = True
702
+
703
+ self.scaler = torch.amp.GradScaler('cuda')
704
+
705
+ # Dataset
706
+ print("\n=== Building Cached Datasets ===")
707
+ self.train_dataset = CachedCIFAR10T5(train=True, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device)
708
+ self.val_dataset = CachedCIFAR10T5(train=False, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device)
709
+
710
+ # Update config with actual T5 raw dimension from cache
711
+ cfg.text_raw_dim = self.train_dataset.text_dim
712
+ print(f"T5 raw dimension: {cfg.text_raw_dim} → bottleneck: {cfg.bottleneck_dim}")
713
+
714
+ self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
715
+ self.val_loader = DataLoader(self.val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0, pin_memory=True)
716
+
717
+ # Store raw text embeddings for sampling (bottleneck applied in model)
718
+ self.text_sequence = self.train_dataset.text_sequence.to(self.device) # [10, L, raw_dim]
719
+ self.text_pooled = self.train_dataset.text_pooled.to(self.device) # [10, raw_dim]
720
+
721
+ # Model
722
+ print("\n=== Building Model (Vision + Text Towers) ===")
723
+ self.model = BeatrixFlowT5(cfg).to(self.device)
724
+
725
+ # Compile
726
+ if hasattr(torch, 'compile'):
727
+ print("Compiling with WideRouter.prepare_and_compile()...")
728
+ self.model = self.model.prepare_and_compile(
729
+ mode="reduce-overhead",
730
+ fullgraph=False,
731
+ )
732
+
733
+ num_params = sum(p.numel() for p in self.model.parameters())
734
+ print(f"Trainable parameters: {num_params:,}")
735
+
736
+ self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
737
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=cfg.num_epochs * len(self.train_loader))
738
+
739
+ # Load most recent checkpoint if exists
740
+ self.start_epoch = 0
741
+ self.hf_repo = "AbstractPhil/beatrix-diffusion-proto"
742
+ self._load_latest_checkpoint()
743
+
744
+ self._vae = None
745
+
746
+ # HuggingFace Hub setup
747
+ self._setup_hf_repo()
748
+
749
+ def _setup_hf_repo(self):
750
+ """Create HF repo if needed and save initial config."""
751
+ try:
752
+ self.hf_api = HfApi()
753
+ create_repo(self.hf_repo, exist_ok=True, repo_type="model")
754
+ print(f"HF repo: {self.hf_repo}")
755
+
756
+ # Save config
757
+ config_dict = {
758
+ 'image_size': self.cfg.image_size,
759
+ 'num_classes': self.cfg.num_classes,
760
+ 'latent_channels': self.cfg.latent_channels,
761
+ 'latent_size': self.cfg.latent_size,
762
+ 'text_raw_dim': self.cfg.text_raw_dim,
763
+ 'bottleneck_dim': self.cfg.bottleneck_dim,
764
+ 'tower_dim': self.cfg.tower_dim,
765
+ 'tower_depth': self.cfg.tower_depth,
766
+ 'num_heads': self.cfg.num_heads,
767
+ 'geometric_types': self.cfg.geometric_types,
768
+ 'conv_types': self.cfg.conv_types,
769
+ 'conv_spatial_size': self.cfg.conv_spatial_size,
770
+ 'manifold_dim': self.cfg.manifold_dim,
771
+ 'fingerprint_dim': self.cfg.fingerprint_dim,
772
+ 'num_flow_steps': self.cfg.num_flow_steps,
773
+ }
774
+ config_path = self.output_dir / "config.json"
775
+ with open(config_path, 'w') as f:
776
+ json.dump(config_dict, f, indent=2)
777
+
778
+ upload_file(
779
+ path_or_fileobj=str(config_path),
780
+ path_in_repo="config.json",
781
+ repo_id=self.hf_repo,
782
+ )
783
+ except Exception as e:
784
+ print(f"HF setup warning: {e}")
785
+ self.hf_api = None
786
+
787
+ def _upload_to_hf(self, epoch: int, sample_path: Path, metrics: dict = None):
788
+ """Upload checkpoint, samples, and metrics to HuggingFace."""
789
+ if self.hf_api is None:
790
+ return
791
+
792
+ try:
793
+ # Upload checkpoint
794
+ ckpt_path = self.output_dir / "ckpt_latest.pt"
795
+ if ckpt_path.exists():
796
+ upload_file(
797
+ path_or_fileobj=str(ckpt_path),
798
+ path_in_repo="ckpt_latest.pt",
799
+ repo_id=self.hf_repo,
800
+ )
801
+
802
+ # Upload samples
803
+ if sample_path.exists():
804
+ upload_file(
805
+ path_or_fileobj=str(sample_path),
806
+ path_in_repo=f"samples/epoch_{epoch:03d}.png",
807
+ repo_id=self.hf_repo,
808
+ )
809
+ # Also as latest
810
+ upload_file(
811
+ path_or_fileobj=str(sample_path),
812
+ path_in_repo="samples/latest.png",
813
+ repo_id=self.hf_repo,
814
+ )
815
+
816
+ # Upload metrics log
817
+ if metrics:
818
+ metrics_path = self.output_dir / "metrics.jsonl"
819
+ with open(metrics_path, 'a') as f:
820
+ f.write(json.dumps({'epoch': epoch, **metrics}) + '\n')
821
+ upload_file(
822
+ path_or_fileobj=str(metrics_path),
823
+ path_in_repo="metrics.jsonl",
824
+ repo_id=self.hf_repo,
825
+ )
826
+
827
+ print(f" → Uploaded to HF")
828
+ except Exception as e:
829
+ print(f" → HF upload failed: {e}")
830
+
831
+ def _load_latest_checkpoint(self):
832
+ """Load most recent checkpoint if available (local or HF)."""
833
+ latest_path = self.output_dir / "ckpt_latest.pt"
834
+
835
+ # Try local first
836
+ if latest_path.exists():
837
+ print(f"Resuming from local ckpt_latest.pt...")
838
+ ckpt = torch.load(latest_path, weights_only=False)
839
+ else:
840
+ # Fall back to numbered checkpoints
841
+ ckpts = sorted(self.output_dir.glob("ckpt_epoch*.pt"))
842
+ if ckpts:
843
+ latest_path = ckpts[-1]
844
+ print(f"Resuming from {latest_path.name}...")
845
+ ckpt = torch.load(latest_path, weights_only=False)
846
+ else:
847
+ # Try downloading from HuggingFace
848
+ try:
849
+ from huggingface_hub import hf_hub_download
850
+ print(f"Checking HF for checkpoint...")
851
+ hf_path = hf_hub_download(
852
+ repo_id=self.hf_repo,
853
+ filename="ckpt_latest.pt",
854
+ local_dir=str(self.output_dir),
855
+ )
856
+ print(f"Downloaded checkpoint from HF")
857
+ ckpt = torch.load(hf_path, weights_only=False)
858
+ except Exception as e:
859
+ print(f"No checkpoint found (local or HF): {e}")
860
+ return
861
+
862
+ self.model.load_state_dict(ckpt['model'])
863
+ self.optimizer.load_state_dict(ckpt['optimizer'])
864
+ self.scheduler.load_state_dict(ckpt['scheduler'])
865
+ self.start_epoch = ckpt['epoch']
866
+ print(f" Resumed at epoch {self.start_epoch}")
867
+
868
+ def _load_vae(self):
869
+ """Load VAE for sampling (temporary)."""
870
+ print("Loading VAE for sampling...")
871
+ return SD15VAE(freeze=True).to(self.device)
872
+
873
+ def _unload_vae(self, vae):
874
+ """Unload VAE after sampling."""
875
+ del vae
876
+ torch.cuda.empty_cache()
877
+
878
+ def train_epoch(self, epoch: int) -> Dict[str, float]:
879
+ self.model.train()
880
+ total_loss, total_tau, n = 0.0, 0.0, 0
881
+
882
+ pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.cfg.num_epochs}", leave=False)
883
+ for latents, text_seq, text_pooled, labels in pbar:
884
+ latents = latents.to(self.device)
885
+ text_seq = text_seq.to(self.device)
886
+ text_pooled = text_pooled.to(self.device)
887
+ labels = labels.to(self.device)
888
+
889
+ with torch.amp.autocast('cuda'):
890
+ out = self.model(latents, text_seq, text_pooled, labels)
891
+ loss = out['loss']
892
+
893
+ self.optimizer.zero_grad()
894
+ self.scaler.scale(loss).backward()
895
+ self.scaler.unscale_(self.optimizer)
896
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
897
+ self.scaler.step(self.optimizer)
898
+ self.scaler.update()
899
+ self.scheduler.step()
900
+
901
+ total_loss += loss.item()
902
+ total_tau += out['tau'].item()
903
+ n += 1
904
+
905
+ pbar.set_postfix(loss=f"{loss.item():.4f}", Ï„=f"{out['tau'].item():.2f}")
906
+
907
+ return {'loss': total_loss / n, 'tau': total_tau / n}
908
+
909
+ @torch.no_grad()
910
+ def validate(self) -> Dict[str, float]:
911
+ self.model.eval()
912
+ total_loss, n = 0.0, 0
913
+
914
+ for latents, text_seq, text_pooled, labels in self.val_loader:
915
+ latents = latents.to(self.device)
916
+ text_seq = text_seq.to(self.device)
917
+ text_pooled = text_pooled.to(self.device)
918
+ labels = labels.to(self.device)
919
+
920
+ with torch.amp.autocast('cuda'):
921
+ out = self.model(latents, text_seq, text_pooled, labels)
922
+ total_loss += out['loss'].item()
923
+ n += 1
924
+
925
+ return {'val_loss': total_loss / n}
926
+
927
+ @torch.no_grad()
928
+ def sample_images(self, n_per_class: int = 10) -> Tensor:
929
+ """Generate samples for each class (memory-efficient batched)."""
930
+ self.model.eval()
931
+ torch.cuda.empty_cache()
932
+
933
+ # Load VAE temporarily
934
+ vae = self._load_vae()
935
+
936
+ all_samples = []
937
+ batch_size = 10 # Generate 10 images at a time
938
+
939
+ for class_idx in range(10):
940
+ # Generate n_per_class images for this class
941
+ for batch_start in range(0, n_per_class, batch_size):
942
+ batch_n = min(batch_size, n_per_class - batch_start)
943
+
944
+ text_seq = self.text_sequence[class_idx:class_idx+1].expand(batch_n, -1, -1)
945
+ text_pooled = self.text_pooled[class_idx:class_idx+1].expand(batch_n, -1)
946
+
947
+ with torch.amp.autocast('cuda'):
948
+ samples = self.model.sample(text_seq, text_pooled, vae)
949
+
950
+ all_samples.append(samples.cpu())
951
+
952
+ # Unload VAE
953
+ self._unload_vae(vae)
954
+
955
+ samples = torch.cat(all_samples, dim=0).to(self.device)
956
+ return ((samples + 1) / 2).clamp(0, 1)
957
+
958
+ def save_checkpoint(self, epoch: int, milestone: bool = False):
959
+ ckpt = {
960
+ 'epoch': epoch,
961
+ 'model': self.model.state_dict(),
962
+ 'optimizer': self.optimizer.state_dict(),
963
+ 'scheduler': self.scheduler.state_dict(),
964
+ }
965
+ # Always save latest (for resume)
966
+ torch.save(ckpt, self.output_dir / "ckpt_latest.pt")
967
+ # Save milestone checkpoints
968
+ if milestone:
969
+ torch.save(ckpt, self.output_dir / f"ckpt_epoch{epoch:03d}.pt")
970
+
971
+ def train(self):
972
+ num_geo = len(self.cfg.geometric_types) * 2 # pos/neg
973
+ num_conv = len(self.cfg.conv_types) * 2
974
+ total_towers = (num_geo + num_conv) * 2 # × 2 modalities
975
+
976
+ print(f"\n{'='*60}")
977
+ print("BEATRIX FLOW - Dual Geometric + Conv Towers (Bottlenecked)")
978
+ print(f"{'='*60}")
979
+ print(f"Device: {self.device}")
980
+ print(f"Geometric towers: {self.cfg.geometric_types} (pos/neg)")
981
+ print(f"Conv towers: {self.cfg.conv_types} (pos/neg)")
982
+ print(f"Tower dim: {self.cfg.tower_dim}")
983
+ print(f"T5 raw → bottleneck: {self.cfg.text_raw_dim} → {self.cfg.bottleneck_dim}")
984
+ print(f"Latent → manifold: {self.cfg.latent_flat_dim} → {self.cfg.manifold_dim}")
985
+ print(f"Total towers: {total_towers}")
986
+ print(f"Batch size: {self.cfg.batch_size}")
987
+ print(f"Epochs: {self.start_epoch}/{self.cfg.num_epochs}")
988
+ print(f"{'='*60}\n")
989
+
990
+ for epoch in range(self.start_epoch, self.cfg.num_epochs):
991
+ train_metrics = self.train_epoch(epoch)
992
+ val_metrics = self.validate()
993
+
994
+ lr = self.scheduler.get_last_lr()[0]
995
+ print(f"Epoch {epoch+1:3d} │ loss={train_metrics['loss']:.4f} │ val={val_metrics['val_loss']:.4f} │ τ={train_metrics['tau']:.2f} │ lr={lr:.2e}")
996
+
997
+ # Sample every epoch to track progress
998
+ samples = self.sample_images(10)
999
+ grid = make_grid(samples, nrow=10, padding=2)
1000
+ sample_path = self.output_dir / f"samples_epoch{epoch+1:03d}.png"
1001
+ save_image(grid, sample_path)
1002
+ print(f" → Saved samples")
1003
+
1004
+ # Checkpoint every epoch (latest), milestone every 10
1005
+ self.save_checkpoint(epoch + 1, milestone=((epoch + 1) % 10 == 0))
1006
+
1007
+ # Upload to HuggingFace
1008
+ metrics = {
1009
+ 'loss': train_metrics['loss'],
1010
+ 'val_loss': val_metrics['val_loss'],
1011
+ 'tau': train_metrics['tau'],
1012
+ 'lr': lr,
1013
+ }
1014
+ self._upload_to_hf(epoch + 1, sample_path, metrics)
1015
+
1016
+ samples = self.sample_images(10)
1017
+ grid = make_grid(samples, nrow=10, padding=2)
1018
+ final_path = self.output_dir / "samples_final.png"
1019
+ save_image(grid, final_path)
1020
+ self.save_checkpoint(self.cfg.num_epochs, milestone=True)
1021
+ self._upload_to_hf(self.cfg.num_epochs, final_path)
1022
+ print(f"\nTraining complete!")
1023
+
1024
+
1025
+ # =============================================================================
1026
+ # MAIN
1027
+ # =============================================================================
1028
+
1029
+ def main():
1030
+ # Lightweight config - 16 towers instead of 32
1031
+ cfg = FlowConfig(
1032
+ image_size=256,
1033
+ tower_dim=256,
1034
+ tower_depth=2,
1035
+ num_heads=8,
1036
+ geometric_types=('cantor', 'beatrix'), # 2 types × pos/neg = 4 per modality
1037
+ conv_types=('wide_resnet', 'squeeze_excite'), # 2 types × pos/neg = 4 per modality
1038
+ conv_spatial_size=8,
1039
+ bottleneck_dim=256,
1040
+ manifold_dim=512, # Smaller manifold
1041
+ batch_size=64,
1042
+ num_epochs=100,
1043
+ cache_dir="./cache",
1044
+ output_dir="./beatrix_cifar_t5",
1045
+ )
1046
+
1047
+ trainer = Trainer(cfg)
1048
+ trainer.train()
1049
+
1050
+
1051
+ def main_full():
1052
+ """Full 32-tower configuration."""
1053
+ cfg = FlowConfig(
1054
+ image_size=256,
1055
+ tower_dim=256,
1056
+ tower_depth=2,
1057
+ num_heads=8,
1058
+ geometric_types=('cantor', 'beatrix', 'helix', 'simplex'),
1059
+ conv_types=('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite'),
1060
+ conv_spatial_size=8,
1061
+ bottleneck_dim=256,
1062
+ manifold_dim=1024,
1063
+ batch_size=64,
1064
+ num_epochs=100,
1065
+ cache_dir="./cache",
1066
+ output_dir="./beatrix_cifar_t5",
1067
+ )
1068
+
1069
+ trainer = Trainer(cfg)
1070
+ trainer.train()
1071
+
1072
+
1073
+ if __name__ == "__main__":
1074
+ main()