| """ |
| Geometric Transformer β CIFAR-100 Training with CM-Validated Analysis |
| |
| Changes from previous version: |
| - CM gate diagnostics per layer: active anchors, gate_mean, cm_positive_frac |
| - CM quality in geometric residual analysis (replaces blind gate) |
| - Geometric regularization losses (CV target + anchor spread) in training loop |
| - Anchor diagnostics via model.anchor_diagnostics() |
| - CM quality trajectory alongside CV and bridge KL for cooperation analysis |
| |
| TensorBoard logging of every geometric feature element: |
| - CV (coefficient of variation) per layer β the pentachoron band metric |
| - CM gate: active anchors, gate mean, cm_positive_frac, quality per position |
| - Stream agreement/divergence per layer |
| - Anchor utilization, entropy, spread |
| - Patchwork activation statistics (from CM-validated triangulation) |
| - Bridge vs assignment consistency |
| - Triangulation distance distributions |
| - SVD spectrum, entropy, novelty |
| - Quaternion arm norms and composition statistics |
| - Cayley rotation βR-Iβ per layer |
| - FiLM gamma/beta deviation from identity |
| - Gate activation statistics |
| - Gradient norms per component type (including cm_gate) |
| - Weight norms per component type |
| - Geometric regularization: CV loss, spread loss per epoch |
| |
| !pip install geolip-core torchvision tqdm tensorboard |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import time, json, math |
| from pathlib import Path |
| from tqdm.auto import tqdm |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| if device.type == 'cuda': |
| print(f" GPU: {torch.cuda.get_device_name()}") |
| print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
| |
| |
| |
|
|
| |
| try: |
| from geolip_core.pipeline.components.geometric_transformer import ( |
| GeometricTransformer, GeometricTransformerLayer, |
| CayleyOrthogonal, QuaternionCompose, FiLMLayer, |
| ContentAttention, GeometricAttention, CMValidatedGate, |
| TorchComponent, BaseTower, |
| anchor_neighborhood_cm, |
| ) |
| print(" Imported from geolip_core (installed)") |
| except ImportError: |
| try: |
| from geometric_transformer import ( |
| GeometricTransformer, GeometricTransformerLayer, |
| CayleyOrthogonal, QuaternionCompose, FiLMLayer, |
| ContentAttention, GeometricAttention, CMValidatedGate, |
| TorchComponent, BaseTower, |
| anchor_neighborhood_cm, |
| ) |
| print(" Imported from local geometric_transformer.py") |
| except ImportError: |
| raise ImportError( |
| "Cannot find geometric_transformer. Place geometric_transformer.py " |
| "in the working directory or install geolip-core.") |
|
|
| torch.set_float32_matmul_precision('high') |
|
|
|
|
| |
| |
| |
|
|
| CONFIG = { |
| |
| 'd_model': 256, |
| 'n_heads': 8, |
| 'n_layers': 8, |
| 'n_anchors': 128, |
| 'manifold_dim': 128, |
| 'n_comp': 4, |
| 'd_comp': 16, |
| 'context_dim': 64, |
| 'quat_dim': 32, |
| 'dropout': 0.1, |
| 'cm_neighbors': 3, |
|
|
| |
| 'patch_size': 4, |
| 'img_size': 32, |
| 'in_channels': 3, |
| 'conv_channels': 64, |
| 'svd_rank': 16, |
|
|
| |
| 'epochs': 100, |
| 'batch_size': 1024, |
| 'lr': 1e-3, |
| 'weight_decay': 0.05, |
| 'warmup_epochs': 5, |
| 'label_smoothing': 0.1, |
| 'num_workers': 8, |
|
|
| |
| 'cv_target': 0.215, |
| 'cv_weight': 0.1, |
| 'spread_weight': 0.01, |
|
|
| |
| 'cutmix_alpha': 1.0, |
| 'cutmix_prob': 0.5, |
| 'random_erasing_p': 0.25, |
|
|
| |
| 'nce_bank_size': 4096, |
| 'nce_temperature': 0.1, |
| 'nce_weight': 0.1, |
|
|
| |
| 'num_classes': 100, |
|
|
| |
| 'log_geo_every': 5, |
| 'log_grads_every': 10, |
| 'log_dir': 'runs/geo_cifar100', |
| } |
|
|
|
|
| |
| |
| |
|
|
| try: |
| from geolip_core.core.input.svd import SVDObserver |
| _HAS_SVD = True |
| except ImportError: |
| _HAS_SVD = False |
|
|
| class SVDObserver(nn.Module): |
| """Fallback SVDObserver.""" |
| def __init__(self, in_channels, svd_rank=24): |
| super().__init__() |
| self.svd_rank = svd_rank |
| self.to_svd = nn.Conv2d(in_channels, svd_rank, 1, bias=False) |
| self.register_buffer('ema_s', torch.ones(svd_rank)) |
| self.register_buffer('ema_vh_flat', torch.eye(svd_rank).reshape(-1)) |
| self.ema_momentum = 0.99 |
|
|
| def extract_features(self, S, Vh): |
| B, k = S.shape |
| S_safe = S.clamp(min=1e-6) |
| s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8) |
| vh_diag = Vh.diagonal(dim1=-2, dim2=-1) |
| vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0) |
| s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True) |
| out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1) |
| return torch.where(torch.isfinite(out), out, torch.zeros_like(out)) |
|
|
| def compute_novelty(self, S): |
| return S - self.ema_s.clone().unsqueeze(0) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| h = self.to_svd(x) |
| h_flat = h.permute(0, 2, 3, 1).reshape(B, H * W, self.svd_rank) |
| with torch.amp.autocast('cuda', enabled=False): |
| with torch.no_grad(): |
| gram = torch.bmm(h_flat.float().transpose(1, 2), h_flat.float()) |
| evals, evecs = torch.linalg.eigh(gram) |
| evals = evals.flip(-1).clamp(min=1e-12) |
| S = evals.sqrt() |
| Vh = evecs.flip(-1).transpose(-2, -1) |
| S = torch.where(torch.isfinite(S), S, torch.ones_like(S)) |
| Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh)) |
| features = self.extract_features(S, Vh) |
| novelty = self.compute_novelty(S) |
| return S, Vh, features, novelty |
|
|
| @torch.no_grad() |
| def update_ema(self, S, Vh): |
| m = self.ema_momentum |
| self.ema_s.mul_(m).add_(S.detach().mean(0), alpha=1-m) |
| self.ema_vh_flat.mul_(m).add_(Vh.detach().mean(0).reshape(-1), alpha=1-m) |
|
|
| @property |
| def feature_dim(self): |
| return 2 * self.svd_rank + 2 |
|
|
|
|
| class ConvSVDPatchEmbedding(TorchComponent): |
| """Input stage: conv frontend β SVDObserver β patch tokens.""" |
| def __init__(self, name, img_size=32, patch_size=4, in_channels=3, |
| conv_channels=64, d_model=256, svd_rank=16): |
| super().__init__(name) |
| self.patch_size = patch_size |
| self.n_patches = (img_size // patch_size) ** 2 |
| self.d_model = d_model |
| self.svd_rank = svd_rank |
|
|
| self.conv_frontend = nn.Sequential( |
| nn.Conv2d(in_channels, conv_channels, 3, padding=1, bias=False), |
| nn.BatchNorm2d(conv_channels), nn.GELU(), |
| nn.Conv2d(conv_channels, conv_channels, 3, padding=1, bias=False), |
| nn.BatchNorm2d(conv_channels), nn.GELU(), |
| ) |
| self.svd_observer = SVDObserver(conv_channels, svd_rank) |
| self.patch_proj = nn.Conv2d( |
| conv_channels, d_model, kernel_size=patch_size, |
| stride=patch_size, bias=False) |
| self.patch_norm = nn.LayerNorm(d_model) |
|
|
| svd_feat_dim = self.svd_observer.feature_dim |
| self.svd_to_gamma = nn.Linear(svd_feat_dim, d_model) |
| self.svd_to_beta = nn.Linear(svd_feat_dim, d_model) |
| nn.init.zeros_(self.svd_to_gamma.weight); nn.init.ones_(self.svd_to_gamma.bias) |
| nn.init.zeros_(self.svd_to_beta.weight); nn.init.zeros_(self.svd_to_beta.bias) |
|
|
| self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02) |
| self.pos_embed = nn.Parameter( |
| torch.randn(1, self.n_patches + 1, d_model) * 0.02) |
|
|
| def forward(self, x): |
| B = x.shape[0] |
| feat = self.conv_frontend(x) |
| S, Vh, svd_features, novelty = self.svd_observer(feat) |
| tokens = self.patch_proj(feat) |
| tokens = tokens.flatten(2).transpose(1, 2) |
| tokens = self.patch_norm(tokens) |
| gamma = self.svd_to_gamma(svd_features).unsqueeze(1) |
| beta = self.svd_to_beta(svd_features).unsqueeze(1) |
| tokens = gamma * tokens + beta |
| cls = self.cls_token.expand(B, -1, -1) |
| tokens = torch.cat([cls, tokens], dim=1) |
| tokens = tokens + self.pos_embed |
| svd_state = { |
| 'singular_values': S, 'Vh': Vh, |
| 'svd_features': svd_features, 'novelty': novelty, |
| } |
| if self.training: |
| self.svd_observer.update_ema(S, Vh) |
| return tokens, svd_state |
|
|
|
|
| |
| |
| |
|
|
| class GeoViTClassifier(BaseTower): |
| """Geometric Vision Transformer for classification. |
| |
| Wraps ConvSVDPatchEmbedding + GeometricTransformer + task head. |
| Exposes geometric_losses() for regularization during training. |
| """ |
| def __init__(self, name, config): |
| super().__init__(name) |
| self.config = config |
|
|
| self.attach('patch_embed', ConvSVDPatchEmbedding( |
| 'patch_embed', img_size=config['img_size'], |
| patch_size=config['patch_size'], in_channels=config['in_channels'], |
| conv_channels=config['conv_channels'], d_model=config['d_model'], |
| svd_rank=config['svd_rank'], |
| )) |
| self.attach('transformer', GeometricTransformer( |
| 'geo_cifar', d_model=config['d_model'], n_heads=config['n_heads'], |
| n_layers=config['n_layers'], n_anchors=config['n_anchors'], |
| manifold_dim=config['manifold_dim'], n_comp=config['n_comp'], |
| d_comp=config['d_comp'], context_dim=config['context_dim'], |
| quat_dim=config['quat_dim'], dropout=config['dropout'], |
| cm_neighbors=config.get('cm_neighbors', 3), |
| nce_bank_size=config.get('nce_bank_size', 4096), |
| nce_temperature=config.get('nce_temperature', 0.1), |
| )) |
| self.attach('head', nn.Sequential( |
| nn.LayerNorm(config['d_model']), |
| nn.Linear(config['d_model'], config['d_model']), |
| nn.GELU(), nn.Dropout(config['dropout']), |
| nn.Linear(config['d_model'], config['num_classes']), |
| )) |
|
|
| def forward(self, x, return_geo_state=False): |
| tokens, svd_state = self['patch_embed'](x) |
| if return_geo_state: |
| features, geo_states = self['transformer'](tokens, return_geo_state=True) |
| else: |
| features = self['transformer'](tokens) |
| cls_out = features[:, 0] |
| logits = self['head'](cls_out) |
| if return_geo_state: |
| return logits, geo_states, svd_state |
| return logits |
|
|
| def geometric_losses(self): |
| """Delegate to transformer's built-in geometric regularization.""" |
| return self['transformer'].geometric_losses( |
| cv_target=self.config.get('cv_target', 0.215), |
| cv_weight=self.config.get('cv_weight', 0.1), |
| spread_weight=self.config.get('spread_weight', 0.01), |
| ) |
|
|
| def infonce_loss(self): |
| """InfoNCE contrastive loss on CLS token's geometric residual. |
| Uses cached residual from last forward pass.""" |
| return self['transformer'].infonce_loss() |
|
|
| def update_nce_bank(self): |
| """Enqueue current batch's residuals. Call AFTER backward.""" |
| self['transformer'].update_nce_bank() |
|
|
| def anchor_diagnostics(self): |
| """Delegate to transformer's anchor diagnostics.""" |
| return self['transformer'].anchor_diagnostics() |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def compute_cv(points): |
| """Coefficient of variation on S^(d-1). |
| CV = std(pairwise_cosine_distances) / mean(pairwise_cosine_distances) |
| Pentachoron band: CV β [0.20, 0.23]. |
| """ |
| points = F.normalize(points.float(), dim=-1) |
| cos_sim = points @ points.T |
| n = points.shape[0] |
| idx = torch.triu_indices(n, n, offset=1, device=points.device) |
| pairwise_dist = 1.0 - cos_sim[idx[0], idx[1]] |
| mean_d = pairwise_dist.mean() |
| std_d = pairwise_dist.std() |
| cv = (std_d / (mean_d + 1e-8)).item() |
| return cv, mean_d.item(), std_d.item() |
|
|
|
|
| @torch.no_grad() |
| def log_geometric_analysis(model, writer, epoch, test_loader, device, config): |
| """Full geometric analysis battery with CM diagnostics.""" |
| model.eval() |
|
|
| images, labels = next(iter(test_loader)) |
| images = images[:min(64, images.shape[0])].to(device) |
| labels = labels[:min(64, labels.shape[0])].to(device) |
|
|
| logits, geo_states, svd_state = model(images, return_geo_state=True) |
|
|
| n_layers = len(geo_states) |
| pred = logits.argmax(1) |
| batch_acc = (pred == labels).float().mean().item() |
| writer.add_scalar('analysis/batch_accuracy', batch_acc, epoch) |
|
|
| |
| S = svd_state['singular_values'] |
| s_norm = S / (S.sum(dim=-1, keepdim=True) + 1e-8) |
| s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1) |
| novelty = svd_state['novelty'] |
|
|
| writer.add_scalar('svd/entropy_mean', s_ent.mean().item(), epoch) |
| writer.add_scalar('svd/entropy_std', s_ent.std().item(), epoch) |
| writer.add_scalar('svd/novelty_norm', novelty.norm(dim=-1).mean().item(), epoch) |
| writer.add_scalar('svd/top1_ratio', (S[:, 0] / (S.sum(-1) + 1e-8)).mean().item(), epoch) |
| writer.add_scalar('svd/condition_number', |
| (S[:, 0] / (S[:, -1].clamp(min=1e-8))).mean().item(), epoch) |
| for k in range(min(S.shape[1], 5)): |
| writer.add_scalar(f'svd/S_{k}', S[:, k].mean().item(), epoch) |
|
|
| |
| pe = model['patch_embed'] |
| writer.add_scalar('svd_film/gamma_weight_norm', pe.svd_to_gamma.weight.data.norm().item(), epoch) |
| writer.add_scalar('svd_film/gamma_bias_dev_from_1', |
| (pe.svd_to_gamma.bias.data - 1.0).abs().mean().item(), epoch) |
| writer.add_scalar('svd_film/beta_weight_norm', pe.svd_to_beta.weight.data.norm().item(), epoch) |
| writer.add_scalar('svd_film/beta_bias_norm', pe.svd_to_beta.bias.data.abs().mean().item(), epoch) |
|
|
| |
| anchor_diag = model.anchor_diagnostics() |
| for layer_name, d in anchor_diag.items(): |
| for k, v in d.items(): |
| writer.add_scalar(f'anchor_diag/{layer_name}_{k}', v, epoch) |
|
|
| |
| for i, gs in enumerate(geo_states): |
| prefix = f'layer_{i}' |
|
|
| |
| emb = gs['embedding'] |
| |
| transformer = model['transformer'] |
| layer = transformer[f'layer_{i}'] |
| anchors = F.normalize( |
| layer['observer'].association.constellation.anchors, dim=-1) |
| cv_anchors, mean_d_anchors, std_d_anchors = compute_cv(anchors) |
| writer.add_scalar(f'{prefix}/cv_anchors', cv_anchors, epoch) |
| writer.add_scalar(f'{prefix}/anchor_mean_dist', mean_d_anchors, epoch) |
| writer.add_scalar(f'{prefix}/anchor_std_dist', std_d_anchors, epoch) |
|
|
| |
| emb_flat = emb.reshape(-1, emb.shape[-1]) |
| n_sample = min(512, emb_flat.shape[0]) |
| idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample] |
| cv_emb, mean_d_emb, std_d_emb = compute_cv(emb_flat[idx]) |
| writer.add_scalar(f'{prefix}/cv_embeddings', cv_emb, epoch) |
| writer.add_scalar(f'{prefix}/embedding_mean_dist', mean_d_emb, epoch) |
|
|
| |
| gate_info = gs.get('gate_info', {}) |
| gate_values = gs.get('gate_values') |
| cm_quality = gs.get('cm_quality') |
|
|
| if gate_info: |
| writer.add_scalar(f'{prefix}/cm_active_anchors', |
| gate_info.get('active', 0), epoch) |
| writer.add_scalar(f'{prefix}/cm_gate_mean', |
| gate_info.get('gate_mean', 0), epoch) |
| writer.add_scalar(f'{prefix}/cm_positive_frac', |
| gate_info.get('cm_positive_frac', 0), epoch) |
|
|
| if gate_values is not None: |
| gv = gate_values |
| writer.add_scalar(f'{prefix}/gate_values_min', gv.min().item(), epoch) |
| writer.add_scalar(f'{prefix}/gate_values_max', gv.max().item(), epoch) |
| writer.add_scalar(f'{prefix}/gate_values_std', gv.std().item(), epoch) |
| |
| gv_per_anchor = gv.mean(dim=0).mean(dim=0) |
| writer.add_scalar(f'{prefix}/gate_anchor_spread', |
| gv_per_anchor.std().item(), epoch) |
| |
| if gv.dim() == 3: |
| pos_open_frac = (gv.mean(dim=-1) > 0.5).float().mean().item() |
| else: |
| pos_open_frac = (gv > 0.5).float().mean().item() |
| writer.add_scalar(f'{prefix}/gate_positions_open_frac', pos_open_frac, epoch) |
|
|
| if cm_quality is not None: |
| writer.add_scalar(f'{prefix}/cm_quality_mean', cm_quality.mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/cm_quality_std', cm_quality.std().item(), epoch) |
| writer.add_scalar(f'{prefix}/cm_quality_min', cm_quality.min().item(), epoch) |
|
|
| |
| content = gs['content'] |
| geometric = gs['geometric'] |
| agreement = F.cosine_similarity( |
| content.reshape(-1, content.shape[-1]), |
| geometric.reshape(-1, geometric.shape[-1]), dim=-1) |
| writer.add_scalar(f'{prefix}/stream_agreement_mean', agreement.mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/stream_agreement_std', agreement.std().item(), epoch) |
|
|
| writer.add_scalar(f'{prefix}/content_norm', content.norm(dim=-1).mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/geometric_norm', geometric.norm(dim=-1).mean().item(), epoch) |
|
|
| |
| disagree = content - geometric |
| agree = content * geometric |
| writer.add_scalar(f'{prefix}/disagree_norm', disagree.norm(dim=-1).mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/agree_norm', agree.norm(dim=-1).mean().item(), epoch) |
|
|
| |
| tri = gs['triangulation'] |
| assignment = gs['assignment'] |
| nearest = gs['nearest'] |
| n_anchors = tri.shape[-1] |
|
|
| nearest_flat = nearest.reshape(-1) |
| counts = torch.bincount(nearest_flat, minlength=n_anchors).float() |
| total_assignments = counts.sum() |
|
|
| probs = counts / (total_assignments + 1e-8) |
| anchor_entropy = -(probs * torch.log(probs.clamp(min=1e-8))).sum().item() |
| max_entropy = math.log(n_anchors) |
| writer.add_scalar(f'{prefix}/anchor_entropy_normalized', |
| anchor_entropy / (max_entropy + 1e-8), epoch) |
| active = (counts > 0).sum().item() |
| writer.add_scalar(f'{prefix}/anchors_active', active, epoch) |
| writer.add_scalar(f'{prefix}/anchors_active_frac', active / n_anchors, epoch) |
| dead = (counts == 0).sum().item() |
| writer.add_scalar(f'{prefix}/anchors_dead', dead, epoch) |
|
|
| |
| writer.add_scalar(f'{prefix}/tri_mean', tri.mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/tri_std', tri.std().item(), epoch) |
|
|
| |
| assign_ent = -(assignment * torch.log(assignment.clamp(min=1e-8))).sum(-1) |
| writer.add_scalar(f'{prefix}/assignment_entropy_mean', assign_ent.mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/assignment_max_prob', |
| assignment.max(dim=-1).values.mean().item(), epoch) |
|
|
| |
| pw = gs['patchwork'] |
| writer.add_scalar(f'{prefix}/patchwork_norm', pw.norm(dim=-1).mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/patchwork_std', pw.std().item(), epoch) |
| pw_sparsity = (pw.abs() < 0.01).float().mean().item() |
| writer.add_scalar(f'{prefix}/patchwork_sparsity', pw_sparsity, epoch) |
|
|
| |
| bridge = gs['bridge'] |
| bridge_soft = F.softmax(bridge, dim=-1) |
| bridge_assign_kl = F.kl_div( |
| bridge_soft.log().reshape(-1, n_anchors), |
| assignment.reshape(-1, n_anchors), |
| reduction='batchmean', log_target=False) |
| writer.add_scalar(f'{prefix}/bridge_assignment_kl', bridge_assign_kl.item(), epoch) |
|
|
| |
| composed = gs['composed'] |
| writer.add_scalar(f'{prefix}/composed_norm', composed.norm(dim=-1).mean().item(), epoch) |
|
|
| |
| geo_ctx = gs['geo_ctx'] |
| writer.add_scalar(f'{prefix}/geo_ctx_norm', geo_ctx.norm(dim=-1).mean().item(), epoch) |
|
|
| |
| geo_res = gs.get('geo_residual') |
| if geo_res is not None: |
| res_norms = geo_res.norm(dim=-1) |
| writer.add_scalar(f'{prefix}/geo_res_norm', res_norms.mean().item(), epoch) |
| writer.add_scalar(f'{prefix}/geo_res_std', geo_res.std().item(), epoch) |
| writer.add_scalar(f'{prefix}/geo_res_sparsity', |
| (geo_res.abs() < 0.01).float().mean().item(), epoch) |
| |
| geo_res_flat = geo_res.reshape(-1, geo_res.shape[-1]) |
| n_s = min(256, geo_res_flat.shape[0]) |
| idx_s = torch.randperm(geo_res_flat.shape[0], device=geo_res.device)[:n_s] |
| sampled = F.normalize(geo_res_flat[idx_s], dim=-1) |
| cos_mat = sampled @ sampled.T |
| triu = torch.triu_indices(n_s, n_s, offset=1, device=geo_res.device) |
| writer.add_scalar(f'{prefix}/geo_res_consistency', |
| cos_mat[triu[0], triu[1]].mean().item(), epoch) |
|
|
| |
| for name, mod in model.named_modules(): |
| if isinstance(mod, CayleyOrthogonal): |
| R = mod.get_rotation() |
| I = torch.eye(R.shape[0], device=R.device) |
| r_dist = (R - I).norm().item() |
| clean_name = name.replace('.', '_') |
| writer.add_scalar(f'cayley/{clean_name}_R_minus_I', r_dist, epoch) |
|
|
| |
| film_idx = 0 |
| for name, mod in model.named_modules(): |
| if isinstance(mod, FiLMLayer): |
| g_b = mod.to_gamma.bias.data |
| b_b = mod.to_beta.bias.data |
| writer.add_scalar(f'film/{film_idx}_gamma_dev', |
| (g_b - 1.0).abs().mean().item(), epoch) |
| writer.add_scalar(f'film/{film_idx}_beta_dev', |
| b_b.abs().mean().item(), epoch) |
| film_idx += 1 |
|
|
| |
| cv_trajectory = [] |
| cm_quality_trajectory = [] |
| res_norms = [] |
| bridge_kls = [] |
|
|
| for i, gs in enumerate(geo_states): |
| |
| emb = gs['embedding'] |
| emb_flat = emb.reshape(-1, emb.shape[-1]) |
| n_sample = min(512, emb_flat.shape[0]) |
| idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample] |
| cv, _, _ = compute_cv(emb_flat[idx]) |
| cv_trajectory.append(cv) |
|
|
| |
| cm_q = gs.get('cm_quality') |
| if cm_q is not None: |
| cm_quality_trajectory.append(cm_q.mean().item()) |
|
|
| |
| geo_res = gs.get('geo_residual') |
| if geo_res is not None: |
| res_norms.append(geo_res.norm(dim=-1).mean().item()) |
|
|
| |
| n_anchors = gs['assignment'].shape[-1] |
| bridge_soft = F.softmax(gs['bridge'], dim=-1) |
| bkl = F.kl_div( |
| bridge_soft.log().reshape(-1, n_anchors), |
| gs['assignment'].reshape(-1, n_anchors), |
| reduction='batchmean', log_target=False).item() |
| bridge_kls.append(bkl) |
|
|
| |
| writer.add_scalar('cv/trajectory_mean', np.mean(cv_trajectory), epoch) |
| writer.add_scalar('cv/trajectory_std', np.std(cv_trajectory), epoch) |
| in_band = sum(1 for cv in cv_trajectory if 0.20 <= cv <= 0.23) |
| writer.add_scalar('cv/layers_in_pentachoron_band', in_band, epoch) |
| writer.add_scalar('cv/layers_in_band_frac', in_band / len(cv_trajectory), epoch) |
|
|
| |
| if cm_quality_trajectory: |
| writer.add_scalar('cm/quality_trajectory_mean', |
| np.mean(cm_quality_trajectory), epoch) |
| writer.add_scalar('cm/quality_trajectory_std', |
| np.std(cm_quality_trajectory), epoch) |
| writer.add_scalar('cm/quality_min_layer', |
| np.min(cm_quality_trajectory), epoch) |
| writer.add_scalar('cm/quality_max_layer', |
| np.max(cm_quality_trajectory), epoch) |
|
|
| |
| if res_norms: |
| writer.add_scalar('geo_res/trajectory_start', res_norms[0], epoch) |
| writer.add_scalar('geo_res/trajectory_end', res_norms[-1], epoch) |
| writer.add_scalar('geo_res/accumulation_ratio', |
| res_norms[-1] / (res_norms[0] + 1e-8), epoch) |
| growth = [res_norms[j+1] - res_norms[j] for j in range(len(res_norms)-1)] |
| writer.add_scalar('geo_res/growth_mean', np.mean(growth), epoch) |
| writer.add_scalar('geo_res/growth_std', np.std(growth), epoch) |
|
|
| |
| if len(res_norms) >= 4: |
| cv_corr = float(np.corrcoef(res_norms, cv_trajectory)[0, 1]) |
| bkl_corr = float(np.corrcoef(res_norms, bridge_kls)[0, 1]) |
| writer.add_scalar('cooperation/geo_res_vs_cv', cv_corr, epoch) |
| writer.add_scalar('cooperation/geo_res_vs_bridge_kl', bkl_corr, epoch) |
|
|
| if len(cm_quality_trajectory) == len(res_norms): |
| cm_corr = float(np.corrcoef( |
| res_norms, cm_quality_trajectory)[0, 1]) |
| writer.add_scalar('cooperation/geo_res_vs_cm_quality', cm_corr, epoch) |
| |
| cm_cv_corr = float(np.corrcoef( |
| cm_quality_trajectory, cv_trajectory)[0, 1]) |
| writer.add_scalar('cooperation/cm_quality_vs_cv', cm_cv_corr, epoch) |
|
|
| return { |
| 'batch_acc': batch_acc, |
| 'cv_trajectory': cv_trajectory, |
| 'cm_quality_trajectory': cm_quality_trajectory, |
| 'res_norms': res_norms, |
| 'bridge_kls': bridge_kls, |
| } |
|
|
|
|
| @torch.no_grad() |
| def log_gradient_norms(model, writer, epoch): |
| """Log gradient norms per component type (includes cm_gate).""" |
| type_grads = {} |
| for name, param in model.named_parameters(): |
| if param.grad is not None: |
| grad_norm = param.grad.norm().item() |
| if 'projection' in name and 'proj' in name: |
| key = 'manifold_proj' |
| elif 'cm_gate' in name: |
| key = 'cm_gate' |
| elif 'observer' in name or 'constellation' in name or 'anchor' in name: |
| key = 'constellation' |
| elif 'context' in name: |
| key = 'geo_context' |
| elif 'content' in name: |
| key = 'content_attn' |
| elif 'geometric' in name and 'film' not in name: |
| key = 'geo_attn' |
| elif 'film' in name: |
| key = 'film' |
| elif 'rotation' in name or 'cayley' in name or 'A_upper' in name: |
| key = 'cayley' |
| elif 'compose' in name or 'quat' in name or 'proj_w' in name: |
| key = 'quaternion' |
| elif 'decode' in name: |
| key = 'decode' |
| elif 'gate' in name: |
| key = 'gate' |
| elif 'conv' in name or 'patch' in name: |
| key = 'input_stage' |
| elif 'head' in name: |
| key = 'head' |
| elif 'svd' in name: |
| key = 'svd' |
| elif 'geo_proj' in name: |
| key = 'geo_residual_proj' |
| else: |
| key = 'other' |
|
|
| if key not in type_grads: |
| type_grads[key] = [] |
| type_grads[key].append(grad_norm) |
|
|
| for key, norms in type_grads.items(): |
| writer.add_scalar(f'grad_norm/{key}_mean', np.mean(norms), epoch) |
| writer.add_scalar(f'grad_norm/{key}_max', np.max(norms), epoch) |
|
|
| total = sum(p.grad.norm().item() ** 2 |
| for p in model.parameters() if p.grad is not None) ** 0.5 |
| writer.add_scalar('grad_norm/total', total, epoch) |
|
|
|
|
| @torch.no_grad() |
| def log_weight_norms(model, writer, epoch): |
| """Log weight norms per component type.""" |
| for name, param in model.named_parameters(): |
| if 'A_upper' in name: |
| clean = name.replace('.', '_') |
| writer.add_scalar(f'weights/{clean}_norm', param.norm().item(), epoch) |
|
|
|
|
| |
| |
| |
|
|
| def get_dataloaders(config): |
| import torchvision |
| import torchvision.transforms as T |
|
|
| |
| |
| |
| |
| |
| |
| |
| train_transform = T.Compose([ |
| T.RandomCrop(32, padding=4), |
| T.RandomHorizontalFlip(), |
| T.TrivialAugmentWide(), |
| T.ToTensor(), |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| T.RandomErasing(p=config.get('random_erasing_p', 0.25), |
| scale=(0.02, 0.33), ratio=(0.3, 3.3)), |
| ]) |
| test_transform = T.Compose([ |
| T.ToTensor(), |
| T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
|
|
| train_ds = torchvision.datasets.CIFAR100( |
| root='./data', train=True, download=True, transform=train_transform) |
| test_ds = torchvision.datasets.CIFAR100( |
| root='./data', train=False, download=True, transform=test_transform) |
|
|
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=config['batch_size'], shuffle=True, |
| num_workers=config['num_workers'], pin_memory=True, drop_last=True) |
| test_loader = torch.utils.data.DataLoader( |
| test_ds, batch_size=config['batch_size'], shuffle=False, |
| num_workers=config['num_workers'], pin_memory=True) |
|
|
| return train_loader, test_loader |
|
|
|
|
| |
| |
| |
|
|
| def cutmix_batch(images, labels, alpha=1.0): |
| """Apply CutMix to a batch. Returns mixed images + label pairs + lambda. |
| |
| CutMix replaces a rectangular region of image A with image B. |
| Positions inside each region have coherent geometry β valid CM simplices. |
| The boundary between regions has mixed geometric context β the CM gate |
| should learn to suppress these positions. |
| |
| Args: |
| images: (B, C, H, W) batch |
| labels: (B,) integer labels |
| alpha: Beta distribution parameter (1.0 = uniform box sizes) |
| |
| Returns: |
| images: (B, C, H, W) mixed batch (modified in-place) |
| labels_a: (B,) labels for region A |
| labels_b: (B,) labels for region B |
| lam: float, fraction of image A remaining |
| """ |
| lam = np.random.beta(alpha, alpha) |
| B = images.size(0) |
| idx = torch.randperm(B, device=images.device) |
|
|
| H, W = images.shape[2], images.shape[3] |
| cut_ratio = (1.0 - lam) ** 0.5 |
| cw = int(W * cut_ratio) |
| ch = int(H * cut_ratio) |
| cx = np.random.randint(W) |
| cy = np.random.randint(H) |
| x1 = max(cx - cw // 2, 0); x2 = min(cx + cw // 2, W) |
| y1 = max(cy - ch // 2, 0); y2 = min(cy + ch // 2, H) |
|
|
| images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2] |
| lam_actual = 1.0 - (x2 - x1) * (y2 - y1) / (W * H) |
| return images, labels, labels[idx], lam_actual |
|
|
|
|
| |
| |
| |
|
|
| def train_epoch(model, loader, optimizer, scheduler, epoch, config, writer): |
| model.train() |
| total_loss = 0 |
| total_geo_loss = 0 |
| total_nce_loss = 0 |
| correct = 0 |
| total = 0 |
|
|
| cutmix_alpha = config.get('cutmix_alpha', 1.0) |
| cutmix_prob = config.get('cutmix_prob', 0.5) |
| label_smoothing = config.get('label_smoothing', 0.1) |
| nce_weight = config.get('nce_weight', 0.1) |
| criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing) |
|
|
| for batch_idx, (images, labels) in enumerate(loader): |
| images = images.to(device) |
| labels = labels.to(device) |
|
|
| |
| use_cutmix = np.random.rand() < cutmix_prob |
| if use_cutmix: |
| images, labels_a, labels_b, lam = cutmix_batch( |
| images, labels, alpha=cutmix_alpha) |
| logits = model(images) |
| ce_loss = lam * criterion(logits, labels_a) + \ |
| (1.0 - lam) * criterion(logits, labels_b) |
| |
| pred = logits.argmax(1) |
| correct += (lam * (pred == labels_a).float() + |
| (1.0 - lam) * (pred == labels_b).float()).sum().item() |
| else: |
| logits = model(images) |
| ce_loss = criterion(logits, labels) |
| correct += (logits.argmax(1) == labels).sum().item() |
|
|
| |
| geo_losses = model.geometric_losses() |
| geo_loss = geo_losses.get('geo_total', torch.tensor(0.0, device=device)) |
|
|
| |
| nce_losses = model.infonce_loss() |
| nce_loss = nce_losses.get('nce', torch.tensor(0.0, device=device)) |
|
|
| loss = ce_loss + geo_loss + nce_weight * nce_loss |
|
|
| optimizer.zero_grad() |
| loss.backward() |
|
|
| |
| model.update_nce_bank() |
|
|
| |
| if epoch % config['log_grads_every'] == 0 and batch_idx == 0: |
| log_gradient_norms(model, writer, epoch) |
|
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| if scheduler is not None: |
| scheduler.step() |
|
|
| total_loss += ce_loss.item() * images.size(0) |
| total_geo_loss += geo_loss.item() * images.size(0) |
| total_nce_loss += nce_loss.item() * images.size(0) |
| total += images.size(0) |
|
|
| avg_ce = total_loss / total |
| avg_geo = total_geo_loss / total |
| avg_nce = total_nce_loss / total |
| return avg_ce, avg_geo, avg_nce, correct / total |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader): |
| model.eval() |
| correct = 0 |
| total = 0 |
| for images, labels in loader: |
| images = images.to(device) |
| labels = labels.to(device) |
| logits = model(images) |
| correct += (logits.argmax(1) == labels).sum().item() |
| total += images.size(0) |
| return correct / total |
|
|
|
|
| def main(): |
| config = CONFIG.copy() |
|
|
| print("=" * 60) |
| print(" Geometric Transformer β CIFAR-100 (CM-Validated)") |
| print(f" Input: conv({config['in_channels']}β{config['conv_channels']}) + " |
| f"SVD(rank={config['svd_rank']}) + " |
| f"{config['patch_size']}Γ{config['patch_size']} patches = " |
| f"{(config['img_size']//config['patch_size'])**2} tokens + CLS") |
| print(f" Model: d={config['d_model']}, heads={config['n_heads']}, " |
| f"layers={config['n_layers']}, anchors={config['n_anchors']}") |
| print(f" CM: neighbors={config['cm_neighbors']}, " |
| f"cv_target={config['cv_target']}, " |
| f"cv_weight={config['cv_weight']}, " |
| f"spread_weight={config['spread_weight']}") |
| print(f" Aug: TrivialAugmentWide + CutMix(Ξ±={config['cutmix_alpha']}, " |
| f"p={config['cutmix_prob']}) + " |
| f"RandomErasing(p={config['random_erasing_p']})") |
| print(f" NCE: bank={config['nce_bank_size']}, " |
| f"temp={config['nce_temperature']}, " |
| f"weight={config['nce_weight']}") |
| print("=" * 60) |
|
|
| writer = SummaryWriter(config['log_dir']) |
| writer.add_text('config', json.dumps(config, indent=2)) |
|
|
| print("\nLoading CIFAR-100...") |
| train_loader, test_loader = get_dataloaders(config) |
| print(f" Train: {len(train_loader.dataset):,} | Test: {len(test_loader.dataset):,}") |
|
|
| model = GeoViTClassifier('geo_vit_cifar100', config) |
| if hasattr(model, 'network_to'): |
| model.network_to(device=device, strict=False) |
| else: |
| model = model.to(device) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"\n Total params: {n_params:,}") |
| print(f" Trainable params: {n_trainable:,}") |
|
|
| for name, module in model.named_children(): |
| n = sum(p.numel() for p in module.parameters()) |
| if n > 0: |
| print(f" {name:<20s}: {n:,}") |
|
|
| writer.add_scalar('model/total_params', n_params, 0) |
|
|
| |
| print(f"\n Initial anchor diagnostics:") |
| diag = model.anchor_diagnostics() |
| for layer_name, d in diag.items(): |
| print(f" {layer_name}: cv={d['anchor_cv']:.4f}, " |
| f"cm_pos={d['cm_positive_frac']:.3f}, " |
| f"min_dist={d['min_pairwise_dist']:.4f}") |
|
|
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) |
|
|
| total_steps = config['epochs'] * len(train_loader) |
| warmup_steps = config['warmup_epochs'] * len(train_loader) |
|
|
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / warmup_steps |
| progress = (step - warmup_steps) / (total_steps - warmup_steps) |
| return 0.5 * (1 + np.cos(np.pi * progress)) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
| print(f"\n{'β'*60}") |
| print(f" Training for {config['epochs']} epochs") |
| print(f" Warmup: {config['warmup_epochs']} epochs, " |
| f"LR: {config['lr']}, WD: {config['weight_decay']}") |
| print(f" Geo reg: cv_w={config['cv_weight']}, spread_w={config['spread_weight']}") |
| print(f" NCE bank: size={config['nce_bank_size']}, " |
| f"temp={config['nce_temperature']}, weight={config['nce_weight']}") |
| print(f" Aug: TrivialAugmentWide + CutMix(p={config['cutmix_prob']}) + " |
| f"RandomErasing(p={config['random_erasing_p']})") |
| print(f" TensorBoard: {config['log_dir']}") |
| print(f" Geo analysis every {config['log_geo_every']} epochs") |
| print(f"{'β'*60}\n") |
|
|
| best_acc = 0 |
| save_dir = Path('geo_cifar100'); save_dir.mkdir(exist_ok=True) |
|
|
| for epoch in tqdm(range(config['epochs']), desc="Epochs"): |
| t0 = time.time() |
|
|
| ce_loss, geo_loss, nce_loss, train_acc = train_epoch( |
| model, train_loader, optimizer, scheduler, epoch, config, writer) |
|
|
| test_acc = evaluate(model, test_loader) |
| elapsed = time.time() - t0 |
|
|
| lr = optimizer.param_groups[0]['lr'] |
| writer.add_scalar('train/ce_loss', ce_loss, epoch) |
| writer.add_scalar('train/geo_loss', geo_loss, epoch) |
| writer.add_scalar('train/nce_loss', nce_loss, epoch) |
| writer.add_scalar('train/total_loss', ce_loss + geo_loss + nce_loss, epoch) |
| writer.add_scalar('train/accuracy', train_acc, epoch) |
| writer.add_scalar('test/accuracy', test_acc, epoch) |
| writer.add_scalar('train/lr', lr, epoch) |
| writer.add_scalar('train/epoch_time', elapsed, epoch) |
| writer.add_scalar('gap/train_test', train_acc - test_acc, epoch) |
|
|
| log_weight_norms(model, writer, epoch) |
|
|
| if test_acc > best_acc: |
| best_acc = test_acc |
| torch.save({ |
| 'state_dict': {k: v.cpu() for k, v in model.state_dict().items()}, |
| 'epoch': epoch, |
| 'test_acc': test_acc, |
| 'config': config, |
| }, save_dir / 'best.pt') |
|
|
| |
| if epoch % config['log_geo_every'] == 0 or epoch == config['epochs'] - 1: |
| geo_info = log_geometric_analysis( |
| model, writer, epoch, test_loader, device, config) |
|
|
| cv_str = ', '.join(f'{cv:.3f}' for cv in geo_info['cv_trajectory']) |
| cm_str = ', '.join(f'{q:.3f}' for q in geo_info.get('cm_quality_trajectory', [])) |
| res_str = ', '.join(f'{r:.3f}' for r in geo_info.get('res_norms', [])) |
| tqdm.write( |
| f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} " |
| f"nce={nce_loss:.4f} " |
| f"train={train_acc:.4f} test={test_acc:.4f} " |
| f"best={best_acc:.4f} {elapsed:.1f}s" |
| f"\n CV=[{cv_str}]" |
| f"\n CM=[{cm_str}]" |
| f"\n GR=[{res_str}]") |
| elif epoch % 5 == 0: |
| tqdm.write( |
| f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} " |
| f"nce={nce_loss:.4f} " |
| f"train={train_acc:.4f} test={test_acc:.4f} " |
| f"best={best_acc:.4f} {elapsed:.1f}s") |
|
|
| |
| print(f"\n{'β'*60}") |
| print(f" CIFAR-100 RESULTS (CM-Validated)") |
| print(f"{'β'*60}") |
| print(f" Best test accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)") |
| print(f" Parameters: {n_params:,}") |
| print(f" Checkpoint: {save_dir}/best.pt") |
| print(f" TensorBoard: {config['log_dir']}") |
|
|
| |
| print(f"\n Final geometric state:") |
| geo_info = log_geometric_analysis( |
| model, writer, config['epochs'], test_loader, device, config) |
|
|
| print(f"\n Final anchor diagnostics:") |
| diag = model.anchor_diagnostics() |
| for layer_name, d in diag.items(): |
| print(f" {layer_name}: cv={d['anchor_cv']:.4f}, " |
| f"cm_pos={d['cm_positive_frac']:.3f}, " |
| f"cm_mean={d['cm_mean']:.4f}") |
|
|
| writer.close() |
| print(f"\nDone.") |
|
|
|
|
| if __name__ == '__main__': |
| main() |