geometric-transformer-v3 / stage1_analysis_trainer.py
AbstractPhil's picture
Create stage1_analysis_trainer.py
394b68b verified
"""
Geometric Transformer β€” CIFAR-100 Training with CM-Validated Analysis
Changes from previous version:
- CM gate diagnostics per layer: active anchors, gate_mean, cm_positive_frac
- CM quality in geometric residual analysis (replaces blind gate)
- Geometric regularization losses (CV target + anchor spread) in training loop
- Anchor diagnostics via model.anchor_diagnostics()
- CM quality trajectory alongside CV and bridge KL for cooperation analysis
TensorBoard logging of every geometric feature element:
- CV (coefficient of variation) per layer β€” the pentachoron band metric
- CM gate: active anchors, gate mean, cm_positive_frac, quality per position
- Stream agreement/divergence per layer
- Anchor utilization, entropy, spread
- Patchwork activation statistics (from CM-validated triangulation)
- Bridge vs assignment consistency
- Triangulation distance distributions
- SVD spectrum, entropy, novelty
- Quaternion arm norms and composition statistics
- Cayley rotation β€–R-Iβ€– per layer
- FiLM gamma/beta deviation from identity
- Gate activation statistics
- Gradient norms per component type (including cm_gate)
- Weight norms per component type
- Geometric regularization: CV loss, spread loss per epoch
!pip install geolip-core torchvision tqdm tensorboard
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time, json, math
from pathlib import Path
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if device.type == 'cuda':
print(f" GPU: {torch.cuda.get_device_name()}")
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ═══════════════════════════════════════════════════════════════════════════════
# IMPORT TRANSFORMER
# ═══════════════════════════════════════════════════════════════════════════════
# Try geolip_core installed package first, fall back to local file
try:
from geolip_core.pipeline.components.geometric_transformer import (
GeometricTransformer, GeometricTransformerLayer,
CayleyOrthogonal, QuaternionCompose, FiLMLayer,
ContentAttention, GeometricAttention, CMValidatedGate,
TorchComponent, BaseTower,
anchor_neighborhood_cm,
)
print(" Imported from geolip_core (installed)")
except ImportError:
try:
from geometric_transformer import (
GeometricTransformer, GeometricTransformerLayer,
CayleyOrthogonal, QuaternionCompose, FiLMLayer,
ContentAttention, GeometricAttention, CMValidatedGate,
TorchComponent, BaseTower,
anchor_neighborhood_cm,
)
print(" Imported from local geometric_transformer.py")
except ImportError:
raise ImportError(
"Cannot find geometric_transformer. Place geometric_transformer.py "
"in the working directory or install geolip-core.")
torch.set_float32_matmul_precision('high')
# ═══════════════════════════════════════════════════════════════════════════════
# CONFIG
# ═══════════════════════════════════════════════════════════════════════════════
CONFIG = {
# Model
'd_model': 256,
'n_heads': 8,
'n_layers': 8,
'n_anchors': 128,
'manifold_dim': 128,
'n_comp': 4,
'd_comp': 16,
'context_dim': 64,
'quat_dim': 32,
'dropout': 0.1,
'cm_neighbors': 3, # CM simplex neighbors
# Input stage
'patch_size': 4,
'img_size': 32,
'in_channels': 3,
'conv_channels': 64,
'svd_rank': 16,
# Training
'epochs': 100,
'batch_size': 1024,
'lr': 1e-3,
'weight_decay': 0.05,
'warmup_epochs': 5,
'label_smoothing': 0.1,
'num_workers': 8,
# Geometric regularization
'cv_target': 0.215, # pentachoron band center
'cv_weight': 0.1, # CV loss weight
'spread_weight': 0.01, # anchor spread loss weight
# Augmentation β€” tuned for CM gate training
'cutmix_alpha': 1.0, # CutMix beta distribution Ξ± (1.0 = uniform box sizes)
'cutmix_prob': 0.5, # probability of applying CutMix per batch
'random_erasing_p': 0.25, # probability of erasing per image
# InfoNCE memory bank on geometric residual
'nce_bank_size': 4096, # queue size (0 to disable)
'nce_temperature': 0.1, # InfoNCE temperature
'nce_weight': 0.1, # loss weight
# Data
'num_classes': 100,
# Logging
'log_geo_every': 5, # full geometric analysis every N epochs
'log_grads_every': 10, # gradient norms every N epochs
'log_dir': 'runs/geo_cifar100',
}
# ═══════════════════════════════════════════════════════════════════════════════
# INPUT STAGE
# ═══════════════════════════════════════════════════════════════════════════════
try:
from geolip_core.core.input.svd import SVDObserver
_HAS_SVD = True
except ImportError:
_HAS_SVD = False
class SVDObserver(nn.Module):
"""Fallback SVDObserver."""
def __init__(self, in_channels, svd_rank=24):
super().__init__()
self.svd_rank = svd_rank
self.to_svd = nn.Conv2d(in_channels, svd_rank, 1, bias=False)
self.register_buffer('ema_s', torch.ones(svd_rank))
self.register_buffer('ema_vh_flat', torch.eye(svd_rank).reshape(-1))
self.ema_momentum = 0.99
def extract_features(self, S, Vh):
B, k = S.shape
S_safe = S.clamp(min=1e-6)
s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8)
vh_diag = Vh.diagonal(dim1=-2, dim2=-1)
vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0)
s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True)
out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1)
return torch.where(torch.isfinite(out), out, torch.zeros_like(out))
def compute_novelty(self, S):
return S - self.ema_s.clone().unsqueeze(0)
def forward(self, x):
B, C, H, W = x.shape
h = self.to_svd(x)
h_flat = h.permute(0, 2, 3, 1).reshape(B, H * W, self.svd_rank)
with torch.amp.autocast('cuda', enabled=False):
with torch.no_grad():
gram = torch.bmm(h_flat.float().transpose(1, 2), h_flat.float())
evals, evecs = torch.linalg.eigh(gram)
evals = evals.flip(-1).clamp(min=1e-12)
S = evals.sqrt()
Vh = evecs.flip(-1).transpose(-2, -1)
S = torch.where(torch.isfinite(S), S, torch.ones_like(S))
Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh))
features = self.extract_features(S, Vh)
novelty = self.compute_novelty(S)
return S, Vh, features, novelty
@torch.no_grad()
def update_ema(self, S, Vh):
m = self.ema_momentum
self.ema_s.mul_(m).add_(S.detach().mean(0), alpha=1-m)
self.ema_vh_flat.mul_(m).add_(Vh.detach().mean(0).reshape(-1), alpha=1-m)
@property
def feature_dim(self):
return 2 * self.svd_rank + 2
class ConvSVDPatchEmbedding(TorchComponent):
"""Input stage: conv frontend β†’ SVDObserver β†’ patch tokens."""
def __init__(self, name, img_size=32, patch_size=4, in_channels=3,
conv_channels=64, d_model=256, svd_rank=16):
super().__init__(name)
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.d_model = d_model
self.svd_rank = svd_rank
self.conv_frontend = nn.Sequential(
nn.Conv2d(in_channels, conv_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(conv_channels), nn.GELU(),
nn.Conv2d(conv_channels, conv_channels, 3, padding=1, bias=False),
nn.BatchNorm2d(conv_channels), nn.GELU(),
)
self.svd_observer = SVDObserver(conv_channels, svd_rank)
self.patch_proj = nn.Conv2d(
conv_channels, d_model, kernel_size=patch_size,
stride=patch_size, bias=False)
self.patch_norm = nn.LayerNorm(d_model)
svd_feat_dim = self.svd_observer.feature_dim
self.svd_to_gamma = nn.Linear(svd_feat_dim, d_model)
self.svd_to_beta = nn.Linear(svd_feat_dim, d_model)
nn.init.zeros_(self.svd_to_gamma.weight); nn.init.ones_(self.svd_to_gamma.bias)
nn.init.zeros_(self.svd_to_beta.weight); nn.init.zeros_(self.svd_to_beta.bias)
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
self.pos_embed = nn.Parameter(
torch.randn(1, self.n_patches + 1, d_model) * 0.02)
def forward(self, x):
B = x.shape[0]
feat = self.conv_frontend(x)
S, Vh, svd_features, novelty = self.svd_observer(feat)
tokens = self.patch_proj(feat)
tokens = tokens.flatten(2).transpose(1, 2)
tokens = self.patch_norm(tokens)
gamma = self.svd_to_gamma(svd_features).unsqueeze(1)
beta = self.svd_to_beta(svd_features).unsqueeze(1)
tokens = gamma * tokens + beta
cls = self.cls_token.expand(B, -1, -1)
tokens = torch.cat([cls, tokens], dim=1)
tokens = tokens + self.pos_embed
svd_state = {
'singular_values': S, 'Vh': Vh,
'svd_features': svd_features, 'novelty': novelty,
}
if self.training:
self.svd_observer.update_ema(S, Vh)
return tokens, svd_state
# ═══════════════════════════════════════════════════════════════════════════════
# CLASSIFIER ( uses GeometricTransformer with CM gates)
# ═══════════════════════════════════════════════════════════════════════════════
class GeoViTClassifier(BaseTower):
"""Geometric Vision Transformer for classification.
Wraps ConvSVDPatchEmbedding + GeometricTransformer + task head.
Exposes geometric_losses() for regularization during training.
"""
def __init__(self, name, config):
super().__init__(name)
self.config = config
self.attach('patch_embed', ConvSVDPatchEmbedding(
'patch_embed', img_size=config['img_size'],
patch_size=config['patch_size'], in_channels=config['in_channels'],
conv_channels=config['conv_channels'], d_model=config['d_model'],
svd_rank=config['svd_rank'],
))
self.attach('transformer', GeometricTransformer(
'geo_cifar', d_model=config['d_model'], n_heads=config['n_heads'],
n_layers=config['n_layers'], n_anchors=config['n_anchors'],
manifold_dim=config['manifold_dim'], n_comp=config['n_comp'],
d_comp=config['d_comp'], context_dim=config['context_dim'],
quat_dim=config['quat_dim'], dropout=config['dropout'],
cm_neighbors=config.get('cm_neighbors', 3),
nce_bank_size=config.get('nce_bank_size', 4096),
nce_temperature=config.get('nce_temperature', 0.1),
))
self.attach('head', nn.Sequential(
nn.LayerNorm(config['d_model']),
nn.Linear(config['d_model'], config['d_model']),
nn.GELU(), nn.Dropout(config['dropout']),
nn.Linear(config['d_model'], config['num_classes']),
))
def forward(self, x, return_geo_state=False):
tokens, svd_state = self['patch_embed'](x)
if return_geo_state:
features, geo_states = self['transformer'](tokens, return_geo_state=True)
else:
features = self['transformer'](tokens)
cls_out = features[:, 0]
logits = self['head'](cls_out)
if return_geo_state:
return logits, geo_states, svd_state
return logits
def geometric_losses(self):
"""Delegate to transformer's built-in geometric regularization."""
return self['transformer'].geometric_losses(
cv_target=self.config.get('cv_target', 0.215),
cv_weight=self.config.get('cv_weight', 0.1),
spread_weight=self.config.get('spread_weight', 0.01),
)
def infonce_loss(self):
"""InfoNCE contrastive loss on CLS token's geometric residual.
Uses cached residual from last forward pass."""
return self['transformer'].infonce_loss()
def update_nce_bank(self):
"""Enqueue current batch's residuals. Call AFTER backward."""
self['transformer'].update_nce_bank()
def anchor_diagnostics(self):
"""Delegate to transformer's anchor diagnostics."""
return self['transformer'].anchor_diagnostics()
# ═══════════════════════════════════════════════════════════════════════════════
# GEOMETRIC ANALYSIS BATTERY ( includes CM diagnostics)
# ═══════════════════════════════════════════════════════════════════════════════
@torch.no_grad()
def compute_cv(points):
"""Coefficient of variation on S^(d-1).
CV = std(pairwise_cosine_distances) / mean(pairwise_cosine_distances)
Pentachoron band: CV ∈ [0.20, 0.23].
"""
points = F.normalize(points.float(), dim=-1)
cos_sim = points @ points.T
n = points.shape[0]
idx = torch.triu_indices(n, n, offset=1, device=points.device)
pairwise_dist = 1.0 - cos_sim[idx[0], idx[1]]
mean_d = pairwise_dist.mean()
std_d = pairwise_dist.std()
cv = (std_d / (mean_d + 1e-8)).item()
return cv, mean_d.item(), std_d.item()
@torch.no_grad()
def log_geometric_analysis(model, writer, epoch, test_loader, device, config):
"""Full geometric analysis battery with CM diagnostics."""
model.eval()
images, labels = next(iter(test_loader))
images = images[:min(64, images.shape[0])].to(device)
labels = labels[:min(64, labels.shape[0])].to(device)
logits, geo_states, svd_state = model(images, return_geo_state=True)
n_layers = len(geo_states)
pred = logits.argmax(1)
batch_acc = (pred == labels).float().mean().item()
writer.add_scalar('analysis/batch_accuracy', batch_acc, epoch)
# ─── SVD Input Stage ───
S = svd_state['singular_values']
s_norm = S / (S.sum(dim=-1, keepdim=True) + 1e-8)
s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1)
novelty = svd_state['novelty']
writer.add_scalar('svd/entropy_mean', s_ent.mean().item(), epoch)
writer.add_scalar('svd/entropy_std', s_ent.std().item(), epoch)
writer.add_scalar('svd/novelty_norm', novelty.norm(dim=-1).mean().item(), epoch)
writer.add_scalar('svd/top1_ratio', (S[:, 0] / (S.sum(-1) + 1e-8)).mean().item(), epoch)
writer.add_scalar('svd/condition_number',
(S[:, 0] / (S[:, -1].clamp(min=1e-8))).mean().item(), epoch)
for k in range(min(S.shape[1], 5)):
writer.add_scalar(f'svd/S_{k}', S[:, k].mean().item(), epoch)
# SVD FiLM deviation
pe = model['patch_embed']
writer.add_scalar('svd_film/gamma_weight_norm', pe.svd_to_gamma.weight.data.norm().item(), epoch)
writer.add_scalar('svd_film/gamma_bias_dev_from_1',
(pe.svd_to_gamma.bias.data - 1.0).abs().mean().item(), epoch)
writer.add_scalar('svd_film/beta_weight_norm', pe.svd_to_beta.weight.data.norm().item(), epoch)
writer.add_scalar('svd_film/beta_bias_norm', pe.svd_to_beta.bias.data.abs().mean().item(), epoch)
# ─── Anchor Diagnostics (built-in) ───
anchor_diag = model.anchor_diagnostics()
for layer_name, d in anchor_diag.items():
for k, v in d.items():
writer.add_scalar(f'anchor_diag/{layer_name}_{k}', v, epoch)
# ─── Per-Layer Geometric Analysis ───
for i, gs in enumerate(geo_states):
prefix = f'layer_{i}'
# === CV β€” pentachoron band metric ===
emb = gs['embedding']
# Anchor CV
transformer = model['transformer']
layer = transformer[f'layer_{i}']
anchors = F.normalize(
layer['observer'].association.constellation.anchors, dim=-1)
cv_anchors, mean_d_anchors, std_d_anchors = compute_cv(anchors)
writer.add_scalar(f'{prefix}/cv_anchors', cv_anchors, epoch)
writer.add_scalar(f'{prefix}/anchor_mean_dist', mean_d_anchors, epoch)
writer.add_scalar(f'{prefix}/anchor_std_dist', std_d_anchors, epoch)
# Embedding CV
emb_flat = emb.reshape(-1, emb.shape[-1])
n_sample = min(512, emb_flat.shape[0])
idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample]
cv_emb, mean_d_emb, std_d_emb = compute_cv(emb_flat[idx])
writer.add_scalar(f'{prefix}/cv_embeddings', cv_emb, epoch)
writer.add_scalar(f'{prefix}/embedding_mean_dist', mean_d_emb, epoch)
# === CM Gate Diagnostics ===
gate_info = gs.get('gate_info', {})
gate_values = gs.get('gate_values')
cm_quality = gs.get('cm_quality')
if gate_info:
writer.add_scalar(f'{prefix}/cm_active_anchors',
gate_info.get('active', 0), epoch)
writer.add_scalar(f'{prefix}/cm_gate_mean',
gate_info.get('gate_mean', 0), epoch)
writer.add_scalar(f'{prefix}/cm_positive_frac',
gate_info.get('cm_positive_frac', 0), epoch)
if gate_values is not None:
gv = gate_values
writer.add_scalar(f'{prefix}/gate_values_min', gv.min().item(), epoch)
writer.add_scalar(f'{prefix}/gate_values_max', gv.max().item(), epoch)
writer.add_scalar(f'{prefix}/gate_values_std', gv.std().item(), epoch)
# Per-anchor gate mean (which anchors are consistently open/closed)
gv_per_anchor = gv.mean(dim=0).mean(dim=0) # average over B and L
writer.add_scalar(f'{prefix}/gate_anchor_spread',
gv_per_anchor.std().item(), epoch)
# Fraction of positions with >50% anchors open
if gv.dim() == 3:
pos_open_frac = (gv.mean(dim=-1) > 0.5).float().mean().item()
else:
pos_open_frac = (gv > 0.5).float().mean().item()
writer.add_scalar(f'{prefix}/gate_positions_open_frac', pos_open_frac, epoch)
if cm_quality is not None:
writer.add_scalar(f'{prefix}/cm_quality_mean', cm_quality.mean().item(), epoch)
writer.add_scalar(f'{prefix}/cm_quality_std', cm_quality.std().item(), epoch)
writer.add_scalar(f'{prefix}/cm_quality_min', cm_quality.min().item(), epoch)
# === Stream Agreement ===
content = gs['content']
geometric = gs['geometric']
agreement = F.cosine_similarity(
content.reshape(-1, content.shape[-1]),
geometric.reshape(-1, geometric.shape[-1]), dim=-1)
writer.add_scalar(f'{prefix}/stream_agreement_mean', agreement.mean().item(), epoch)
writer.add_scalar(f'{prefix}/stream_agreement_std', agreement.std().item(), epoch)
writer.add_scalar(f'{prefix}/content_norm', content.norm(dim=-1).mean().item(), epoch)
writer.add_scalar(f'{prefix}/geometric_norm', geometric.norm(dim=-1).mean().item(), epoch)
# === Disagreement arm analysis ===
disagree = content - geometric
agree = content * geometric
writer.add_scalar(f'{prefix}/disagree_norm', disagree.norm(dim=-1).mean().item(), epoch)
writer.add_scalar(f'{prefix}/agree_norm', agree.norm(dim=-1).mean().item(), epoch)
# === Anchor Utilization ===
tri = gs['triangulation']
assignment = gs['assignment']
nearest = gs['nearest']
n_anchors = tri.shape[-1]
nearest_flat = nearest.reshape(-1)
counts = torch.bincount(nearest_flat, minlength=n_anchors).float()
total_assignments = counts.sum()
probs = counts / (total_assignments + 1e-8)
anchor_entropy = -(probs * torch.log(probs.clamp(min=1e-8))).sum().item()
max_entropy = math.log(n_anchors)
writer.add_scalar(f'{prefix}/anchor_entropy_normalized',
anchor_entropy / (max_entropy + 1e-8), epoch)
active = (counts > 0).sum().item()
writer.add_scalar(f'{prefix}/anchors_active', active, epoch)
writer.add_scalar(f'{prefix}/anchors_active_frac', active / n_anchors, epoch)
dead = (counts == 0).sum().item()
writer.add_scalar(f'{prefix}/anchors_dead', dead, epoch)
# === Triangulation Statistics ===
writer.add_scalar(f'{prefix}/tri_mean', tri.mean().item(), epoch)
writer.add_scalar(f'{prefix}/tri_std', tri.std().item(), epoch)
# === Soft Assignment Statistics ===
assign_ent = -(assignment * torch.log(assignment.clamp(min=1e-8))).sum(-1)
writer.add_scalar(f'{prefix}/assignment_entropy_mean', assign_ent.mean().item(), epoch)
writer.add_scalar(f'{prefix}/assignment_max_prob',
assignment.max(dim=-1).values.mean().item(), epoch)
# === Patchwork Statistics (now from CM-validated triangulation) ===
pw = gs['patchwork']
writer.add_scalar(f'{prefix}/patchwork_norm', pw.norm(dim=-1).mean().item(), epoch)
writer.add_scalar(f'{prefix}/patchwork_std', pw.std().item(), epoch)
pw_sparsity = (pw.abs() < 0.01).float().mean().item()
writer.add_scalar(f'{prefix}/patchwork_sparsity', pw_sparsity, epoch)
# === Bridge Consistency ===
bridge = gs['bridge']
bridge_soft = F.softmax(bridge, dim=-1)
bridge_assign_kl = F.kl_div(
bridge_soft.log().reshape(-1, n_anchors),
assignment.reshape(-1, n_anchors),
reduction='batchmean', log_target=False)
writer.add_scalar(f'{prefix}/bridge_assignment_kl', bridge_assign_kl.item(), epoch)
# === Quaternion Composition ===
composed = gs['composed']
writer.add_scalar(f'{prefix}/composed_norm', composed.norm(dim=-1).mean().item(), epoch)
# === Geo Context ===
geo_ctx = gs['geo_ctx']
writer.add_scalar(f'{prefix}/geo_ctx_norm', geo_ctx.norm(dim=-1).mean().item(), epoch)
# === Geometric Residual Stream (CM-conditioned) ===
geo_res = gs.get('geo_residual')
if geo_res is not None:
res_norms = geo_res.norm(dim=-1)
writer.add_scalar(f'{prefix}/geo_res_norm', res_norms.mean().item(), epoch)
writer.add_scalar(f'{prefix}/geo_res_std', geo_res.std().item(), epoch)
writer.add_scalar(f'{prefix}/geo_res_sparsity',
(geo_res.abs() < 0.01).float().mean().item(), epoch)
# Cross-position consistency
geo_res_flat = geo_res.reshape(-1, geo_res.shape[-1])
n_s = min(256, geo_res_flat.shape[0])
idx_s = torch.randperm(geo_res_flat.shape[0], device=geo_res.device)[:n_s]
sampled = F.normalize(geo_res_flat[idx_s], dim=-1)
cos_mat = sampled @ sampled.T
triu = torch.triu_indices(n_s, n_s, offset=1, device=geo_res.device)
writer.add_scalar(f'{prefix}/geo_res_consistency',
cos_mat[triu[0], triu[1]].mean().item(), epoch)
# ─── Cayley Rotation Analysis ───
for name, mod in model.named_modules():
if isinstance(mod, CayleyOrthogonal):
R = mod.get_rotation()
I = torch.eye(R.shape[0], device=R.device)
r_dist = (R - I).norm().item()
clean_name = name.replace('.', '_')
writer.add_scalar(f'cayley/{clean_name}_R_minus_I', r_dist, epoch)
# ─── FiLM Layer Analysis ───
film_idx = 0
for name, mod in model.named_modules():
if isinstance(mod, FiLMLayer):
g_b = mod.to_gamma.bias.data
b_b = mod.to_beta.bias.data
writer.add_scalar(f'film/{film_idx}_gamma_dev',
(g_b - 1.0).abs().mean().item(), epoch)
writer.add_scalar(f'film/{film_idx}_beta_dev',
b_b.abs().mean().item(), epoch)
film_idx += 1
# ─── Cross-Layer Trajectories ───
cv_trajectory = []
cm_quality_trajectory = []
res_norms = []
bridge_kls = []
for i, gs in enumerate(geo_states):
# CV
emb = gs['embedding']
emb_flat = emb.reshape(-1, emb.shape[-1])
n_sample = min(512, emb_flat.shape[0])
idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample]
cv, _, _ = compute_cv(emb_flat[idx])
cv_trajectory.append(cv)
# CM quality
cm_q = gs.get('cm_quality')
if cm_q is not None:
cm_quality_trajectory.append(cm_q.mean().item())
# Geo residual norms
geo_res = gs.get('geo_residual')
if geo_res is not None:
res_norms.append(geo_res.norm(dim=-1).mean().item())
# Bridge KL
n_anchors = gs['assignment'].shape[-1]
bridge_soft = F.softmax(gs['bridge'], dim=-1)
bkl = F.kl_div(
bridge_soft.log().reshape(-1, n_anchors),
gs['assignment'].reshape(-1, n_anchors),
reduction='batchmean', log_target=False).item()
bridge_kls.append(bkl)
# CV trajectory
writer.add_scalar('cv/trajectory_mean', np.mean(cv_trajectory), epoch)
writer.add_scalar('cv/trajectory_std', np.std(cv_trajectory), epoch)
in_band = sum(1 for cv in cv_trajectory if 0.20 <= cv <= 0.23)
writer.add_scalar('cv/layers_in_pentachoron_band', in_band, epoch)
writer.add_scalar('cv/layers_in_band_frac', in_band / len(cv_trajectory), epoch)
# CM quality trajectory
if cm_quality_trajectory:
writer.add_scalar('cm/quality_trajectory_mean',
np.mean(cm_quality_trajectory), epoch)
writer.add_scalar('cm/quality_trajectory_std',
np.std(cm_quality_trajectory), epoch)
writer.add_scalar('cm/quality_min_layer',
np.min(cm_quality_trajectory), epoch)
writer.add_scalar('cm/quality_max_layer',
np.max(cm_quality_trajectory), epoch)
# Geometric residual trajectory
if res_norms:
writer.add_scalar('geo_res/trajectory_start', res_norms[0], epoch)
writer.add_scalar('geo_res/trajectory_end', res_norms[-1], epoch)
writer.add_scalar('geo_res/accumulation_ratio',
res_norms[-1] / (res_norms[0] + 1e-8), epoch)
growth = [res_norms[j+1] - res_norms[j] for j in range(len(res_norms)-1)]
writer.add_scalar('geo_res/growth_mean', np.mean(growth), epoch)
writer.add_scalar('geo_res/growth_std', np.std(growth), epoch)
# Cooperation analysis (includes CM quality)
if len(res_norms) >= 4:
cv_corr = float(np.corrcoef(res_norms, cv_trajectory)[0, 1])
bkl_corr = float(np.corrcoef(res_norms, bridge_kls)[0, 1])
writer.add_scalar('cooperation/geo_res_vs_cv', cv_corr, epoch)
writer.add_scalar('cooperation/geo_res_vs_bridge_kl', bkl_corr, epoch)
if len(cm_quality_trajectory) == len(res_norms):
cm_corr = float(np.corrcoef(
res_norms, cm_quality_trajectory)[0, 1])
writer.add_scalar('cooperation/geo_res_vs_cm_quality', cm_corr, epoch)
# CM vs CV: do layers with better CM quality also have better CV?
cm_cv_corr = float(np.corrcoef(
cm_quality_trajectory, cv_trajectory)[0, 1])
writer.add_scalar('cooperation/cm_quality_vs_cv', cm_cv_corr, epoch)
return {
'batch_acc': batch_acc,
'cv_trajectory': cv_trajectory,
'cm_quality_trajectory': cm_quality_trajectory,
'res_norms': res_norms,
'bridge_kls': bridge_kls,
}
@torch.no_grad()
def log_gradient_norms(model, writer, epoch):
"""Log gradient norms per component type (includes cm_gate)."""
type_grads = {}
for name, param in model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.norm().item()
if 'projection' in name and 'proj' in name:
key = 'manifold_proj'
elif 'cm_gate' in name:
key = 'cm_gate'
elif 'observer' in name or 'constellation' in name or 'anchor' in name:
key = 'constellation'
elif 'context' in name:
key = 'geo_context'
elif 'content' in name:
key = 'content_attn'
elif 'geometric' in name and 'film' not in name:
key = 'geo_attn'
elif 'film' in name:
key = 'film'
elif 'rotation' in name or 'cayley' in name or 'A_upper' in name:
key = 'cayley'
elif 'compose' in name or 'quat' in name or 'proj_w' in name:
key = 'quaternion'
elif 'decode' in name:
key = 'decode'
elif 'gate' in name:
key = 'gate'
elif 'conv' in name or 'patch' in name:
key = 'input_stage'
elif 'head' in name:
key = 'head'
elif 'svd' in name:
key = 'svd'
elif 'geo_proj' in name:
key = 'geo_residual_proj'
else:
key = 'other'
if key not in type_grads:
type_grads[key] = []
type_grads[key].append(grad_norm)
for key, norms in type_grads.items():
writer.add_scalar(f'grad_norm/{key}_mean', np.mean(norms), epoch)
writer.add_scalar(f'grad_norm/{key}_max', np.max(norms), epoch)
total = sum(p.grad.norm().item() ** 2
for p in model.parameters() if p.grad is not None) ** 0.5
writer.add_scalar('grad_norm/total', total, epoch)
@torch.no_grad()
def log_weight_norms(model, writer, epoch):
"""Log weight norms per component type."""
for name, param in model.named_parameters():
if 'A_upper' in name:
clean = name.replace('.', '_')
writer.add_scalar(f'weights/{clean}_norm', param.norm().item(), epoch)
# ═══════════════════════════════════════════════════════════════════════════════
# DATA
# ═══════════════════════════════════════════════════════════════════════════════
def get_dataloaders(config):
import torchvision
import torchvision.transforms as T
# Augmentation pipeline tuned for geometric transformer:
# TrivialAugmentWide: continuous severity spectrum of geometric + photometric
# transforms. Exercises CM gate across full quality range β€” mild distortion
# keeps CM high, severe distortion creates partially-degenerate simplices.
# RandomErasing: creates degenerate manifold projections (zero-volume CM simplices).
# Trains CM gate to close on corrupted regions.
# CutMix applied at batch level in train_epoch (not here).
train_transform = T.Compose([
T.RandomCrop(32, padding=4),
T.RandomHorizontalFlip(),
T.TrivialAugmentWide(),
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
T.RandomErasing(p=config.get('random_erasing_p', 0.25),
scale=(0.02, 0.33), ratio=(0.3, 3.3)),
])
test_transform = T.Compose([
T.ToTensor(),
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_ds = torchvision.datasets.CIFAR100(
root='./data', train=True, download=True, transform=train_transform)
test_ds = torchvision.datasets.CIFAR100(
root='./data', train=False, download=True, transform=test_transform)
train_loader = torch.utils.data.DataLoader(
train_ds, batch_size=config['batch_size'], shuffle=True,
num_workers=config['num_workers'], pin_memory=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(
test_ds, batch_size=config['batch_size'], shuffle=False,
num_workers=config['num_workers'], pin_memory=True)
return train_loader, test_loader
# ═══════════════════════════════════════════════════════════════════════════════
# CUTMIX β€” batch-level augmentation for CM gate boundary training
# ═══════════════════════════════════════════════════════════════════════════════
def cutmix_batch(images, labels, alpha=1.0):
"""Apply CutMix to a batch. Returns mixed images + label pairs + lambda.
CutMix replaces a rectangular region of image A with image B.
Positions inside each region have coherent geometry β€” valid CM simplices.
The boundary between regions has mixed geometric context β€” the CM gate
should learn to suppress these positions.
Args:
images: (B, C, H, W) batch
labels: (B,) integer labels
alpha: Beta distribution parameter (1.0 = uniform box sizes)
Returns:
images: (B, C, H, W) mixed batch (modified in-place)
labels_a: (B,) labels for region A
labels_b: (B,) labels for region B
lam: float, fraction of image A remaining
"""
lam = np.random.beta(alpha, alpha)
B = images.size(0)
idx = torch.randperm(B, device=images.device)
H, W = images.shape[2], images.shape[3]
cut_ratio = (1.0 - lam) ** 0.5
cw = int(W * cut_ratio)
ch = int(H * cut_ratio)
cx = np.random.randint(W)
cy = np.random.randint(H)
x1 = max(cx - cw // 2, 0); x2 = min(cx + cw // 2, W)
y1 = max(cy - ch // 2, 0); y2 = min(cy + ch // 2, H)
images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2]
lam_actual = 1.0 - (x2 - x1) * (y2 - y1) / (W * H)
return images, labels, labels[idx], lam_actual
# ═══════════════════════════════════════════════════════════════════════════════
# TRAINING (geometric losses + CutMix integrated)
# ═══════════════════════════════════════════════════════════════════════════════
def train_epoch(model, loader, optimizer, scheduler, epoch, config, writer):
model.train()
total_loss = 0
total_geo_loss = 0
total_nce_loss = 0
correct = 0
total = 0
cutmix_alpha = config.get('cutmix_alpha', 1.0)
cutmix_prob = config.get('cutmix_prob', 0.5)
label_smoothing = config.get('label_smoothing', 0.1)
nce_weight = config.get('nce_weight', 0.1)
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
for batch_idx, (images, labels) in enumerate(loader):
images = images.to(device)
labels = labels.to(device)
# CutMix: applied probabilistically per batch
use_cutmix = np.random.rand() < cutmix_prob
if use_cutmix:
images, labels_a, labels_b, lam = cutmix_batch(
images, labels, alpha=cutmix_alpha)
logits = model(images)
ce_loss = lam * criterion(logits, labels_a) + \
(1.0 - lam) * criterion(logits, labels_b)
# Accuracy: count correct if matches either label
pred = logits.argmax(1)
correct += (lam * (pred == labels_a).float() +
(1.0 - lam) * (pred == labels_b).float()).sum().item()
else:
logits = model(images)
ce_loss = criterion(logits, labels)
correct += (logits.argmax(1) == labels).sum().item()
# Geometric regularization β€” CV target + anchor spread
geo_losses = model.geometric_losses()
geo_loss = geo_losses.get('geo_total', torch.tensor(0.0, device=device))
# InfoNCE on geometric residual β€” discriminative pressure
nce_losses = model.infonce_loss()
nce_loss = nce_losses.get('nce', torch.tensor(0.0, device=device))
loss = ce_loss + geo_loss + nce_weight * nce_loss
optimizer.zero_grad()
loss.backward()
# Enqueue AFTER backward β€” detached residuals go into bank
model.update_nce_bank()
# Log gradient norms periodically
if epoch % config['log_grads_every'] == 0 and batch_idx == 0:
log_gradient_norms(model, writer, epoch)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if scheduler is not None:
scheduler.step()
total_loss += ce_loss.item() * images.size(0)
total_geo_loss += geo_loss.item() * images.size(0)
total_nce_loss += nce_loss.item() * images.size(0)
total += images.size(0)
avg_ce = total_loss / total
avg_geo = total_geo_loss / total
avg_nce = total_nce_loss / total
return avg_ce, avg_geo, avg_nce, correct / total
@torch.no_grad()
def evaluate(model, loader):
model.eval()
correct = 0
total = 0
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
logits = model(images)
correct += (logits.argmax(1) == labels).sum().item()
total += images.size(0)
return correct / total
def main():
config = CONFIG.copy()
print("=" * 60)
print(" Geometric Transformer β€” CIFAR-100 (CM-Validated)")
print(f" Input: conv({config['in_channels']}β†’{config['conv_channels']}) + "
f"SVD(rank={config['svd_rank']}) + "
f"{config['patch_size']}Γ—{config['patch_size']} patches = "
f"{(config['img_size']//config['patch_size'])**2} tokens + CLS")
print(f" Model: d={config['d_model']}, heads={config['n_heads']}, "
f"layers={config['n_layers']}, anchors={config['n_anchors']}")
print(f" CM: neighbors={config['cm_neighbors']}, "
f"cv_target={config['cv_target']}, "
f"cv_weight={config['cv_weight']}, "
f"spread_weight={config['spread_weight']}")
print(f" Aug: TrivialAugmentWide + CutMix(Ξ±={config['cutmix_alpha']}, "
f"p={config['cutmix_prob']}) + "
f"RandomErasing(p={config['random_erasing_p']})")
print(f" NCE: bank={config['nce_bank_size']}, "
f"temp={config['nce_temperature']}, "
f"weight={config['nce_weight']}")
print("=" * 60)
writer = SummaryWriter(config['log_dir'])
writer.add_text('config', json.dumps(config, indent=2))
print("\nLoading CIFAR-100...")
train_loader, test_loader = get_dataloaders(config)
print(f" Train: {len(train_loader.dataset):,} | Test: {len(test_loader.dataset):,}")
model = GeoViTClassifier('geo_vit_cifar100', config)
if hasattr(model, 'network_to'):
model.network_to(device=device, strict=False)
else:
model = model.to(device)
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n Total params: {n_params:,}")
print(f" Trainable params: {n_trainable:,}")
for name, module in model.named_children():
n = sum(p.numel() for p in module.parameters())
if n > 0:
print(f" {name:<20s}: {n:,}")
writer.add_scalar('model/total_params', n_params, 0)
# Initial anchor diagnostics
print(f"\n Initial anchor diagnostics:")
diag = model.anchor_diagnostics()
for layer_name, d in diag.items():
print(f" {layer_name}: cv={d['anchor_cv']:.4f}, "
f"cm_pos={d['cm_positive_frac']:.3f}, "
f"min_dist={d['min_pairwise_dist']:.4f}")
# Optimizer + scheduler
optimizer = torch.optim.AdamW(
model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
total_steps = config['epochs'] * len(train_loader)
warmup_steps = config['warmup_epochs'] * len(train_loader)
def lr_lambda(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + np.cos(np.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
print(f"\n{'━'*60}")
print(f" Training for {config['epochs']} epochs")
print(f" Warmup: {config['warmup_epochs']} epochs, "
f"LR: {config['lr']}, WD: {config['weight_decay']}")
print(f" Geo reg: cv_w={config['cv_weight']}, spread_w={config['spread_weight']}")
print(f" NCE bank: size={config['nce_bank_size']}, "
f"temp={config['nce_temperature']}, weight={config['nce_weight']}")
print(f" Aug: TrivialAugmentWide + CutMix(p={config['cutmix_prob']}) + "
f"RandomErasing(p={config['random_erasing_p']})")
print(f" TensorBoard: {config['log_dir']}")
print(f" Geo analysis every {config['log_geo_every']} epochs")
print(f"{'━'*60}\n")
best_acc = 0
save_dir = Path('geo_cifar100'); save_dir.mkdir(exist_ok=True)
for epoch in tqdm(range(config['epochs']), desc="Epochs"):
t0 = time.time()
ce_loss, geo_loss, nce_loss, train_acc = train_epoch(
model, train_loader, optimizer, scheduler, epoch, config, writer)
test_acc = evaluate(model, test_loader)
elapsed = time.time() - t0
lr = optimizer.param_groups[0]['lr']
writer.add_scalar('train/ce_loss', ce_loss, epoch)
writer.add_scalar('train/geo_loss', geo_loss, epoch)
writer.add_scalar('train/nce_loss', nce_loss, epoch)
writer.add_scalar('train/total_loss', ce_loss + geo_loss + nce_loss, epoch)
writer.add_scalar('train/accuracy', train_acc, epoch)
writer.add_scalar('test/accuracy', test_acc, epoch)
writer.add_scalar('train/lr', lr, epoch)
writer.add_scalar('train/epoch_time', elapsed, epoch)
writer.add_scalar('gap/train_test', train_acc - test_acc, epoch)
log_weight_norms(model, writer, epoch)
if test_acc > best_acc:
best_acc = test_acc
torch.save({
'state_dict': {k: v.cpu() for k, v in model.state_dict().items()},
'epoch': epoch,
'test_acc': test_acc,
'config': config,
}, save_dir / 'best.pt')
# Full geometric analysis periodically
if epoch % config['log_geo_every'] == 0 or epoch == config['epochs'] - 1:
geo_info = log_geometric_analysis(
model, writer, epoch, test_loader, device, config)
cv_str = ', '.join(f'{cv:.3f}' for cv in geo_info['cv_trajectory'])
cm_str = ', '.join(f'{q:.3f}' for q in geo_info.get('cm_quality_trajectory', []))
res_str = ', '.join(f'{r:.3f}' for r in geo_info.get('res_norms', []))
tqdm.write(
f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} "
f"nce={nce_loss:.4f} "
f"train={train_acc:.4f} test={test_acc:.4f} "
f"best={best_acc:.4f} {elapsed:.1f}s"
f"\n CV=[{cv_str}]"
f"\n CM=[{cm_str}]"
f"\n GR=[{res_str}]")
elif epoch % 5 == 0:
tqdm.write(
f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} "
f"nce={nce_loss:.4f} "
f"train={train_acc:.4f} test={test_acc:.4f} "
f"best={best_acc:.4f} {elapsed:.1f}s")
# Final summary
print(f"\n{'═'*60}")
print(f" CIFAR-100 RESULTS (CM-Validated)")
print(f"{'═'*60}")
print(f" Best test accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)")
print(f" Parameters: {n_params:,}")
print(f" Checkpoint: {save_dir}/best.pt")
print(f" TensorBoard: {config['log_dir']}")
# Final geometric state + anchor diagnostics
print(f"\n Final geometric state:")
geo_info = log_geometric_analysis(
model, writer, config['epochs'], test_loader, device, config)
print(f"\n Final anchor diagnostics:")
diag = model.anchor_diagnostics()
for layer_name, d in diag.items():
print(f" {layer_name}: cv={d['anchor_cv']:.4f}, "
f"cm_pos={d['cm_positive_frac']:.3f}, "
f"cm_mean={d['cm_mean']:.4f}")
writer.close()
print(f"\nDone.")
if __name__ == '__main__':
main()