AbstractPhil commited on
Commit
394b68b
·
verified ·
1 Parent(s): 3f1f097

Create stage1_analysis_trainer.py

Browse files
Files changed (1) hide show
  1. stage1_analysis_trainer.py +1049 -0
stage1_analysis_trainer.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Geometric Transformer — CIFAR-100 Training with CM-Validated Analysis
3
+
4
+ Changes from previous version:
5
+ - CM gate diagnostics per layer: active anchors, gate_mean, cm_positive_frac
6
+ - CM quality in geometric residual analysis (replaces blind gate)
7
+ - Geometric regularization losses (CV target + anchor spread) in training loop
8
+ - Anchor diagnostics via model.anchor_diagnostics()
9
+ - CM quality trajectory alongside CV and bridge KL for cooperation analysis
10
+
11
+ TensorBoard logging of every geometric feature element:
12
+ - CV (coefficient of variation) per layer — the pentachoron band metric
13
+ - CM gate: active anchors, gate mean, cm_positive_frac, quality per position
14
+ - Stream agreement/divergence per layer
15
+ - Anchor utilization, entropy, spread
16
+ - Patchwork activation statistics (from CM-validated triangulation)
17
+ - Bridge vs assignment consistency
18
+ - Triangulation distance distributions
19
+ - SVD spectrum, entropy, novelty
20
+ - Quaternion arm norms and composition statistics
21
+ - Cayley rotation ‖R-I‖ per layer
22
+ - FiLM gamma/beta deviation from identity
23
+ - Gate activation statistics
24
+ - Gradient norms per component type (including cm_gate)
25
+ - Weight norms per component type
26
+ - Geometric regularization: CV loss, spread loss per epoch
27
+
28
+ !pip install geolip-core torchvision tqdm tensorboard
29
+ """
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ import numpy as np
35
+ import time, json, math
36
+ from pathlib import Path
37
+ from tqdm.auto import tqdm
38
+ from torch.utils.tensorboard import SummaryWriter
39
+
40
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
+ print(f"Device: {device}")
42
+ if device.type == 'cuda':
43
+ print(f" GPU: {torch.cuda.get_device_name()}")
44
+ print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
45
+
46
+ # ═══════════════════════════════════════════════════════════════════════════════
47
+ # IMPORT TRANSFORMER
48
+ # ═══════════════════════════════════════════════════════════════════════════════
49
+
50
+ # Try geolip_core installed package first, fall back to local file
51
+ try:
52
+ from geolip_core.pipeline.components.geometric_transformer import (
53
+ GeometricTransformer, GeometricTransformerLayer,
54
+ CayleyOrthogonal, QuaternionCompose, FiLMLayer,
55
+ ContentAttention, GeometricAttention, CMValidatedGate,
56
+ TorchComponent, BaseTower,
57
+ anchor_neighborhood_cm,
58
+ )
59
+ print(" Imported from geolip_core (installed)")
60
+ except ImportError:
61
+ try:
62
+ from geometric_transformer import (
63
+ GeometricTransformer, GeometricTransformerLayer,
64
+ CayleyOrthogonal, QuaternionCompose, FiLMLayer,
65
+ ContentAttention, GeometricAttention, CMValidatedGate,
66
+ TorchComponent, BaseTower,
67
+ anchor_neighborhood_cm,
68
+ )
69
+ print(" Imported from local geometric_transformer.py")
70
+ except ImportError:
71
+ raise ImportError(
72
+ "Cannot find geometric_transformer. Place geometric_transformer.py "
73
+ "in the working directory or install geolip-core.")
74
+
75
+ torch.set_float32_matmul_precision('high')
76
+
77
+
78
+ # ═══════════════════════════════════════════════════════════════════════════════
79
+ # CONFIG
80
+ # ═══════════════════════════════════════════════════════════════════════════════
81
+
82
+ CONFIG = {
83
+ # Model
84
+ 'd_model': 256,
85
+ 'n_heads': 8,
86
+ 'n_layers': 8,
87
+ 'n_anchors': 128,
88
+ 'manifold_dim': 128,
89
+ 'n_comp': 4,
90
+ 'd_comp': 16,
91
+ 'context_dim': 64,
92
+ 'quat_dim': 32,
93
+ 'dropout': 0.1,
94
+ 'cm_neighbors': 3, # CM simplex neighbors
95
+
96
+ # Input stage
97
+ 'patch_size': 4,
98
+ 'img_size': 32,
99
+ 'in_channels': 3,
100
+ 'conv_channels': 64,
101
+ 'svd_rank': 16,
102
+
103
+ # Training
104
+ 'epochs': 100,
105
+ 'batch_size': 1024,
106
+ 'lr': 1e-3,
107
+ 'weight_decay': 0.05,
108
+ 'warmup_epochs': 5,
109
+ 'label_smoothing': 0.1,
110
+ 'num_workers': 8,
111
+
112
+ # Geometric regularization
113
+ 'cv_target': 0.215, # pentachoron band center
114
+ 'cv_weight': 0.1, # CV loss weight
115
+ 'spread_weight': 0.01, # anchor spread loss weight
116
+
117
+ # Augmentation — tuned for CM gate training
118
+ 'cutmix_alpha': 1.0, # CutMix beta distribution α (1.0 = uniform box sizes)
119
+ 'cutmix_prob': 0.5, # probability of applying CutMix per batch
120
+ 'random_erasing_p': 0.25, # probability of erasing per image
121
+
122
+ # InfoNCE memory bank on geometric residual
123
+ 'nce_bank_size': 4096, # queue size (0 to disable)
124
+ 'nce_temperature': 0.1, # InfoNCE temperature
125
+ 'nce_weight': 0.1, # loss weight
126
+
127
+ # Data
128
+ 'num_classes': 100,
129
+
130
+ # Logging
131
+ 'log_geo_every': 5, # full geometric analysis every N epochs
132
+ 'log_grads_every': 10, # gradient norms every N epochs
133
+ 'log_dir': 'runs/geo_cifar100',
134
+ }
135
+
136
+
137
+ # ═══════════════════════════════════════════════════════════════════════════════
138
+ # INPUT STAGE
139
+ # ═══════════════════════════════════════════════════════════════════════════════
140
+
141
+ try:
142
+ from geolip_core.core.input.svd import SVDObserver
143
+ _HAS_SVD = True
144
+ except ImportError:
145
+ _HAS_SVD = False
146
+
147
+ class SVDObserver(nn.Module):
148
+ """Fallback SVDObserver."""
149
+ def __init__(self, in_channels, svd_rank=24):
150
+ super().__init__()
151
+ self.svd_rank = svd_rank
152
+ self.to_svd = nn.Conv2d(in_channels, svd_rank, 1, bias=False)
153
+ self.register_buffer('ema_s', torch.ones(svd_rank))
154
+ self.register_buffer('ema_vh_flat', torch.eye(svd_rank).reshape(-1))
155
+ self.ema_momentum = 0.99
156
+
157
+ def extract_features(self, S, Vh):
158
+ B, k = S.shape
159
+ S_safe = S.clamp(min=1e-6)
160
+ s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8)
161
+ vh_diag = Vh.diagonal(dim1=-2, dim2=-1)
162
+ vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0)
163
+ s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True)
164
+ out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1)
165
+ return torch.where(torch.isfinite(out), out, torch.zeros_like(out))
166
+
167
+ def compute_novelty(self, S):
168
+ return S - self.ema_s.clone().unsqueeze(0)
169
+
170
+ def forward(self, x):
171
+ B, C, H, W = x.shape
172
+ h = self.to_svd(x)
173
+ h_flat = h.permute(0, 2, 3, 1).reshape(B, H * W, self.svd_rank)
174
+ with torch.amp.autocast('cuda', enabled=False):
175
+ with torch.no_grad():
176
+ gram = torch.bmm(h_flat.float().transpose(1, 2), h_flat.float())
177
+ evals, evecs = torch.linalg.eigh(gram)
178
+ evals = evals.flip(-1).clamp(min=1e-12)
179
+ S = evals.sqrt()
180
+ Vh = evecs.flip(-1).transpose(-2, -1)
181
+ S = torch.where(torch.isfinite(S), S, torch.ones_like(S))
182
+ Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh))
183
+ features = self.extract_features(S, Vh)
184
+ novelty = self.compute_novelty(S)
185
+ return S, Vh, features, novelty
186
+
187
+ @torch.no_grad()
188
+ def update_ema(self, S, Vh):
189
+ m = self.ema_momentum
190
+ self.ema_s.mul_(m).add_(S.detach().mean(0), alpha=1-m)
191
+ self.ema_vh_flat.mul_(m).add_(Vh.detach().mean(0).reshape(-1), alpha=1-m)
192
+
193
+ @property
194
+ def feature_dim(self):
195
+ return 2 * self.svd_rank + 2
196
+
197
+
198
+ class ConvSVDPatchEmbedding(TorchComponent):
199
+ """Input stage: conv frontend → SVDObserver → patch tokens."""
200
+ def __init__(self, name, img_size=32, patch_size=4, in_channels=3,
201
+ conv_channels=64, d_model=256, svd_rank=16):
202
+ super().__init__(name)
203
+ self.patch_size = patch_size
204
+ self.n_patches = (img_size // patch_size) ** 2
205
+ self.d_model = d_model
206
+ self.svd_rank = svd_rank
207
+
208
+ self.conv_frontend = nn.Sequential(
209
+ nn.Conv2d(in_channels, conv_channels, 3, padding=1, bias=False),
210
+ nn.BatchNorm2d(conv_channels), nn.GELU(),
211
+ nn.Conv2d(conv_channels, conv_channels, 3, padding=1, bias=False),
212
+ nn.BatchNorm2d(conv_channels), nn.GELU(),
213
+ )
214
+ self.svd_observer = SVDObserver(conv_channels, svd_rank)
215
+ self.patch_proj = nn.Conv2d(
216
+ conv_channels, d_model, kernel_size=patch_size,
217
+ stride=patch_size, bias=False)
218
+ self.patch_norm = nn.LayerNorm(d_model)
219
+
220
+ svd_feat_dim = self.svd_observer.feature_dim
221
+ self.svd_to_gamma = nn.Linear(svd_feat_dim, d_model)
222
+ self.svd_to_beta = nn.Linear(svd_feat_dim, d_model)
223
+ nn.init.zeros_(self.svd_to_gamma.weight); nn.init.ones_(self.svd_to_gamma.bias)
224
+ nn.init.zeros_(self.svd_to_beta.weight); nn.init.zeros_(self.svd_to_beta.bias)
225
+
226
+ self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
227
+ self.pos_embed = nn.Parameter(
228
+ torch.randn(1, self.n_patches + 1, d_model) * 0.02)
229
+
230
+ def forward(self, x):
231
+ B = x.shape[0]
232
+ feat = self.conv_frontend(x)
233
+ S, Vh, svd_features, novelty = self.svd_observer(feat)
234
+ tokens = self.patch_proj(feat)
235
+ tokens = tokens.flatten(2).transpose(1, 2)
236
+ tokens = self.patch_norm(tokens)
237
+ gamma = self.svd_to_gamma(svd_features).unsqueeze(1)
238
+ beta = self.svd_to_beta(svd_features).unsqueeze(1)
239
+ tokens = gamma * tokens + beta
240
+ cls = self.cls_token.expand(B, -1, -1)
241
+ tokens = torch.cat([cls, tokens], dim=1)
242
+ tokens = tokens + self.pos_embed
243
+ svd_state = {
244
+ 'singular_values': S, 'Vh': Vh,
245
+ 'svd_features': svd_features, 'novelty': novelty,
246
+ }
247
+ if self.training:
248
+ self.svd_observer.update_ema(S, Vh)
249
+ return tokens, svd_state
250
+
251
+
252
+ # ═══════════════════════════════════════════════════════════════════════════════
253
+ # CLASSIFIER ( uses GeometricTransformer with CM gates)
254
+ # ═══════════════════════════════════════════════════════════════════════════════
255
+
256
+ class GeoViTClassifier(BaseTower):
257
+ """Geometric Vision Transformer for classification.
258
+
259
+ Wraps ConvSVDPatchEmbedding + GeometricTransformer + task head.
260
+ Exposes geometric_losses() for regularization during training.
261
+ """
262
+ def __init__(self, name, config):
263
+ super().__init__(name)
264
+ self.config = config
265
+
266
+ self.attach('patch_embed', ConvSVDPatchEmbedding(
267
+ 'patch_embed', img_size=config['img_size'],
268
+ patch_size=config['patch_size'], in_channels=config['in_channels'],
269
+ conv_channels=config['conv_channels'], d_model=config['d_model'],
270
+ svd_rank=config['svd_rank'],
271
+ ))
272
+ self.attach('transformer', GeometricTransformer(
273
+ 'geo_cifar', d_model=config['d_model'], n_heads=config['n_heads'],
274
+ n_layers=config['n_layers'], n_anchors=config['n_anchors'],
275
+ manifold_dim=config['manifold_dim'], n_comp=config['n_comp'],
276
+ d_comp=config['d_comp'], context_dim=config['context_dim'],
277
+ quat_dim=config['quat_dim'], dropout=config['dropout'],
278
+ cm_neighbors=config.get('cm_neighbors', 3),
279
+ nce_bank_size=config.get('nce_bank_size', 4096),
280
+ nce_temperature=config.get('nce_temperature', 0.1),
281
+ ))
282
+ self.attach('head', nn.Sequential(
283
+ nn.LayerNorm(config['d_model']),
284
+ nn.Linear(config['d_model'], config['d_model']),
285
+ nn.GELU(), nn.Dropout(config['dropout']),
286
+ nn.Linear(config['d_model'], config['num_classes']),
287
+ ))
288
+
289
+ def forward(self, x, return_geo_state=False):
290
+ tokens, svd_state = self['patch_embed'](x)
291
+ if return_geo_state:
292
+ features, geo_states = self['transformer'](tokens, return_geo_state=True)
293
+ else:
294
+ features = self['transformer'](tokens)
295
+ cls_out = features[:, 0]
296
+ logits = self['head'](cls_out)
297
+ if return_geo_state:
298
+ return logits, geo_states, svd_state
299
+ return logits
300
+
301
+ def geometric_losses(self):
302
+ """Delegate to transformer's built-in geometric regularization."""
303
+ return self['transformer'].geometric_losses(
304
+ cv_target=self.config.get('cv_target', 0.215),
305
+ cv_weight=self.config.get('cv_weight', 0.1),
306
+ spread_weight=self.config.get('spread_weight', 0.01),
307
+ )
308
+
309
+ def infonce_loss(self):
310
+ """InfoNCE contrastive loss on CLS token's geometric residual.
311
+ Uses cached residual from last forward pass."""
312
+ return self['transformer'].infonce_loss()
313
+
314
+ def update_nce_bank(self):
315
+ """Enqueue current batch's residuals. Call AFTER backward."""
316
+ self['transformer'].update_nce_bank()
317
+
318
+ def anchor_diagnostics(self):
319
+ """Delegate to transformer's anchor diagnostics."""
320
+ return self['transformer'].anchor_diagnostics()
321
+
322
+
323
+ # ═══════════════════════════════════════════════════════════════════════════════
324
+ # GEOMETRIC ANALYSIS BATTERY ( includes CM diagnostics)
325
+ # ═══════════════════════════════════════════════════════════════════════════════
326
+
327
+ @torch.no_grad()
328
+ def compute_cv(points):
329
+ """Coefficient of variation on S^(d-1).
330
+ CV = std(pairwise_cosine_distances) / mean(pairwise_cosine_distances)
331
+ Pentachoron band: CV ∈ [0.20, 0.23].
332
+ """
333
+ points = F.normalize(points.float(), dim=-1)
334
+ cos_sim = points @ points.T
335
+ n = points.shape[0]
336
+ idx = torch.triu_indices(n, n, offset=1, device=points.device)
337
+ pairwise_dist = 1.0 - cos_sim[idx[0], idx[1]]
338
+ mean_d = pairwise_dist.mean()
339
+ std_d = pairwise_dist.std()
340
+ cv = (std_d / (mean_d + 1e-8)).item()
341
+ return cv, mean_d.item(), std_d.item()
342
+
343
+
344
+ @torch.no_grad()
345
+ def log_geometric_analysis(model, writer, epoch, test_loader, device, config):
346
+ """Full geometric analysis battery with CM diagnostics."""
347
+ model.eval()
348
+
349
+ images, labels = next(iter(test_loader))
350
+ images = images[:min(64, images.shape[0])].to(device)
351
+ labels = labels[:min(64, labels.shape[0])].to(device)
352
+
353
+ logits, geo_states, svd_state = model(images, return_geo_state=True)
354
+
355
+ n_layers = len(geo_states)
356
+ pred = logits.argmax(1)
357
+ batch_acc = (pred == labels).float().mean().item()
358
+ writer.add_scalar('analysis/batch_accuracy', batch_acc, epoch)
359
+
360
+ # ─── SVD Input Stage ───
361
+ S = svd_state['singular_values']
362
+ s_norm = S / (S.sum(dim=-1, keepdim=True) + 1e-8)
363
+ s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1)
364
+ novelty = svd_state['novelty']
365
+
366
+ writer.add_scalar('svd/entropy_mean', s_ent.mean().item(), epoch)
367
+ writer.add_scalar('svd/entropy_std', s_ent.std().item(), epoch)
368
+ writer.add_scalar('svd/novelty_norm', novelty.norm(dim=-1).mean().item(), epoch)
369
+ writer.add_scalar('svd/top1_ratio', (S[:, 0] / (S.sum(-1) + 1e-8)).mean().item(), epoch)
370
+ writer.add_scalar('svd/condition_number',
371
+ (S[:, 0] / (S[:, -1].clamp(min=1e-8))).mean().item(), epoch)
372
+ for k in range(min(S.shape[1], 5)):
373
+ writer.add_scalar(f'svd/S_{k}', S[:, k].mean().item(), epoch)
374
+
375
+ # SVD FiLM deviation
376
+ pe = model['patch_embed']
377
+ writer.add_scalar('svd_film/gamma_weight_norm', pe.svd_to_gamma.weight.data.norm().item(), epoch)
378
+ writer.add_scalar('svd_film/gamma_bias_dev_from_1',
379
+ (pe.svd_to_gamma.bias.data - 1.0).abs().mean().item(), epoch)
380
+ writer.add_scalar('svd_film/beta_weight_norm', pe.svd_to_beta.weight.data.norm().item(), epoch)
381
+ writer.add_scalar('svd_film/beta_bias_norm', pe.svd_to_beta.bias.data.abs().mean().item(), epoch)
382
+
383
+ # ─── Anchor Diagnostics (built-in) ───
384
+ anchor_diag = model.anchor_diagnostics()
385
+ for layer_name, d in anchor_diag.items():
386
+ for k, v in d.items():
387
+ writer.add_scalar(f'anchor_diag/{layer_name}_{k}', v, epoch)
388
+
389
+ # ─── Per-Layer Geometric Analysis ───
390
+ for i, gs in enumerate(geo_states):
391
+ prefix = f'layer_{i}'
392
+
393
+ # === CV — pentachoron band metric ===
394
+ emb = gs['embedding']
395
+ # Anchor CV
396
+ transformer = model['transformer']
397
+ layer = transformer[f'layer_{i}']
398
+ anchors = F.normalize(
399
+ layer['observer'].association.constellation.anchors, dim=-1)
400
+ cv_anchors, mean_d_anchors, std_d_anchors = compute_cv(anchors)
401
+ writer.add_scalar(f'{prefix}/cv_anchors', cv_anchors, epoch)
402
+ writer.add_scalar(f'{prefix}/anchor_mean_dist', mean_d_anchors, epoch)
403
+ writer.add_scalar(f'{prefix}/anchor_std_dist', std_d_anchors, epoch)
404
+
405
+ # Embedding CV
406
+ emb_flat = emb.reshape(-1, emb.shape[-1])
407
+ n_sample = min(512, emb_flat.shape[0])
408
+ idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample]
409
+ cv_emb, mean_d_emb, std_d_emb = compute_cv(emb_flat[idx])
410
+ writer.add_scalar(f'{prefix}/cv_embeddings', cv_emb, epoch)
411
+ writer.add_scalar(f'{prefix}/embedding_mean_dist', mean_d_emb, epoch)
412
+
413
+ # === CM Gate Diagnostics ===
414
+ gate_info = gs.get('gate_info', {})
415
+ gate_values = gs.get('gate_values')
416
+ cm_quality = gs.get('cm_quality')
417
+
418
+ if gate_info:
419
+ writer.add_scalar(f'{prefix}/cm_active_anchors',
420
+ gate_info.get('active', 0), epoch)
421
+ writer.add_scalar(f'{prefix}/cm_gate_mean',
422
+ gate_info.get('gate_mean', 0), epoch)
423
+ writer.add_scalar(f'{prefix}/cm_positive_frac',
424
+ gate_info.get('cm_positive_frac', 0), epoch)
425
+
426
+ if gate_values is not None:
427
+ gv = gate_values
428
+ writer.add_scalar(f'{prefix}/gate_values_min', gv.min().item(), epoch)
429
+ writer.add_scalar(f'{prefix}/gate_values_max', gv.max().item(), epoch)
430
+ writer.add_scalar(f'{prefix}/gate_values_std', gv.std().item(), epoch)
431
+ # Per-anchor gate mean (which anchors are consistently open/closed)
432
+ gv_per_anchor = gv.mean(dim=0).mean(dim=0) # average over B and L
433
+ writer.add_scalar(f'{prefix}/gate_anchor_spread',
434
+ gv_per_anchor.std().item(), epoch)
435
+ # Fraction of positions with >50% anchors open
436
+ if gv.dim() == 3:
437
+ pos_open_frac = (gv.mean(dim=-1) > 0.5).float().mean().item()
438
+ else:
439
+ pos_open_frac = (gv > 0.5).float().mean().item()
440
+ writer.add_scalar(f'{prefix}/gate_positions_open_frac', pos_open_frac, epoch)
441
+
442
+ if cm_quality is not None:
443
+ writer.add_scalar(f'{prefix}/cm_quality_mean', cm_quality.mean().item(), epoch)
444
+ writer.add_scalar(f'{prefix}/cm_quality_std', cm_quality.std().item(), epoch)
445
+ writer.add_scalar(f'{prefix}/cm_quality_min', cm_quality.min().item(), epoch)
446
+
447
+ # === Stream Agreement ===
448
+ content = gs['content']
449
+ geometric = gs['geometric']
450
+ agreement = F.cosine_similarity(
451
+ content.reshape(-1, content.shape[-1]),
452
+ geometric.reshape(-1, geometric.shape[-1]), dim=-1)
453
+ writer.add_scalar(f'{prefix}/stream_agreement_mean', agreement.mean().item(), epoch)
454
+ writer.add_scalar(f'{prefix}/stream_agreement_std', agreement.std().item(), epoch)
455
+
456
+ writer.add_scalar(f'{prefix}/content_norm', content.norm(dim=-1).mean().item(), epoch)
457
+ writer.add_scalar(f'{prefix}/geometric_norm', geometric.norm(dim=-1).mean().item(), epoch)
458
+
459
+ # === Disagreement arm analysis ===
460
+ disagree = content - geometric
461
+ agree = content * geometric
462
+ writer.add_scalar(f'{prefix}/disagree_norm', disagree.norm(dim=-1).mean().item(), epoch)
463
+ writer.add_scalar(f'{prefix}/agree_norm', agree.norm(dim=-1).mean().item(), epoch)
464
+
465
+ # === Anchor Utilization ===
466
+ tri = gs['triangulation']
467
+ assignment = gs['assignment']
468
+ nearest = gs['nearest']
469
+ n_anchors = tri.shape[-1]
470
+
471
+ nearest_flat = nearest.reshape(-1)
472
+ counts = torch.bincount(nearest_flat, minlength=n_anchors).float()
473
+ total_assignments = counts.sum()
474
+
475
+ probs = counts / (total_assignments + 1e-8)
476
+ anchor_entropy = -(probs * torch.log(probs.clamp(min=1e-8))).sum().item()
477
+ max_entropy = math.log(n_anchors)
478
+ writer.add_scalar(f'{prefix}/anchor_entropy_normalized',
479
+ anchor_entropy / (max_entropy + 1e-8), epoch)
480
+ active = (counts > 0).sum().item()
481
+ writer.add_scalar(f'{prefix}/anchors_active', active, epoch)
482
+ writer.add_scalar(f'{prefix}/anchors_active_frac', active / n_anchors, epoch)
483
+ dead = (counts == 0).sum().item()
484
+ writer.add_scalar(f'{prefix}/anchors_dead', dead, epoch)
485
+
486
+ # === Triangulation Statistics ===
487
+ writer.add_scalar(f'{prefix}/tri_mean', tri.mean().item(), epoch)
488
+ writer.add_scalar(f'{prefix}/tri_std', tri.std().item(), epoch)
489
+
490
+ # === Soft Assignment Statistics ===
491
+ assign_ent = -(assignment * torch.log(assignment.clamp(min=1e-8))).sum(-1)
492
+ writer.add_scalar(f'{prefix}/assignment_entropy_mean', assign_ent.mean().item(), epoch)
493
+ writer.add_scalar(f'{prefix}/assignment_max_prob',
494
+ assignment.max(dim=-1).values.mean().item(), epoch)
495
+
496
+ # === Patchwork Statistics (now from CM-validated triangulation) ===
497
+ pw = gs['patchwork']
498
+ writer.add_scalar(f'{prefix}/patchwork_norm', pw.norm(dim=-1).mean().item(), epoch)
499
+ writer.add_scalar(f'{prefix}/patchwork_std', pw.std().item(), epoch)
500
+ pw_sparsity = (pw.abs() < 0.01).float().mean().item()
501
+ writer.add_scalar(f'{prefix}/patchwork_sparsity', pw_sparsity, epoch)
502
+
503
+ # === Bridge Consistency ===
504
+ bridge = gs['bridge']
505
+ bridge_soft = F.softmax(bridge, dim=-1)
506
+ bridge_assign_kl = F.kl_div(
507
+ bridge_soft.log().reshape(-1, n_anchors),
508
+ assignment.reshape(-1, n_anchors),
509
+ reduction='batchmean', log_target=False)
510
+ writer.add_scalar(f'{prefix}/bridge_assignment_kl', bridge_assign_kl.item(), epoch)
511
+
512
+ # === Quaternion Composition ===
513
+ composed = gs['composed']
514
+ writer.add_scalar(f'{prefix}/composed_norm', composed.norm(dim=-1).mean().item(), epoch)
515
+
516
+ # === Geo Context ===
517
+ geo_ctx = gs['geo_ctx']
518
+ writer.add_scalar(f'{prefix}/geo_ctx_norm', geo_ctx.norm(dim=-1).mean().item(), epoch)
519
+
520
+ # === Geometric Residual Stream (CM-conditioned) ===
521
+ geo_res = gs.get('geo_residual')
522
+ if geo_res is not None:
523
+ res_norms = geo_res.norm(dim=-1)
524
+ writer.add_scalar(f'{prefix}/geo_res_norm', res_norms.mean().item(), epoch)
525
+ writer.add_scalar(f'{prefix}/geo_res_std', geo_res.std().item(), epoch)
526
+ writer.add_scalar(f'{prefix}/geo_res_sparsity',
527
+ (geo_res.abs() < 0.01).float().mean().item(), epoch)
528
+ # Cross-position consistency
529
+ geo_res_flat = geo_res.reshape(-1, geo_res.shape[-1])
530
+ n_s = min(256, geo_res_flat.shape[0])
531
+ idx_s = torch.randperm(geo_res_flat.shape[0], device=geo_res.device)[:n_s]
532
+ sampled = F.normalize(geo_res_flat[idx_s], dim=-1)
533
+ cos_mat = sampled @ sampled.T
534
+ triu = torch.triu_indices(n_s, n_s, offset=1, device=geo_res.device)
535
+ writer.add_scalar(f'{prefix}/geo_res_consistency',
536
+ cos_mat[triu[0], triu[1]].mean().item(), epoch)
537
+
538
+ # ─── Cayley Rotation Analysis ───
539
+ for name, mod in model.named_modules():
540
+ if isinstance(mod, CayleyOrthogonal):
541
+ R = mod.get_rotation()
542
+ I = torch.eye(R.shape[0], device=R.device)
543
+ r_dist = (R - I).norm().item()
544
+ clean_name = name.replace('.', '_')
545
+ writer.add_scalar(f'cayley/{clean_name}_R_minus_I', r_dist, epoch)
546
+
547
+ # ─── FiLM Layer Analysis ───
548
+ film_idx = 0
549
+ for name, mod in model.named_modules():
550
+ if isinstance(mod, FiLMLayer):
551
+ g_b = mod.to_gamma.bias.data
552
+ b_b = mod.to_beta.bias.data
553
+ writer.add_scalar(f'film/{film_idx}_gamma_dev',
554
+ (g_b - 1.0).abs().mean().item(), epoch)
555
+ writer.add_scalar(f'film/{film_idx}_beta_dev',
556
+ b_b.abs().mean().item(), epoch)
557
+ film_idx += 1
558
+
559
+ # ─── Cross-Layer Trajectories ───
560
+ cv_trajectory = []
561
+ cm_quality_trajectory = []
562
+ res_norms = []
563
+ bridge_kls = []
564
+
565
+ for i, gs in enumerate(geo_states):
566
+ # CV
567
+ emb = gs['embedding']
568
+ emb_flat = emb.reshape(-1, emb.shape[-1])
569
+ n_sample = min(512, emb_flat.shape[0])
570
+ idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample]
571
+ cv, _, _ = compute_cv(emb_flat[idx])
572
+ cv_trajectory.append(cv)
573
+
574
+ # CM quality
575
+ cm_q = gs.get('cm_quality')
576
+ if cm_q is not None:
577
+ cm_quality_trajectory.append(cm_q.mean().item())
578
+
579
+ # Geo residual norms
580
+ geo_res = gs.get('geo_residual')
581
+ if geo_res is not None:
582
+ res_norms.append(geo_res.norm(dim=-1).mean().item())
583
+
584
+ # Bridge KL
585
+ n_anchors = gs['assignment'].shape[-1]
586
+ bridge_soft = F.softmax(gs['bridge'], dim=-1)
587
+ bkl = F.kl_div(
588
+ bridge_soft.log().reshape(-1, n_anchors),
589
+ gs['assignment'].reshape(-1, n_anchors),
590
+ reduction='batchmean', log_target=False).item()
591
+ bridge_kls.append(bkl)
592
+
593
+ # CV trajectory
594
+ writer.add_scalar('cv/trajectory_mean', np.mean(cv_trajectory), epoch)
595
+ writer.add_scalar('cv/trajectory_std', np.std(cv_trajectory), epoch)
596
+ in_band = sum(1 for cv in cv_trajectory if 0.20 <= cv <= 0.23)
597
+ writer.add_scalar('cv/layers_in_pentachoron_band', in_band, epoch)
598
+ writer.add_scalar('cv/layers_in_band_frac', in_band / len(cv_trajectory), epoch)
599
+
600
+ # CM quality trajectory
601
+ if cm_quality_trajectory:
602
+ writer.add_scalar('cm/quality_trajectory_mean',
603
+ np.mean(cm_quality_trajectory), epoch)
604
+ writer.add_scalar('cm/quality_trajectory_std',
605
+ np.std(cm_quality_trajectory), epoch)
606
+ writer.add_scalar('cm/quality_min_layer',
607
+ np.min(cm_quality_trajectory), epoch)
608
+ writer.add_scalar('cm/quality_max_layer',
609
+ np.max(cm_quality_trajectory), epoch)
610
+
611
+ # Geometric residual trajectory
612
+ if res_norms:
613
+ writer.add_scalar('geo_res/trajectory_start', res_norms[0], epoch)
614
+ writer.add_scalar('geo_res/trajectory_end', res_norms[-1], epoch)
615
+ writer.add_scalar('geo_res/accumulation_ratio',
616
+ res_norms[-1] / (res_norms[0] + 1e-8), epoch)
617
+ growth = [res_norms[j+1] - res_norms[j] for j in range(len(res_norms)-1)]
618
+ writer.add_scalar('geo_res/growth_mean', np.mean(growth), epoch)
619
+ writer.add_scalar('geo_res/growth_std', np.std(growth), epoch)
620
+
621
+ # Cooperation analysis (includes CM quality)
622
+ if len(res_norms) >= 4:
623
+ cv_corr = float(np.corrcoef(res_norms, cv_trajectory)[0, 1])
624
+ bkl_corr = float(np.corrcoef(res_norms, bridge_kls)[0, 1])
625
+ writer.add_scalar('cooperation/geo_res_vs_cv', cv_corr, epoch)
626
+ writer.add_scalar('cooperation/geo_res_vs_bridge_kl', bkl_corr, epoch)
627
+
628
+ if len(cm_quality_trajectory) == len(res_norms):
629
+ cm_corr = float(np.corrcoef(
630
+ res_norms, cm_quality_trajectory)[0, 1])
631
+ writer.add_scalar('cooperation/geo_res_vs_cm_quality', cm_corr, epoch)
632
+ # CM vs CV: do layers with better CM quality also have better CV?
633
+ cm_cv_corr = float(np.corrcoef(
634
+ cm_quality_trajectory, cv_trajectory)[0, 1])
635
+ writer.add_scalar('cooperation/cm_quality_vs_cv', cm_cv_corr, epoch)
636
+
637
+ return {
638
+ 'batch_acc': batch_acc,
639
+ 'cv_trajectory': cv_trajectory,
640
+ 'cm_quality_trajectory': cm_quality_trajectory,
641
+ 'res_norms': res_norms,
642
+ 'bridge_kls': bridge_kls,
643
+ }
644
+
645
+
646
+ @torch.no_grad()
647
+ def log_gradient_norms(model, writer, epoch):
648
+ """Log gradient norms per component type (includes cm_gate)."""
649
+ type_grads = {}
650
+ for name, param in model.named_parameters():
651
+ if param.grad is not None:
652
+ grad_norm = param.grad.norm().item()
653
+ if 'projection' in name and 'proj' in name:
654
+ key = 'manifold_proj'
655
+ elif 'cm_gate' in name:
656
+ key = 'cm_gate'
657
+ elif 'observer' in name or 'constellation' in name or 'anchor' in name:
658
+ key = 'constellation'
659
+ elif 'context' in name:
660
+ key = 'geo_context'
661
+ elif 'content' in name:
662
+ key = 'content_attn'
663
+ elif 'geometric' in name and 'film' not in name:
664
+ key = 'geo_attn'
665
+ elif 'film' in name:
666
+ key = 'film'
667
+ elif 'rotation' in name or 'cayley' in name or 'A_upper' in name:
668
+ key = 'cayley'
669
+ elif 'compose' in name or 'quat' in name or 'proj_w' in name:
670
+ key = 'quaternion'
671
+ elif 'decode' in name:
672
+ key = 'decode'
673
+ elif 'gate' in name:
674
+ key = 'gate'
675
+ elif 'conv' in name or 'patch' in name:
676
+ key = 'input_stage'
677
+ elif 'head' in name:
678
+ key = 'head'
679
+ elif 'svd' in name:
680
+ key = 'svd'
681
+ elif 'geo_proj' in name:
682
+ key = 'geo_residual_proj'
683
+ else:
684
+ key = 'other'
685
+
686
+ if key not in type_grads:
687
+ type_grads[key] = []
688
+ type_grads[key].append(grad_norm)
689
+
690
+ for key, norms in type_grads.items():
691
+ writer.add_scalar(f'grad_norm/{key}_mean', np.mean(norms), epoch)
692
+ writer.add_scalar(f'grad_norm/{key}_max', np.max(norms), epoch)
693
+
694
+ total = sum(p.grad.norm().item() ** 2
695
+ for p in model.parameters() if p.grad is not None) ** 0.5
696
+ writer.add_scalar('grad_norm/total', total, epoch)
697
+
698
+
699
+ @torch.no_grad()
700
+ def log_weight_norms(model, writer, epoch):
701
+ """Log weight norms per component type."""
702
+ for name, param in model.named_parameters():
703
+ if 'A_upper' in name:
704
+ clean = name.replace('.', '_')
705
+ writer.add_scalar(f'weights/{clean}_norm', param.norm().item(), epoch)
706
+
707
+
708
+ # ═══════════════════════════════════════════════════════════════════════════════
709
+ # DATA
710
+ # ═══════════════════════════════════════════════════════════════════════════════
711
+
712
+ def get_dataloaders(config):
713
+ import torchvision
714
+ import torchvision.transforms as T
715
+
716
+ # Augmentation pipeline tuned for geometric transformer:
717
+ # TrivialAugmentWide: continuous severity spectrum of geometric + photometric
718
+ # transforms. Exercises CM gate across full quality range — mild distortion
719
+ # keeps CM high, severe distortion creates partially-degenerate simplices.
720
+ # RandomErasing: creates degenerate manifold projections (zero-volume CM simplices).
721
+ # Trains CM gate to close on corrupted regions.
722
+ # CutMix applied at batch level in train_epoch (not here).
723
+ train_transform = T.Compose([
724
+ T.RandomCrop(32, padding=4),
725
+ T.RandomHorizontalFlip(),
726
+ T.TrivialAugmentWide(),
727
+ T.ToTensor(),
728
+ T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
729
+ T.RandomErasing(p=config.get('random_erasing_p', 0.25),
730
+ scale=(0.02, 0.33), ratio=(0.3, 3.3)),
731
+ ])
732
+ test_transform = T.Compose([
733
+ T.ToTensor(),
734
+ T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
735
+ ])
736
+
737
+ train_ds = torchvision.datasets.CIFAR100(
738
+ root='./data', train=True, download=True, transform=train_transform)
739
+ test_ds = torchvision.datasets.CIFAR100(
740
+ root='./data', train=False, download=True, transform=test_transform)
741
+
742
+ train_loader = torch.utils.data.DataLoader(
743
+ train_ds, batch_size=config['batch_size'], shuffle=True,
744
+ num_workers=config['num_workers'], pin_memory=True, drop_last=True)
745
+ test_loader = torch.utils.data.DataLoader(
746
+ test_ds, batch_size=config['batch_size'], shuffle=False,
747
+ num_workers=config['num_workers'], pin_memory=True)
748
+
749
+ return train_loader, test_loader
750
+
751
+
752
+ # ═══════════════════════════════════════════════════════════════════════════════
753
+ # CUTMIX — batch-level augmentation for CM gate boundary training
754
+ # ═══════════════════════════════════════════════════════════════════════════════
755
+
756
+ def cutmix_batch(images, labels, alpha=1.0):
757
+ """Apply CutMix to a batch. Returns mixed images + label pairs + lambda.
758
+
759
+ CutMix replaces a rectangular region of image A with image B.
760
+ Positions inside each region have coherent geometry — valid CM simplices.
761
+ The boundary between regions has mixed geometric context — the CM gate
762
+ should learn to suppress these positions.
763
+
764
+ Args:
765
+ images: (B, C, H, W) batch
766
+ labels: (B,) integer labels
767
+ alpha: Beta distribution parameter (1.0 = uniform box sizes)
768
+
769
+ Returns:
770
+ images: (B, C, H, W) mixed batch (modified in-place)
771
+ labels_a: (B,) labels for region A
772
+ labels_b: (B,) labels for region B
773
+ lam: float, fraction of image A remaining
774
+ """
775
+ lam = np.random.beta(alpha, alpha)
776
+ B = images.size(0)
777
+ idx = torch.randperm(B, device=images.device)
778
+
779
+ H, W = images.shape[2], images.shape[3]
780
+ cut_ratio = (1.0 - lam) ** 0.5
781
+ cw = int(W * cut_ratio)
782
+ ch = int(H * cut_ratio)
783
+ cx = np.random.randint(W)
784
+ cy = np.random.randint(H)
785
+ x1 = max(cx - cw // 2, 0); x2 = min(cx + cw // 2, W)
786
+ y1 = max(cy - ch // 2, 0); y2 = min(cy + ch // 2, H)
787
+
788
+ images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2]
789
+ lam_actual = 1.0 - (x2 - x1) * (y2 - y1) / (W * H)
790
+ return images, labels, labels[idx], lam_actual
791
+
792
+
793
+ # ═══════════════════════════════════════════════════════════════════════════════
794
+ # TRAINING (geometric losses + CutMix integrated)
795
+ # ═══════════════════════════════════════════════════════════════════════════════
796
+
797
+ def train_epoch(model, loader, optimizer, scheduler, epoch, config, writer):
798
+ model.train()
799
+ total_loss = 0
800
+ total_geo_loss = 0
801
+ total_nce_loss = 0
802
+ correct = 0
803
+ total = 0
804
+
805
+ cutmix_alpha = config.get('cutmix_alpha', 1.0)
806
+ cutmix_prob = config.get('cutmix_prob', 0.5)
807
+ label_smoothing = config.get('label_smoothing', 0.1)
808
+ nce_weight = config.get('nce_weight', 0.1)
809
+ criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
810
+
811
+ for batch_idx, (images, labels) in enumerate(loader):
812
+ images = images.to(device)
813
+ labels = labels.to(device)
814
+
815
+ # CutMix: applied probabilistically per batch
816
+ use_cutmix = np.random.rand() < cutmix_prob
817
+ if use_cutmix:
818
+ images, labels_a, labels_b, lam = cutmix_batch(
819
+ images, labels, alpha=cutmix_alpha)
820
+ logits = model(images)
821
+ ce_loss = lam * criterion(logits, labels_a) + \
822
+ (1.0 - lam) * criterion(logits, labels_b)
823
+ # Accuracy: count correct if matches either label
824
+ pred = logits.argmax(1)
825
+ correct += (lam * (pred == labels_a).float() +
826
+ (1.0 - lam) * (pred == labels_b).float()).sum().item()
827
+ else:
828
+ logits = model(images)
829
+ ce_loss = criterion(logits, labels)
830
+ correct += (logits.argmax(1) == labels).sum().item()
831
+
832
+ # Geometric regularization — CV target + anchor spread
833
+ geo_losses = model.geometric_losses()
834
+ geo_loss = geo_losses.get('geo_total', torch.tensor(0.0, device=device))
835
+
836
+ # InfoNCE on geometric residual — discriminative pressure
837
+ nce_losses = model.infonce_loss()
838
+ nce_loss = nce_losses.get('nce', torch.tensor(0.0, device=device))
839
+
840
+ loss = ce_loss + geo_loss + nce_weight * nce_loss
841
+
842
+ optimizer.zero_grad()
843
+ loss.backward()
844
+
845
+ # Enqueue AFTER backward — detached residuals go into bank
846
+ model.update_nce_bank()
847
+
848
+ # Log gradient norms periodically
849
+ if epoch % config['log_grads_every'] == 0 and batch_idx == 0:
850
+ log_gradient_norms(model, writer, epoch)
851
+
852
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
853
+ optimizer.step()
854
+ if scheduler is not None:
855
+ scheduler.step()
856
+
857
+ total_loss += ce_loss.item() * images.size(0)
858
+ total_geo_loss += geo_loss.item() * images.size(0)
859
+ total_nce_loss += nce_loss.item() * images.size(0)
860
+ total += images.size(0)
861
+
862
+ avg_ce = total_loss / total
863
+ avg_geo = total_geo_loss / total
864
+ avg_nce = total_nce_loss / total
865
+ return avg_ce, avg_geo, avg_nce, correct / total
866
+
867
+
868
+ @torch.no_grad()
869
+ def evaluate(model, loader):
870
+ model.eval()
871
+ correct = 0
872
+ total = 0
873
+ for images, labels in loader:
874
+ images = images.to(device)
875
+ labels = labels.to(device)
876
+ logits = model(images)
877
+ correct += (logits.argmax(1) == labels).sum().item()
878
+ total += images.size(0)
879
+ return correct / total
880
+
881
+
882
+ def main():
883
+ config = CONFIG.copy()
884
+
885
+ print("=" * 60)
886
+ print(" Geometric Transformer — CIFAR-100 (CM-Validated)")
887
+ print(f" Input: conv({config['in_channels']}→{config['conv_channels']}) + "
888
+ f"SVD(rank={config['svd_rank']}) + "
889
+ f"{config['patch_size']}×{config['patch_size']} patches = "
890
+ f"{(config['img_size']//config['patch_size'])**2} tokens + CLS")
891
+ print(f" Model: d={config['d_model']}, heads={config['n_heads']}, "
892
+ f"layers={config['n_layers']}, anchors={config['n_anchors']}")
893
+ print(f" CM: neighbors={config['cm_neighbors']}, "
894
+ f"cv_target={config['cv_target']}, "
895
+ f"cv_weight={config['cv_weight']}, "
896
+ f"spread_weight={config['spread_weight']}")
897
+ print(f" Aug: TrivialAugmentWide + CutMix(α={config['cutmix_alpha']}, "
898
+ f"p={config['cutmix_prob']}) + "
899
+ f"RandomErasing(p={config['random_erasing_p']})")
900
+ print(f" NCE: bank={config['nce_bank_size']}, "
901
+ f"temp={config['nce_temperature']}, "
902
+ f"weight={config['nce_weight']}")
903
+ print("=" * 60)
904
+
905
+ writer = SummaryWriter(config['log_dir'])
906
+ writer.add_text('config', json.dumps(config, indent=2))
907
+
908
+ print("\nLoading CIFAR-100...")
909
+ train_loader, test_loader = get_dataloaders(config)
910
+ print(f" Train: {len(train_loader.dataset):,} | Test: {len(test_loader.dataset):,}")
911
+
912
+ model = GeoViTClassifier('geo_vit_cifar100', config)
913
+ if hasattr(model, 'network_to'):
914
+ model.network_to(device=device, strict=False)
915
+ else:
916
+ model = model.to(device)
917
+
918
+ n_params = sum(p.numel() for p in model.parameters())
919
+ n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
920
+ print(f"\n Total params: {n_params:,}")
921
+ print(f" Trainable params: {n_trainable:,}")
922
+
923
+ for name, module in model.named_children():
924
+ n = sum(p.numel() for p in module.parameters())
925
+ if n > 0:
926
+ print(f" {name:<20s}: {n:,}")
927
+
928
+ writer.add_scalar('model/total_params', n_params, 0)
929
+
930
+ # Initial anchor diagnostics
931
+ print(f"\n Initial anchor diagnostics:")
932
+ diag = model.anchor_diagnostics()
933
+ for layer_name, d in diag.items():
934
+ print(f" {layer_name}: cv={d['anchor_cv']:.4f}, "
935
+ f"cm_pos={d['cm_positive_frac']:.3f}, "
936
+ f"min_dist={d['min_pairwise_dist']:.4f}")
937
+
938
+ # Optimizer + scheduler
939
+ optimizer = torch.optim.AdamW(
940
+ model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
941
+
942
+ total_steps = config['epochs'] * len(train_loader)
943
+ warmup_steps = config['warmup_epochs'] * len(train_loader)
944
+
945
+ def lr_lambda(step):
946
+ if step < warmup_steps:
947
+ return step / warmup_steps
948
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
949
+ return 0.5 * (1 + np.cos(np.pi * progress))
950
+
951
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
952
+
953
+ print(f"\n{'━'*60}")
954
+ print(f" Training for {config['epochs']} epochs")
955
+ print(f" Warmup: {config['warmup_epochs']} epochs, "
956
+ f"LR: {config['lr']}, WD: {config['weight_decay']}")
957
+ print(f" Geo reg: cv_w={config['cv_weight']}, spread_w={config['spread_weight']}")
958
+ print(f" NCE bank: size={config['nce_bank_size']}, "
959
+ f"temp={config['nce_temperature']}, weight={config['nce_weight']}")
960
+ print(f" Aug: TrivialAugmentWide + CutMix(p={config['cutmix_prob']}) + "
961
+ f"RandomErasing(p={config['random_erasing_p']})")
962
+ print(f" TensorBoard: {config['log_dir']}")
963
+ print(f" Geo analysis every {config['log_geo_every']} epochs")
964
+ print(f"{'━'*60}\n")
965
+
966
+ best_acc = 0
967
+ save_dir = Path('geo_cifar100'); save_dir.mkdir(exist_ok=True)
968
+
969
+ for epoch in tqdm(range(config['epochs']), desc="Epochs"):
970
+ t0 = time.time()
971
+
972
+ ce_loss, geo_loss, nce_loss, train_acc = train_epoch(
973
+ model, train_loader, optimizer, scheduler, epoch, config, writer)
974
+
975
+ test_acc = evaluate(model, test_loader)
976
+ elapsed = time.time() - t0
977
+
978
+ lr = optimizer.param_groups[0]['lr']
979
+ writer.add_scalar('train/ce_loss', ce_loss, epoch)
980
+ writer.add_scalar('train/geo_loss', geo_loss, epoch)
981
+ writer.add_scalar('train/nce_loss', nce_loss, epoch)
982
+ writer.add_scalar('train/total_loss', ce_loss + geo_loss + nce_loss, epoch)
983
+ writer.add_scalar('train/accuracy', train_acc, epoch)
984
+ writer.add_scalar('test/accuracy', test_acc, epoch)
985
+ writer.add_scalar('train/lr', lr, epoch)
986
+ writer.add_scalar('train/epoch_time', elapsed, epoch)
987
+ writer.add_scalar('gap/train_test', train_acc - test_acc, epoch)
988
+
989
+ log_weight_norms(model, writer, epoch)
990
+
991
+ if test_acc > best_acc:
992
+ best_acc = test_acc
993
+ torch.save({
994
+ 'state_dict': {k: v.cpu() for k, v in model.state_dict().items()},
995
+ 'epoch': epoch,
996
+ 'test_acc': test_acc,
997
+ 'config': config,
998
+ }, save_dir / 'best.pt')
999
+
1000
+ # Full geometric analysis periodically
1001
+ if epoch % config['log_geo_every'] == 0 or epoch == config['epochs'] - 1:
1002
+ geo_info = log_geometric_analysis(
1003
+ model, writer, epoch, test_loader, device, config)
1004
+
1005
+ cv_str = ', '.join(f'{cv:.3f}' for cv in geo_info['cv_trajectory'])
1006
+ cm_str = ', '.join(f'{q:.3f}' for q in geo_info.get('cm_quality_trajectory', []))
1007
+ res_str = ', '.join(f'{r:.3f}' for r in geo_info.get('res_norms', []))
1008
+ tqdm.write(
1009
+ f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} "
1010
+ f"nce={nce_loss:.4f} "
1011
+ f"train={train_acc:.4f} test={test_acc:.4f} "
1012
+ f"best={best_acc:.4f} {elapsed:.1f}s"
1013
+ f"\n CV=[{cv_str}]"
1014
+ f"\n CM=[{cm_str}]"
1015
+ f"\n GR=[{res_str}]")
1016
+ elif epoch % 5 == 0:
1017
+ tqdm.write(
1018
+ f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} "
1019
+ f"nce={nce_loss:.4f} "
1020
+ f"train={train_acc:.4f} test={test_acc:.4f} "
1021
+ f"best={best_acc:.4f} {elapsed:.1f}s")
1022
+
1023
+ # Final summary
1024
+ print(f"\n{'═'*60}")
1025
+ print(f" CIFAR-100 RESULTS (CM-Validated)")
1026
+ print(f"{'═'*60}")
1027
+ print(f" Best test accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)")
1028
+ print(f" Parameters: {n_params:,}")
1029
+ print(f" Checkpoint: {save_dir}/best.pt")
1030
+ print(f" TensorBoard: {config['log_dir']}")
1031
+
1032
+ # Final geometric state + anchor diagnostics
1033
+ print(f"\n Final geometric state:")
1034
+ geo_info = log_geometric_analysis(
1035
+ model, writer, config['epochs'], test_loader, device, config)
1036
+
1037
+ print(f"\n Final anchor diagnostics:")
1038
+ diag = model.anchor_diagnostics()
1039
+ for layer_name, d in diag.items():
1040
+ print(f" {layer_name}: cv={d['anchor_cv']:.4f}, "
1041
+ f"cm_pos={d['cm_positive_frac']:.3f}, "
1042
+ f"cm_mean={d['cm_mean']:.4f}")
1043
+
1044
+ writer.close()
1045
+ print(f"\nDone.")
1046
+
1047
+
1048
+ if __name__ == '__main__':
1049
+ main()