""" Geometric Transformer — CIFAR-100 Training with CM-Validated Analysis Changes from previous version: - CM gate diagnostics per layer: active anchors, gate_mean, cm_positive_frac - CM quality in geometric residual analysis (replaces blind gate) - Geometric regularization losses (CV target + anchor spread) in training loop - Anchor diagnostics via model.anchor_diagnostics() - CM quality trajectory alongside CV and bridge KL for cooperation analysis TensorBoard logging of every geometric feature element: - CV (coefficient of variation) per layer — the pentachoron band metric - CM gate: active anchors, gate mean, cm_positive_frac, quality per position - Stream agreement/divergence per layer - Anchor utilization, entropy, spread - Patchwork activation statistics (from CM-validated triangulation) - Bridge vs assignment consistency - Triangulation distance distributions - SVD spectrum, entropy, novelty - Quaternion arm norms and composition statistics - Cayley rotation ‖R-I‖ per layer - FiLM gamma/beta deviation from identity - Gate activation statistics - Gradient norms per component type (including cm_gate) - Weight norms per component type - Geometric regularization: CV loss, spread loss per epoch !pip install geolip-core torchvision tqdm tensorboard """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import time, json, math from pathlib import Path from tqdm.auto import tqdm from torch.utils.tensorboard import SummaryWriter device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") if device.type == 'cuda': print(f" GPU: {torch.cuda.get_device_name()}") print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") # ═══════════════════════════════════════════════════════════════════════════════ # IMPORT TRANSFORMER # ═══════════════════════════════════════════════════════════════════════════════ # Try geolip_core installed package first, fall back to local file try: from geolip_core.pipeline.components.geometric_transformer import ( GeometricTransformer, GeometricTransformerLayer, CayleyOrthogonal, QuaternionCompose, FiLMLayer, ContentAttention, GeometricAttention, CMValidatedGate, TorchComponent, BaseTower, anchor_neighborhood_cm, ) print(" Imported from geolip_core (installed)") except ImportError: try: from geometric_transformer import ( GeometricTransformer, GeometricTransformerLayer, CayleyOrthogonal, QuaternionCompose, FiLMLayer, ContentAttention, GeometricAttention, CMValidatedGate, TorchComponent, BaseTower, anchor_neighborhood_cm, ) print(" Imported from local geometric_transformer.py") except ImportError: raise ImportError( "Cannot find geometric_transformer. Place geometric_transformer.py " "in the working directory or install geolip-core.") torch.set_float32_matmul_precision('high') # ═══════════════════════════════════════════════════════════════════════════════ # CONFIG # ═══════════════════════════════════════════════════════════════════════════════ CONFIG = { # Model 'd_model': 256, 'n_heads': 8, 'n_layers': 8, 'n_anchors': 128, 'manifold_dim': 128, 'n_comp': 4, 'd_comp': 16, 'context_dim': 64, 'quat_dim': 32, 'dropout': 0.1, 'cm_neighbors': 3, # CM simplex neighbors # Input stage 'patch_size': 4, 'img_size': 32, 'in_channels': 3, 'conv_channels': 64, 'svd_rank': 16, # Training 'epochs': 100, 'batch_size': 1024, 'lr': 1e-3, 'weight_decay': 0.05, 'warmup_epochs': 5, 'label_smoothing': 0.1, 'num_workers': 8, # Geometric regularization 'cv_target': 0.215, # pentachoron band center 'cv_weight': 0.1, # CV loss weight 'spread_weight': 0.01, # anchor spread loss weight # Augmentation — tuned for CM gate training 'cutmix_alpha': 1.0, # CutMix beta distribution α (1.0 = uniform box sizes) 'cutmix_prob': 0.5, # probability of applying CutMix per batch 'random_erasing_p': 0.25, # probability of erasing per image # InfoNCE memory bank on geometric residual 'nce_bank_size': 4096, # queue size (0 to disable) 'nce_temperature': 0.1, # InfoNCE temperature 'nce_weight': 0.1, # loss weight # Data 'num_classes': 100, # Logging 'log_geo_every': 5, # full geometric analysis every N epochs 'log_grads_every': 10, # gradient norms every N epochs 'log_dir': 'runs/geo_cifar100', } # ═══════════════════════════════════════════════════════════════════════════════ # INPUT STAGE # ═══════════════════════════════════════════════════════════════════════════════ try: from geolip_core.core.input.svd import SVDObserver _HAS_SVD = True except ImportError: _HAS_SVD = False class SVDObserver(nn.Module): """Fallback SVDObserver.""" def __init__(self, in_channels, svd_rank=24): super().__init__() self.svd_rank = svd_rank self.to_svd = nn.Conv2d(in_channels, svd_rank, 1, bias=False) self.register_buffer('ema_s', torch.ones(svd_rank)) self.register_buffer('ema_vh_flat', torch.eye(svd_rank).reshape(-1)) self.ema_momentum = 0.99 def extract_features(self, S, Vh): B, k = S.shape S_safe = S.clamp(min=1e-6) s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8) vh_diag = Vh.diagonal(dim1=-2, dim2=-1) vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0) s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True) out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1) return torch.where(torch.isfinite(out), out, torch.zeros_like(out)) def compute_novelty(self, S): return S - self.ema_s.clone().unsqueeze(0) def forward(self, x): B, C, H, W = x.shape h = self.to_svd(x) h_flat = h.permute(0, 2, 3, 1).reshape(B, H * W, self.svd_rank) with torch.amp.autocast('cuda', enabled=False): with torch.no_grad(): gram = torch.bmm(h_flat.float().transpose(1, 2), h_flat.float()) evals, evecs = torch.linalg.eigh(gram) evals = evals.flip(-1).clamp(min=1e-12) S = evals.sqrt() Vh = evecs.flip(-1).transpose(-2, -1) S = torch.where(torch.isfinite(S), S, torch.ones_like(S)) Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh)) features = self.extract_features(S, Vh) novelty = self.compute_novelty(S) return S, Vh, features, novelty @torch.no_grad() def update_ema(self, S, Vh): m = self.ema_momentum self.ema_s.mul_(m).add_(S.detach().mean(0), alpha=1-m) self.ema_vh_flat.mul_(m).add_(Vh.detach().mean(0).reshape(-1), alpha=1-m) @property def feature_dim(self): return 2 * self.svd_rank + 2 class ConvSVDPatchEmbedding(TorchComponent): """Input stage: conv frontend → SVDObserver → patch tokens.""" def __init__(self, name, img_size=32, patch_size=4, in_channels=3, conv_channels=64, d_model=256, svd_rank=16): super().__init__(name) self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.d_model = d_model self.svd_rank = svd_rank self.conv_frontend = nn.Sequential( nn.Conv2d(in_channels, conv_channels, 3, padding=1, bias=False), nn.BatchNorm2d(conv_channels), nn.GELU(), nn.Conv2d(conv_channels, conv_channels, 3, padding=1, bias=False), nn.BatchNorm2d(conv_channels), nn.GELU(), ) self.svd_observer = SVDObserver(conv_channels, svd_rank) self.patch_proj = nn.Conv2d( conv_channels, d_model, kernel_size=patch_size, stride=patch_size, bias=False) self.patch_norm = nn.LayerNorm(d_model) svd_feat_dim = self.svd_observer.feature_dim self.svd_to_gamma = nn.Linear(svd_feat_dim, d_model) self.svd_to_beta = nn.Linear(svd_feat_dim, d_model) nn.init.zeros_(self.svd_to_gamma.weight); nn.init.ones_(self.svd_to_gamma.bias) nn.init.zeros_(self.svd_to_beta.weight); nn.init.zeros_(self.svd_to_beta.bias) self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) self.pos_embed = nn.Parameter( torch.randn(1, self.n_patches + 1, d_model) * 0.02) def forward(self, x): B = x.shape[0] feat = self.conv_frontend(x) S, Vh, svd_features, novelty = self.svd_observer(feat) tokens = self.patch_proj(feat) tokens = tokens.flatten(2).transpose(1, 2) tokens = self.patch_norm(tokens) gamma = self.svd_to_gamma(svd_features).unsqueeze(1) beta = self.svd_to_beta(svd_features).unsqueeze(1) tokens = gamma * tokens + beta cls = self.cls_token.expand(B, -1, -1) tokens = torch.cat([cls, tokens], dim=1) tokens = tokens + self.pos_embed svd_state = { 'singular_values': S, 'Vh': Vh, 'svd_features': svd_features, 'novelty': novelty, } if self.training: self.svd_observer.update_ema(S, Vh) return tokens, svd_state # ═══════════════════════════════════════════════════════════════════════════════ # CLASSIFIER ( uses GeometricTransformer with CM gates) # ═══════════════════════════════════════════════════════════════════════════════ class GeoViTClassifier(BaseTower): """Geometric Vision Transformer for classification. Wraps ConvSVDPatchEmbedding + GeometricTransformer + task head. Exposes geometric_losses() for regularization during training. """ def __init__(self, name, config): super().__init__(name) self.config = config self.attach('patch_embed', ConvSVDPatchEmbedding( 'patch_embed', img_size=config['img_size'], patch_size=config['patch_size'], in_channels=config['in_channels'], conv_channels=config['conv_channels'], d_model=config['d_model'], svd_rank=config['svd_rank'], )) self.attach('transformer', GeometricTransformer( 'geo_cifar', d_model=config['d_model'], n_heads=config['n_heads'], n_layers=config['n_layers'], n_anchors=config['n_anchors'], manifold_dim=config['manifold_dim'], n_comp=config['n_comp'], d_comp=config['d_comp'], context_dim=config['context_dim'], quat_dim=config['quat_dim'], dropout=config['dropout'], cm_neighbors=config.get('cm_neighbors', 3), nce_bank_size=config.get('nce_bank_size', 4096), nce_temperature=config.get('nce_temperature', 0.1), )) self.attach('head', nn.Sequential( nn.LayerNorm(config['d_model']), nn.Linear(config['d_model'], config['d_model']), nn.GELU(), nn.Dropout(config['dropout']), nn.Linear(config['d_model'], config['num_classes']), )) def forward(self, x, return_geo_state=False): tokens, svd_state = self['patch_embed'](x) if return_geo_state: features, geo_states = self['transformer'](tokens, return_geo_state=True) else: features = self['transformer'](tokens) cls_out = features[:, 0] logits = self['head'](cls_out) if return_geo_state: return logits, geo_states, svd_state return logits def geometric_losses(self): """Delegate to transformer's built-in geometric regularization.""" return self['transformer'].geometric_losses( cv_target=self.config.get('cv_target', 0.215), cv_weight=self.config.get('cv_weight', 0.1), spread_weight=self.config.get('spread_weight', 0.01), ) def infonce_loss(self): """InfoNCE contrastive loss on CLS token's geometric residual. Uses cached residual from last forward pass.""" return self['transformer'].infonce_loss() def update_nce_bank(self): """Enqueue current batch's residuals. Call AFTER backward.""" self['transformer'].update_nce_bank() def anchor_diagnostics(self): """Delegate to transformer's anchor diagnostics.""" return self['transformer'].anchor_diagnostics() # ═══════════════════════════════════════════════════════════════════════════════ # GEOMETRIC ANALYSIS BATTERY ( includes CM diagnostics) # ═══════════════════════════════════════════════════════════════════════════════ @torch.no_grad() def compute_cv(points): """Coefficient of variation on S^(d-1). CV = std(pairwise_cosine_distances) / mean(pairwise_cosine_distances) Pentachoron band: CV ∈ [0.20, 0.23]. """ points = F.normalize(points.float(), dim=-1) cos_sim = points @ points.T n = points.shape[0] idx = torch.triu_indices(n, n, offset=1, device=points.device) pairwise_dist = 1.0 - cos_sim[idx[0], idx[1]] mean_d = pairwise_dist.mean() std_d = pairwise_dist.std() cv = (std_d / (mean_d + 1e-8)).item() return cv, mean_d.item(), std_d.item() @torch.no_grad() def log_geometric_analysis(model, writer, epoch, test_loader, device, config): """Full geometric analysis battery with CM diagnostics.""" model.eval() images, labels = next(iter(test_loader)) images = images[:min(64, images.shape[0])].to(device) labels = labels[:min(64, labels.shape[0])].to(device) logits, geo_states, svd_state = model(images, return_geo_state=True) n_layers = len(geo_states) pred = logits.argmax(1) batch_acc = (pred == labels).float().mean().item() writer.add_scalar('analysis/batch_accuracy', batch_acc, epoch) # ─── SVD Input Stage ─── S = svd_state['singular_values'] s_norm = S / (S.sum(dim=-1, keepdim=True) + 1e-8) s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1) novelty = svd_state['novelty'] writer.add_scalar('svd/entropy_mean', s_ent.mean().item(), epoch) writer.add_scalar('svd/entropy_std', s_ent.std().item(), epoch) writer.add_scalar('svd/novelty_norm', novelty.norm(dim=-1).mean().item(), epoch) writer.add_scalar('svd/top1_ratio', (S[:, 0] / (S.sum(-1) + 1e-8)).mean().item(), epoch) writer.add_scalar('svd/condition_number', (S[:, 0] / (S[:, -1].clamp(min=1e-8))).mean().item(), epoch) for k in range(min(S.shape[1], 5)): writer.add_scalar(f'svd/S_{k}', S[:, k].mean().item(), epoch) # SVD FiLM deviation pe = model['patch_embed'] writer.add_scalar('svd_film/gamma_weight_norm', pe.svd_to_gamma.weight.data.norm().item(), epoch) writer.add_scalar('svd_film/gamma_bias_dev_from_1', (pe.svd_to_gamma.bias.data - 1.0).abs().mean().item(), epoch) writer.add_scalar('svd_film/beta_weight_norm', pe.svd_to_beta.weight.data.norm().item(), epoch) writer.add_scalar('svd_film/beta_bias_norm', pe.svd_to_beta.bias.data.abs().mean().item(), epoch) # ─── Anchor Diagnostics (built-in) ─── anchor_diag = model.anchor_diagnostics() for layer_name, d in anchor_diag.items(): for k, v in d.items(): writer.add_scalar(f'anchor_diag/{layer_name}_{k}', v, epoch) # ─── Per-Layer Geometric Analysis ─── for i, gs in enumerate(geo_states): prefix = f'layer_{i}' # === CV — pentachoron band metric === emb = gs['embedding'] # Anchor CV transformer = model['transformer'] layer = transformer[f'layer_{i}'] anchors = F.normalize( layer['observer'].association.constellation.anchors, dim=-1) cv_anchors, mean_d_anchors, std_d_anchors = compute_cv(anchors) writer.add_scalar(f'{prefix}/cv_anchors', cv_anchors, epoch) writer.add_scalar(f'{prefix}/anchor_mean_dist', mean_d_anchors, epoch) writer.add_scalar(f'{prefix}/anchor_std_dist', std_d_anchors, epoch) # Embedding CV emb_flat = emb.reshape(-1, emb.shape[-1]) n_sample = min(512, emb_flat.shape[0]) idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample] cv_emb, mean_d_emb, std_d_emb = compute_cv(emb_flat[idx]) writer.add_scalar(f'{prefix}/cv_embeddings', cv_emb, epoch) writer.add_scalar(f'{prefix}/embedding_mean_dist', mean_d_emb, epoch) # === CM Gate Diagnostics === gate_info = gs.get('gate_info', {}) gate_values = gs.get('gate_values') cm_quality = gs.get('cm_quality') if gate_info: writer.add_scalar(f'{prefix}/cm_active_anchors', gate_info.get('active', 0), epoch) writer.add_scalar(f'{prefix}/cm_gate_mean', gate_info.get('gate_mean', 0), epoch) writer.add_scalar(f'{prefix}/cm_positive_frac', gate_info.get('cm_positive_frac', 0), epoch) if gate_values is not None: gv = gate_values writer.add_scalar(f'{prefix}/gate_values_min', gv.min().item(), epoch) writer.add_scalar(f'{prefix}/gate_values_max', gv.max().item(), epoch) writer.add_scalar(f'{prefix}/gate_values_std', gv.std().item(), epoch) # Per-anchor gate mean (which anchors are consistently open/closed) gv_per_anchor = gv.mean(dim=0).mean(dim=0) # average over B and L writer.add_scalar(f'{prefix}/gate_anchor_spread', gv_per_anchor.std().item(), epoch) # Fraction of positions with >50% anchors open if gv.dim() == 3: pos_open_frac = (gv.mean(dim=-1) > 0.5).float().mean().item() else: pos_open_frac = (gv > 0.5).float().mean().item() writer.add_scalar(f'{prefix}/gate_positions_open_frac', pos_open_frac, epoch) if cm_quality is not None: writer.add_scalar(f'{prefix}/cm_quality_mean', cm_quality.mean().item(), epoch) writer.add_scalar(f'{prefix}/cm_quality_std', cm_quality.std().item(), epoch) writer.add_scalar(f'{prefix}/cm_quality_min', cm_quality.min().item(), epoch) # === Stream Agreement === content = gs['content'] geometric = gs['geometric'] agreement = F.cosine_similarity( content.reshape(-1, content.shape[-1]), geometric.reshape(-1, geometric.shape[-1]), dim=-1) writer.add_scalar(f'{prefix}/stream_agreement_mean', agreement.mean().item(), epoch) writer.add_scalar(f'{prefix}/stream_agreement_std', agreement.std().item(), epoch) writer.add_scalar(f'{prefix}/content_norm', content.norm(dim=-1).mean().item(), epoch) writer.add_scalar(f'{prefix}/geometric_norm', geometric.norm(dim=-1).mean().item(), epoch) # === Disagreement arm analysis === disagree = content - geometric agree = content * geometric writer.add_scalar(f'{prefix}/disagree_norm', disagree.norm(dim=-1).mean().item(), epoch) writer.add_scalar(f'{prefix}/agree_norm', agree.norm(dim=-1).mean().item(), epoch) # === Anchor Utilization === tri = gs['triangulation'] assignment = gs['assignment'] nearest = gs['nearest'] n_anchors = tri.shape[-1] nearest_flat = nearest.reshape(-1) counts = torch.bincount(nearest_flat, minlength=n_anchors).float() total_assignments = counts.sum() probs = counts / (total_assignments + 1e-8) anchor_entropy = -(probs * torch.log(probs.clamp(min=1e-8))).sum().item() max_entropy = math.log(n_anchors) writer.add_scalar(f'{prefix}/anchor_entropy_normalized', anchor_entropy / (max_entropy + 1e-8), epoch) active = (counts > 0).sum().item() writer.add_scalar(f'{prefix}/anchors_active', active, epoch) writer.add_scalar(f'{prefix}/anchors_active_frac', active / n_anchors, epoch) dead = (counts == 0).sum().item() writer.add_scalar(f'{prefix}/anchors_dead', dead, epoch) # === Triangulation Statistics === writer.add_scalar(f'{prefix}/tri_mean', tri.mean().item(), epoch) writer.add_scalar(f'{prefix}/tri_std', tri.std().item(), epoch) # === Soft Assignment Statistics === assign_ent = -(assignment * torch.log(assignment.clamp(min=1e-8))).sum(-1) writer.add_scalar(f'{prefix}/assignment_entropy_mean', assign_ent.mean().item(), epoch) writer.add_scalar(f'{prefix}/assignment_max_prob', assignment.max(dim=-1).values.mean().item(), epoch) # === Patchwork Statistics (now from CM-validated triangulation) === pw = gs['patchwork'] writer.add_scalar(f'{prefix}/patchwork_norm', pw.norm(dim=-1).mean().item(), epoch) writer.add_scalar(f'{prefix}/patchwork_std', pw.std().item(), epoch) pw_sparsity = (pw.abs() < 0.01).float().mean().item() writer.add_scalar(f'{prefix}/patchwork_sparsity', pw_sparsity, epoch) # === Bridge Consistency === bridge = gs['bridge'] bridge_soft = F.softmax(bridge, dim=-1) bridge_assign_kl = F.kl_div( bridge_soft.log().reshape(-1, n_anchors), assignment.reshape(-1, n_anchors), reduction='batchmean', log_target=False) writer.add_scalar(f'{prefix}/bridge_assignment_kl', bridge_assign_kl.item(), epoch) # === Quaternion Composition === composed = gs['composed'] writer.add_scalar(f'{prefix}/composed_norm', composed.norm(dim=-1).mean().item(), epoch) # === Geo Context === geo_ctx = gs['geo_ctx'] writer.add_scalar(f'{prefix}/geo_ctx_norm', geo_ctx.norm(dim=-1).mean().item(), epoch) # === Geometric Residual Stream (CM-conditioned) === geo_res = gs.get('geo_residual') if geo_res is not None: res_norms = geo_res.norm(dim=-1) writer.add_scalar(f'{prefix}/geo_res_norm', res_norms.mean().item(), epoch) writer.add_scalar(f'{prefix}/geo_res_std', geo_res.std().item(), epoch) writer.add_scalar(f'{prefix}/geo_res_sparsity', (geo_res.abs() < 0.01).float().mean().item(), epoch) # Cross-position consistency geo_res_flat = geo_res.reshape(-1, geo_res.shape[-1]) n_s = min(256, geo_res_flat.shape[0]) idx_s = torch.randperm(geo_res_flat.shape[0], device=geo_res.device)[:n_s] sampled = F.normalize(geo_res_flat[idx_s], dim=-1) cos_mat = sampled @ sampled.T triu = torch.triu_indices(n_s, n_s, offset=1, device=geo_res.device) writer.add_scalar(f'{prefix}/geo_res_consistency', cos_mat[triu[0], triu[1]].mean().item(), epoch) # ─── Cayley Rotation Analysis ─── for name, mod in model.named_modules(): if isinstance(mod, CayleyOrthogonal): R = mod.get_rotation() I = torch.eye(R.shape[0], device=R.device) r_dist = (R - I).norm().item() clean_name = name.replace('.', '_') writer.add_scalar(f'cayley/{clean_name}_R_minus_I', r_dist, epoch) # ─── FiLM Layer Analysis ─── film_idx = 0 for name, mod in model.named_modules(): if isinstance(mod, FiLMLayer): g_b = mod.to_gamma.bias.data b_b = mod.to_beta.bias.data writer.add_scalar(f'film/{film_idx}_gamma_dev', (g_b - 1.0).abs().mean().item(), epoch) writer.add_scalar(f'film/{film_idx}_beta_dev', b_b.abs().mean().item(), epoch) film_idx += 1 # ─── Cross-Layer Trajectories ─── cv_trajectory = [] cm_quality_trajectory = [] res_norms = [] bridge_kls = [] for i, gs in enumerate(geo_states): # CV emb = gs['embedding'] emb_flat = emb.reshape(-1, emb.shape[-1]) n_sample = min(512, emb_flat.shape[0]) idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample] cv, _, _ = compute_cv(emb_flat[idx]) cv_trajectory.append(cv) # CM quality cm_q = gs.get('cm_quality') if cm_q is not None: cm_quality_trajectory.append(cm_q.mean().item()) # Geo residual norms geo_res = gs.get('geo_residual') if geo_res is not None: res_norms.append(geo_res.norm(dim=-1).mean().item()) # Bridge KL n_anchors = gs['assignment'].shape[-1] bridge_soft = F.softmax(gs['bridge'], dim=-1) bkl = F.kl_div( bridge_soft.log().reshape(-1, n_anchors), gs['assignment'].reshape(-1, n_anchors), reduction='batchmean', log_target=False).item() bridge_kls.append(bkl) # CV trajectory writer.add_scalar('cv/trajectory_mean', np.mean(cv_trajectory), epoch) writer.add_scalar('cv/trajectory_std', np.std(cv_trajectory), epoch) in_band = sum(1 for cv in cv_trajectory if 0.20 <= cv <= 0.23) writer.add_scalar('cv/layers_in_pentachoron_band', in_band, epoch) writer.add_scalar('cv/layers_in_band_frac', in_band / len(cv_trajectory), epoch) # CM quality trajectory if cm_quality_trajectory: writer.add_scalar('cm/quality_trajectory_mean', np.mean(cm_quality_trajectory), epoch) writer.add_scalar('cm/quality_trajectory_std', np.std(cm_quality_trajectory), epoch) writer.add_scalar('cm/quality_min_layer', np.min(cm_quality_trajectory), epoch) writer.add_scalar('cm/quality_max_layer', np.max(cm_quality_trajectory), epoch) # Geometric residual trajectory if res_norms: writer.add_scalar('geo_res/trajectory_start', res_norms[0], epoch) writer.add_scalar('geo_res/trajectory_end', res_norms[-1], epoch) writer.add_scalar('geo_res/accumulation_ratio', res_norms[-1] / (res_norms[0] + 1e-8), epoch) growth = [res_norms[j+1] - res_norms[j] for j in range(len(res_norms)-1)] writer.add_scalar('geo_res/growth_mean', np.mean(growth), epoch) writer.add_scalar('geo_res/growth_std', np.std(growth), epoch) # Cooperation analysis (includes CM quality) if len(res_norms) >= 4: cv_corr = float(np.corrcoef(res_norms, cv_trajectory)[0, 1]) bkl_corr = float(np.corrcoef(res_norms, bridge_kls)[0, 1]) writer.add_scalar('cooperation/geo_res_vs_cv', cv_corr, epoch) writer.add_scalar('cooperation/geo_res_vs_bridge_kl', bkl_corr, epoch) if len(cm_quality_trajectory) == len(res_norms): cm_corr = float(np.corrcoef( res_norms, cm_quality_trajectory)[0, 1]) writer.add_scalar('cooperation/geo_res_vs_cm_quality', cm_corr, epoch) # CM vs CV: do layers with better CM quality also have better CV? cm_cv_corr = float(np.corrcoef( cm_quality_trajectory, cv_trajectory)[0, 1]) writer.add_scalar('cooperation/cm_quality_vs_cv', cm_cv_corr, epoch) return { 'batch_acc': batch_acc, 'cv_trajectory': cv_trajectory, 'cm_quality_trajectory': cm_quality_trajectory, 'res_norms': res_norms, 'bridge_kls': bridge_kls, } @torch.no_grad() def log_gradient_norms(model, writer, epoch): """Log gradient norms per component type (includes cm_gate).""" type_grads = {} for name, param in model.named_parameters(): if param.grad is not None: grad_norm = param.grad.norm().item() if 'projection' in name and 'proj' in name: key = 'manifold_proj' elif 'cm_gate' in name: key = 'cm_gate' elif 'observer' in name or 'constellation' in name or 'anchor' in name: key = 'constellation' elif 'context' in name: key = 'geo_context' elif 'content' in name: key = 'content_attn' elif 'geometric' in name and 'film' not in name: key = 'geo_attn' elif 'film' in name: key = 'film' elif 'rotation' in name or 'cayley' in name or 'A_upper' in name: key = 'cayley' elif 'compose' in name or 'quat' in name or 'proj_w' in name: key = 'quaternion' elif 'decode' in name: key = 'decode' elif 'gate' in name: key = 'gate' elif 'conv' in name or 'patch' in name: key = 'input_stage' elif 'head' in name: key = 'head' elif 'svd' in name: key = 'svd' elif 'geo_proj' in name: key = 'geo_residual_proj' else: key = 'other' if key not in type_grads: type_grads[key] = [] type_grads[key].append(grad_norm) for key, norms in type_grads.items(): writer.add_scalar(f'grad_norm/{key}_mean', np.mean(norms), epoch) writer.add_scalar(f'grad_norm/{key}_max', np.max(norms), epoch) total = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5 writer.add_scalar('grad_norm/total', total, epoch) @torch.no_grad() def log_weight_norms(model, writer, epoch): """Log weight norms per component type.""" for name, param in model.named_parameters(): if 'A_upper' in name: clean = name.replace('.', '_') writer.add_scalar(f'weights/{clean}_norm', param.norm().item(), epoch) # ═══════════════════════════════════════════════════════════════════════════════ # DATA # ═══════════════════════════════════════════════════════════════════════════════ def get_dataloaders(config): import torchvision import torchvision.transforms as T # Augmentation pipeline tuned for geometric transformer: # TrivialAugmentWide: continuous severity spectrum of geometric + photometric # transforms. Exercises CM gate across full quality range — mild distortion # keeps CM high, severe distortion creates partially-degenerate simplices. # RandomErasing: creates degenerate manifold projections (zero-volume CM simplices). # Trains CM gate to close on corrupted regions. # CutMix applied at batch level in train_epoch (not here). train_transform = T.Compose([ T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.TrivialAugmentWide(), T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), T.RandomErasing(p=config.get('random_erasing_p', 0.25), scale=(0.02, 0.33), ratio=(0.3, 3.3)), ]) test_transform = T.Compose([ T.ToTensor(), T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) train_ds = torchvision.datasets.CIFAR100( root='./data', train=True, download=True, transform=train_transform) test_ds = torchvision.datasets.CIFAR100( root='./data', train=False, download=True, transform=test_transform) train_loader = torch.utils.data.DataLoader( train_ds, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], pin_memory=True, drop_last=True) test_loader = torch.utils.data.DataLoader( test_ds, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], pin_memory=True) return train_loader, test_loader # ═══════════════════════════════════════════════════════════════════════════════ # CUTMIX — batch-level augmentation for CM gate boundary training # ═══════════════════════════════════════════════════════════════════════════════ def cutmix_batch(images, labels, alpha=1.0): """Apply CutMix to a batch. Returns mixed images + label pairs + lambda. CutMix replaces a rectangular region of image A with image B. Positions inside each region have coherent geometry — valid CM simplices. The boundary between regions has mixed geometric context — the CM gate should learn to suppress these positions. Args: images: (B, C, H, W) batch labels: (B,) integer labels alpha: Beta distribution parameter (1.0 = uniform box sizes) Returns: images: (B, C, H, W) mixed batch (modified in-place) labels_a: (B,) labels for region A labels_b: (B,) labels for region B lam: float, fraction of image A remaining """ lam = np.random.beta(alpha, alpha) B = images.size(0) idx = torch.randperm(B, device=images.device) H, W = images.shape[2], images.shape[3] cut_ratio = (1.0 - lam) ** 0.5 cw = int(W * cut_ratio) ch = int(H * cut_ratio) cx = np.random.randint(W) cy = np.random.randint(H) x1 = max(cx - cw // 2, 0); x2 = min(cx + cw // 2, W) y1 = max(cy - ch // 2, 0); y2 = min(cy + ch // 2, H) images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2] lam_actual = 1.0 - (x2 - x1) * (y2 - y1) / (W * H) return images, labels, labels[idx], lam_actual # ═══════════════════════════════════════════════════════════════════════════════ # TRAINING (geometric losses + CutMix integrated) # ═══════════════════════════════════════════════════════════════════════════════ def train_epoch(model, loader, optimizer, scheduler, epoch, config, writer): model.train() total_loss = 0 total_geo_loss = 0 total_nce_loss = 0 correct = 0 total = 0 cutmix_alpha = config.get('cutmix_alpha', 1.0) cutmix_prob = config.get('cutmix_prob', 0.5) label_smoothing = config.get('label_smoothing', 0.1) nce_weight = config.get('nce_weight', 0.1) criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) for batch_idx, (images, labels) in enumerate(loader): images = images.to(device) labels = labels.to(device) # CutMix: applied probabilistically per batch use_cutmix = np.random.rand() < cutmix_prob if use_cutmix: images, labels_a, labels_b, lam = cutmix_batch( images, labels, alpha=cutmix_alpha) logits = model(images) ce_loss = lam * criterion(logits, labels_a) + \ (1.0 - lam) * criterion(logits, labels_b) # Accuracy: count correct if matches either label pred = logits.argmax(1) correct += (lam * (pred == labels_a).float() + (1.0 - lam) * (pred == labels_b).float()).sum().item() else: logits = model(images) ce_loss = criterion(logits, labels) correct += (logits.argmax(1) == labels).sum().item() # Geometric regularization — CV target + anchor spread geo_losses = model.geometric_losses() geo_loss = geo_losses.get('geo_total', torch.tensor(0.0, device=device)) # InfoNCE on geometric residual — discriminative pressure nce_losses = model.infonce_loss() nce_loss = nce_losses.get('nce', torch.tensor(0.0, device=device)) loss = ce_loss + geo_loss + nce_weight * nce_loss optimizer.zero_grad() loss.backward() # Enqueue AFTER backward — detached residuals go into bank model.update_nce_bank() # Log gradient norms periodically if epoch % config['log_grads_every'] == 0 and batch_idx == 0: log_gradient_norms(model, writer, epoch) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if scheduler is not None: scheduler.step() total_loss += ce_loss.item() * images.size(0) total_geo_loss += geo_loss.item() * images.size(0) total_nce_loss += nce_loss.item() * images.size(0) total += images.size(0) avg_ce = total_loss / total avg_geo = total_geo_loss / total avg_nce = total_nce_loss / total return avg_ce, avg_geo, avg_nce, correct / total @torch.no_grad() def evaluate(model, loader): model.eval() correct = 0 total = 0 for images, labels in loader: images = images.to(device) labels = labels.to(device) logits = model(images) correct += (logits.argmax(1) == labels).sum().item() total += images.size(0) return correct / total def main(): config = CONFIG.copy() print("=" * 60) print(" Geometric Transformer — CIFAR-100 (CM-Validated)") print(f" Input: conv({config['in_channels']}→{config['conv_channels']}) + " f"SVD(rank={config['svd_rank']}) + " f"{config['patch_size']}×{config['patch_size']} patches = " f"{(config['img_size']//config['patch_size'])**2} tokens + CLS") print(f" Model: d={config['d_model']}, heads={config['n_heads']}, " f"layers={config['n_layers']}, anchors={config['n_anchors']}") print(f" CM: neighbors={config['cm_neighbors']}, " f"cv_target={config['cv_target']}, " f"cv_weight={config['cv_weight']}, " f"spread_weight={config['spread_weight']}") print(f" Aug: TrivialAugmentWide + CutMix(α={config['cutmix_alpha']}, " f"p={config['cutmix_prob']}) + " f"RandomErasing(p={config['random_erasing_p']})") print(f" NCE: bank={config['nce_bank_size']}, " f"temp={config['nce_temperature']}, " f"weight={config['nce_weight']}") print("=" * 60) writer = SummaryWriter(config['log_dir']) writer.add_text('config', json.dumps(config, indent=2)) print("\nLoading CIFAR-100...") train_loader, test_loader = get_dataloaders(config) print(f" Train: {len(train_loader.dataset):,} | Test: {len(test_loader.dataset):,}") model = GeoViTClassifier('geo_vit_cifar100', config) if hasattr(model, 'network_to'): model.network_to(device=device, strict=False) else: model = model.to(device) n_params = sum(p.numel() for p in model.parameters()) n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n Total params: {n_params:,}") print(f" Trainable params: {n_trainable:,}") for name, module in model.named_children(): n = sum(p.numel() for p in module.parameters()) if n > 0: print(f" {name:<20s}: {n:,}") writer.add_scalar('model/total_params', n_params, 0) # Initial anchor diagnostics print(f"\n Initial anchor diagnostics:") diag = model.anchor_diagnostics() for layer_name, d in diag.items(): print(f" {layer_name}: cv={d['anchor_cv']:.4f}, " f"cm_pos={d['cm_positive_frac']:.3f}, " f"min_dist={d['min_pairwise_dist']:.4f}") # Optimizer + scheduler optimizer = torch.optim.AdamW( model.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) total_steps = config['epochs'] * len(train_loader) warmup_steps = config['warmup_epochs'] * len(train_loader) def lr_lambda(step): if step < warmup_steps: return step / warmup_steps progress = (step - warmup_steps) / (total_steps - warmup_steps) return 0.5 * (1 + np.cos(np.pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) print(f"\n{'━'*60}") print(f" Training for {config['epochs']} epochs") print(f" Warmup: {config['warmup_epochs']} epochs, " f"LR: {config['lr']}, WD: {config['weight_decay']}") print(f" Geo reg: cv_w={config['cv_weight']}, spread_w={config['spread_weight']}") print(f" NCE bank: size={config['nce_bank_size']}, " f"temp={config['nce_temperature']}, weight={config['nce_weight']}") print(f" Aug: TrivialAugmentWide + CutMix(p={config['cutmix_prob']}) + " f"RandomErasing(p={config['random_erasing_p']})") print(f" TensorBoard: {config['log_dir']}") print(f" Geo analysis every {config['log_geo_every']} epochs") print(f"{'━'*60}\n") best_acc = 0 save_dir = Path('geo_cifar100'); save_dir.mkdir(exist_ok=True) for epoch in tqdm(range(config['epochs']), desc="Epochs"): t0 = time.time() ce_loss, geo_loss, nce_loss, train_acc = train_epoch( model, train_loader, optimizer, scheduler, epoch, config, writer) test_acc = evaluate(model, test_loader) elapsed = time.time() - t0 lr = optimizer.param_groups[0]['lr'] writer.add_scalar('train/ce_loss', ce_loss, epoch) writer.add_scalar('train/geo_loss', geo_loss, epoch) writer.add_scalar('train/nce_loss', nce_loss, epoch) writer.add_scalar('train/total_loss', ce_loss + geo_loss + nce_loss, epoch) writer.add_scalar('train/accuracy', train_acc, epoch) writer.add_scalar('test/accuracy', test_acc, epoch) writer.add_scalar('train/lr', lr, epoch) writer.add_scalar('train/epoch_time', elapsed, epoch) writer.add_scalar('gap/train_test', train_acc - test_acc, epoch) log_weight_norms(model, writer, epoch) if test_acc > best_acc: best_acc = test_acc torch.save({ 'state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, 'epoch': epoch, 'test_acc': test_acc, 'config': config, }, save_dir / 'best.pt') # Full geometric analysis periodically if epoch % config['log_geo_every'] == 0 or epoch == config['epochs'] - 1: geo_info = log_geometric_analysis( model, writer, epoch, test_loader, device, config) cv_str = ', '.join(f'{cv:.3f}' for cv in geo_info['cv_trajectory']) cm_str = ', '.join(f'{q:.3f}' for q in geo_info.get('cm_quality_trajectory', [])) res_str = ', '.join(f'{r:.3f}' for r in geo_info.get('res_norms', [])) tqdm.write( f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} " f"nce={nce_loss:.4f} " f"train={train_acc:.4f} test={test_acc:.4f} " f"best={best_acc:.4f} {elapsed:.1f}s" f"\n CV=[{cv_str}]" f"\n CM=[{cm_str}]" f"\n GR=[{res_str}]") elif epoch % 5 == 0: tqdm.write( f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} " f"nce={nce_loss:.4f} " f"train={train_acc:.4f} test={test_acc:.4f} " f"best={best_acc:.4f} {elapsed:.1f}s") # Final summary print(f"\n{'═'*60}") print(f" CIFAR-100 RESULTS (CM-Validated)") print(f"{'═'*60}") print(f" Best test accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)") print(f" Parameters: {n_params:,}") print(f" Checkpoint: {save_dir}/best.pt") print(f" TensorBoard: {config['log_dir']}") # Final geometric state + anchor diagnostics print(f"\n Final geometric state:") geo_info = log_geometric_analysis( model, writer, config['epochs'], test_loader, device, config) print(f"\n Final anchor diagnostics:") diag = model.anchor_diagnostics() for layer_name, d in diag.items(): print(f" {layer_name}: cv={d['anchor_cv']:.4f}, " f"cm_pos={d['cm_positive_frac']:.3f}, " f"cm_mean={d['cm_mean']:.4f}") writer.close() print(f"\nDone.") if __name__ == '__main__': main()