Create stage1_analysis_trainer.py
Browse files- stage1_analysis_trainer.py +1049 -0
stage1_analysis_trainer.py
ADDED
|
@@ -0,0 +1,1049 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Geometric Transformer — CIFAR-100 Training with CM-Validated Analysis
|
| 3 |
+
|
| 4 |
+
Changes from previous version:
|
| 5 |
+
- CM gate diagnostics per layer: active anchors, gate_mean, cm_positive_frac
|
| 6 |
+
- CM quality in geometric residual analysis (replaces blind gate)
|
| 7 |
+
- Geometric regularization losses (CV target + anchor spread) in training loop
|
| 8 |
+
- Anchor diagnostics via model.anchor_diagnostics()
|
| 9 |
+
- CM quality trajectory alongside CV and bridge KL for cooperation analysis
|
| 10 |
+
|
| 11 |
+
TensorBoard logging of every geometric feature element:
|
| 12 |
+
- CV (coefficient of variation) per layer — the pentachoron band metric
|
| 13 |
+
- CM gate: active anchors, gate mean, cm_positive_frac, quality per position
|
| 14 |
+
- Stream agreement/divergence per layer
|
| 15 |
+
- Anchor utilization, entropy, spread
|
| 16 |
+
- Patchwork activation statistics (from CM-validated triangulation)
|
| 17 |
+
- Bridge vs assignment consistency
|
| 18 |
+
- Triangulation distance distributions
|
| 19 |
+
- SVD spectrum, entropy, novelty
|
| 20 |
+
- Quaternion arm norms and composition statistics
|
| 21 |
+
- Cayley rotation ‖R-I‖ per layer
|
| 22 |
+
- FiLM gamma/beta deviation from identity
|
| 23 |
+
- Gate activation statistics
|
| 24 |
+
- Gradient norms per component type (including cm_gate)
|
| 25 |
+
- Weight norms per component type
|
| 26 |
+
- Geometric regularization: CV loss, spread loss per epoch
|
| 27 |
+
|
| 28 |
+
!pip install geolip-core torchvision tqdm tensorboard
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
import numpy as np
|
| 35 |
+
import time, json, math
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from tqdm.auto import tqdm
|
| 38 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 39 |
+
|
| 40 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 41 |
+
print(f"Device: {device}")
|
| 42 |
+
if device.type == 'cuda':
|
| 43 |
+
print(f" GPU: {torch.cuda.get_device_name()}")
|
| 44 |
+
print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 45 |
+
|
| 46 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 47 |
+
# IMPORT TRANSFORMER
|
| 48 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 49 |
+
|
| 50 |
+
# Try geolip_core installed package first, fall back to local file
|
| 51 |
+
try:
|
| 52 |
+
from geolip_core.pipeline.components.geometric_transformer import (
|
| 53 |
+
GeometricTransformer, GeometricTransformerLayer,
|
| 54 |
+
CayleyOrthogonal, QuaternionCompose, FiLMLayer,
|
| 55 |
+
ContentAttention, GeometricAttention, CMValidatedGate,
|
| 56 |
+
TorchComponent, BaseTower,
|
| 57 |
+
anchor_neighborhood_cm,
|
| 58 |
+
)
|
| 59 |
+
print(" Imported from geolip_core (installed)")
|
| 60 |
+
except ImportError:
|
| 61 |
+
try:
|
| 62 |
+
from geometric_transformer import (
|
| 63 |
+
GeometricTransformer, GeometricTransformerLayer,
|
| 64 |
+
CayleyOrthogonal, QuaternionCompose, FiLMLayer,
|
| 65 |
+
ContentAttention, GeometricAttention, CMValidatedGate,
|
| 66 |
+
TorchComponent, BaseTower,
|
| 67 |
+
anchor_neighborhood_cm,
|
| 68 |
+
)
|
| 69 |
+
print(" Imported from local geometric_transformer.py")
|
| 70 |
+
except ImportError:
|
| 71 |
+
raise ImportError(
|
| 72 |
+
"Cannot find geometric_transformer. Place geometric_transformer.py "
|
| 73 |
+
"in the working directory or install geolip-core.")
|
| 74 |
+
|
| 75 |
+
torch.set_float32_matmul_precision('high')
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 79 |
+
# CONFIG
|
| 80 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 81 |
+
|
| 82 |
+
CONFIG = {
|
| 83 |
+
# Model
|
| 84 |
+
'd_model': 256,
|
| 85 |
+
'n_heads': 8,
|
| 86 |
+
'n_layers': 8,
|
| 87 |
+
'n_anchors': 128,
|
| 88 |
+
'manifold_dim': 128,
|
| 89 |
+
'n_comp': 4,
|
| 90 |
+
'd_comp': 16,
|
| 91 |
+
'context_dim': 64,
|
| 92 |
+
'quat_dim': 32,
|
| 93 |
+
'dropout': 0.1,
|
| 94 |
+
'cm_neighbors': 3, # CM simplex neighbors
|
| 95 |
+
|
| 96 |
+
# Input stage
|
| 97 |
+
'patch_size': 4,
|
| 98 |
+
'img_size': 32,
|
| 99 |
+
'in_channels': 3,
|
| 100 |
+
'conv_channels': 64,
|
| 101 |
+
'svd_rank': 16,
|
| 102 |
+
|
| 103 |
+
# Training
|
| 104 |
+
'epochs': 100,
|
| 105 |
+
'batch_size': 1024,
|
| 106 |
+
'lr': 1e-3,
|
| 107 |
+
'weight_decay': 0.05,
|
| 108 |
+
'warmup_epochs': 5,
|
| 109 |
+
'label_smoothing': 0.1,
|
| 110 |
+
'num_workers': 8,
|
| 111 |
+
|
| 112 |
+
# Geometric regularization
|
| 113 |
+
'cv_target': 0.215, # pentachoron band center
|
| 114 |
+
'cv_weight': 0.1, # CV loss weight
|
| 115 |
+
'spread_weight': 0.01, # anchor spread loss weight
|
| 116 |
+
|
| 117 |
+
# Augmentation — tuned for CM gate training
|
| 118 |
+
'cutmix_alpha': 1.0, # CutMix beta distribution α (1.0 = uniform box sizes)
|
| 119 |
+
'cutmix_prob': 0.5, # probability of applying CutMix per batch
|
| 120 |
+
'random_erasing_p': 0.25, # probability of erasing per image
|
| 121 |
+
|
| 122 |
+
# InfoNCE memory bank on geometric residual
|
| 123 |
+
'nce_bank_size': 4096, # queue size (0 to disable)
|
| 124 |
+
'nce_temperature': 0.1, # InfoNCE temperature
|
| 125 |
+
'nce_weight': 0.1, # loss weight
|
| 126 |
+
|
| 127 |
+
# Data
|
| 128 |
+
'num_classes': 100,
|
| 129 |
+
|
| 130 |
+
# Logging
|
| 131 |
+
'log_geo_every': 5, # full geometric analysis every N epochs
|
| 132 |
+
'log_grads_every': 10, # gradient norms every N epochs
|
| 133 |
+
'log_dir': 'runs/geo_cifar100',
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 138 |
+
# INPUT STAGE
|
| 139 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
from geolip_core.core.input.svd import SVDObserver
|
| 143 |
+
_HAS_SVD = True
|
| 144 |
+
except ImportError:
|
| 145 |
+
_HAS_SVD = False
|
| 146 |
+
|
| 147 |
+
class SVDObserver(nn.Module):
|
| 148 |
+
"""Fallback SVDObserver."""
|
| 149 |
+
def __init__(self, in_channels, svd_rank=24):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.svd_rank = svd_rank
|
| 152 |
+
self.to_svd = nn.Conv2d(in_channels, svd_rank, 1, bias=False)
|
| 153 |
+
self.register_buffer('ema_s', torch.ones(svd_rank))
|
| 154 |
+
self.register_buffer('ema_vh_flat', torch.eye(svd_rank).reshape(-1))
|
| 155 |
+
self.ema_momentum = 0.99
|
| 156 |
+
|
| 157 |
+
def extract_features(self, S, Vh):
|
| 158 |
+
B, k = S.shape
|
| 159 |
+
S_safe = S.clamp(min=1e-6)
|
| 160 |
+
s_norm = S_safe / (S_safe.sum(dim=-1, keepdim=True) + 1e-8)
|
| 161 |
+
vh_diag = Vh.diagonal(dim1=-2, dim2=-1)
|
| 162 |
+
vh_offdiag = (Vh.pow(2).sum((-2, -1)) - vh_diag.pow(2).sum(-1)).unsqueeze(-1).clamp(min=0)
|
| 163 |
+
s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1, keepdim=True)
|
| 164 |
+
out = torch.cat([s_norm, vh_diag, vh_offdiag, s_ent], dim=-1)
|
| 165 |
+
return torch.where(torch.isfinite(out), out, torch.zeros_like(out))
|
| 166 |
+
|
| 167 |
+
def compute_novelty(self, S):
|
| 168 |
+
return S - self.ema_s.clone().unsqueeze(0)
|
| 169 |
+
|
| 170 |
+
def forward(self, x):
|
| 171 |
+
B, C, H, W = x.shape
|
| 172 |
+
h = self.to_svd(x)
|
| 173 |
+
h_flat = h.permute(0, 2, 3, 1).reshape(B, H * W, self.svd_rank)
|
| 174 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
gram = torch.bmm(h_flat.float().transpose(1, 2), h_flat.float())
|
| 177 |
+
evals, evecs = torch.linalg.eigh(gram)
|
| 178 |
+
evals = evals.flip(-1).clamp(min=1e-12)
|
| 179 |
+
S = evals.sqrt()
|
| 180 |
+
Vh = evecs.flip(-1).transpose(-2, -1)
|
| 181 |
+
S = torch.where(torch.isfinite(S), S, torch.ones_like(S))
|
| 182 |
+
Vh = torch.where(torch.isfinite(Vh), Vh, torch.zeros_like(Vh))
|
| 183 |
+
features = self.extract_features(S, Vh)
|
| 184 |
+
novelty = self.compute_novelty(S)
|
| 185 |
+
return S, Vh, features, novelty
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def update_ema(self, S, Vh):
|
| 189 |
+
m = self.ema_momentum
|
| 190 |
+
self.ema_s.mul_(m).add_(S.detach().mean(0), alpha=1-m)
|
| 191 |
+
self.ema_vh_flat.mul_(m).add_(Vh.detach().mean(0).reshape(-1), alpha=1-m)
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def feature_dim(self):
|
| 195 |
+
return 2 * self.svd_rank + 2
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ConvSVDPatchEmbedding(TorchComponent):
|
| 199 |
+
"""Input stage: conv frontend → SVDObserver → patch tokens."""
|
| 200 |
+
def __init__(self, name, img_size=32, patch_size=4, in_channels=3,
|
| 201 |
+
conv_channels=64, d_model=256, svd_rank=16):
|
| 202 |
+
super().__init__(name)
|
| 203 |
+
self.patch_size = patch_size
|
| 204 |
+
self.n_patches = (img_size // patch_size) ** 2
|
| 205 |
+
self.d_model = d_model
|
| 206 |
+
self.svd_rank = svd_rank
|
| 207 |
+
|
| 208 |
+
self.conv_frontend = nn.Sequential(
|
| 209 |
+
nn.Conv2d(in_channels, conv_channels, 3, padding=1, bias=False),
|
| 210 |
+
nn.BatchNorm2d(conv_channels), nn.GELU(),
|
| 211 |
+
nn.Conv2d(conv_channels, conv_channels, 3, padding=1, bias=False),
|
| 212 |
+
nn.BatchNorm2d(conv_channels), nn.GELU(),
|
| 213 |
+
)
|
| 214 |
+
self.svd_observer = SVDObserver(conv_channels, svd_rank)
|
| 215 |
+
self.patch_proj = nn.Conv2d(
|
| 216 |
+
conv_channels, d_model, kernel_size=patch_size,
|
| 217 |
+
stride=patch_size, bias=False)
|
| 218 |
+
self.patch_norm = nn.LayerNorm(d_model)
|
| 219 |
+
|
| 220 |
+
svd_feat_dim = self.svd_observer.feature_dim
|
| 221 |
+
self.svd_to_gamma = nn.Linear(svd_feat_dim, d_model)
|
| 222 |
+
self.svd_to_beta = nn.Linear(svd_feat_dim, d_model)
|
| 223 |
+
nn.init.zeros_(self.svd_to_gamma.weight); nn.init.ones_(self.svd_to_gamma.bias)
|
| 224 |
+
nn.init.zeros_(self.svd_to_beta.weight); nn.init.zeros_(self.svd_to_beta.bias)
|
| 225 |
+
|
| 226 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)
|
| 227 |
+
self.pos_embed = nn.Parameter(
|
| 228 |
+
torch.randn(1, self.n_patches + 1, d_model) * 0.02)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
B = x.shape[0]
|
| 232 |
+
feat = self.conv_frontend(x)
|
| 233 |
+
S, Vh, svd_features, novelty = self.svd_observer(feat)
|
| 234 |
+
tokens = self.patch_proj(feat)
|
| 235 |
+
tokens = tokens.flatten(2).transpose(1, 2)
|
| 236 |
+
tokens = self.patch_norm(tokens)
|
| 237 |
+
gamma = self.svd_to_gamma(svd_features).unsqueeze(1)
|
| 238 |
+
beta = self.svd_to_beta(svd_features).unsqueeze(1)
|
| 239 |
+
tokens = gamma * tokens + beta
|
| 240 |
+
cls = self.cls_token.expand(B, -1, -1)
|
| 241 |
+
tokens = torch.cat([cls, tokens], dim=1)
|
| 242 |
+
tokens = tokens + self.pos_embed
|
| 243 |
+
svd_state = {
|
| 244 |
+
'singular_values': S, 'Vh': Vh,
|
| 245 |
+
'svd_features': svd_features, 'novelty': novelty,
|
| 246 |
+
}
|
| 247 |
+
if self.training:
|
| 248 |
+
self.svd_observer.update_ema(S, Vh)
|
| 249 |
+
return tokens, svd_state
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 253 |
+
# CLASSIFIER ( uses GeometricTransformer with CM gates)
|
| 254 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 255 |
+
|
| 256 |
+
class GeoViTClassifier(BaseTower):
|
| 257 |
+
"""Geometric Vision Transformer for classification.
|
| 258 |
+
|
| 259 |
+
Wraps ConvSVDPatchEmbedding + GeometricTransformer + task head.
|
| 260 |
+
Exposes geometric_losses() for regularization during training.
|
| 261 |
+
"""
|
| 262 |
+
def __init__(self, name, config):
|
| 263 |
+
super().__init__(name)
|
| 264 |
+
self.config = config
|
| 265 |
+
|
| 266 |
+
self.attach('patch_embed', ConvSVDPatchEmbedding(
|
| 267 |
+
'patch_embed', img_size=config['img_size'],
|
| 268 |
+
patch_size=config['patch_size'], in_channels=config['in_channels'],
|
| 269 |
+
conv_channels=config['conv_channels'], d_model=config['d_model'],
|
| 270 |
+
svd_rank=config['svd_rank'],
|
| 271 |
+
))
|
| 272 |
+
self.attach('transformer', GeometricTransformer(
|
| 273 |
+
'geo_cifar', d_model=config['d_model'], n_heads=config['n_heads'],
|
| 274 |
+
n_layers=config['n_layers'], n_anchors=config['n_anchors'],
|
| 275 |
+
manifold_dim=config['manifold_dim'], n_comp=config['n_comp'],
|
| 276 |
+
d_comp=config['d_comp'], context_dim=config['context_dim'],
|
| 277 |
+
quat_dim=config['quat_dim'], dropout=config['dropout'],
|
| 278 |
+
cm_neighbors=config.get('cm_neighbors', 3),
|
| 279 |
+
nce_bank_size=config.get('nce_bank_size', 4096),
|
| 280 |
+
nce_temperature=config.get('nce_temperature', 0.1),
|
| 281 |
+
))
|
| 282 |
+
self.attach('head', nn.Sequential(
|
| 283 |
+
nn.LayerNorm(config['d_model']),
|
| 284 |
+
nn.Linear(config['d_model'], config['d_model']),
|
| 285 |
+
nn.GELU(), nn.Dropout(config['dropout']),
|
| 286 |
+
nn.Linear(config['d_model'], config['num_classes']),
|
| 287 |
+
))
|
| 288 |
+
|
| 289 |
+
def forward(self, x, return_geo_state=False):
|
| 290 |
+
tokens, svd_state = self['patch_embed'](x)
|
| 291 |
+
if return_geo_state:
|
| 292 |
+
features, geo_states = self['transformer'](tokens, return_geo_state=True)
|
| 293 |
+
else:
|
| 294 |
+
features = self['transformer'](tokens)
|
| 295 |
+
cls_out = features[:, 0]
|
| 296 |
+
logits = self['head'](cls_out)
|
| 297 |
+
if return_geo_state:
|
| 298 |
+
return logits, geo_states, svd_state
|
| 299 |
+
return logits
|
| 300 |
+
|
| 301 |
+
def geometric_losses(self):
|
| 302 |
+
"""Delegate to transformer's built-in geometric regularization."""
|
| 303 |
+
return self['transformer'].geometric_losses(
|
| 304 |
+
cv_target=self.config.get('cv_target', 0.215),
|
| 305 |
+
cv_weight=self.config.get('cv_weight', 0.1),
|
| 306 |
+
spread_weight=self.config.get('spread_weight', 0.01),
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
def infonce_loss(self):
|
| 310 |
+
"""InfoNCE contrastive loss on CLS token's geometric residual.
|
| 311 |
+
Uses cached residual from last forward pass."""
|
| 312 |
+
return self['transformer'].infonce_loss()
|
| 313 |
+
|
| 314 |
+
def update_nce_bank(self):
|
| 315 |
+
"""Enqueue current batch's residuals. Call AFTER backward."""
|
| 316 |
+
self['transformer'].update_nce_bank()
|
| 317 |
+
|
| 318 |
+
def anchor_diagnostics(self):
|
| 319 |
+
"""Delegate to transformer's anchor diagnostics."""
|
| 320 |
+
return self['transformer'].anchor_diagnostics()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 324 |
+
# GEOMETRIC ANALYSIS BATTERY ( includes CM diagnostics)
|
| 325 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 326 |
+
|
| 327 |
+
@torch.no_grad()
|
| 328 |
+
def compute_cv(points):
|
| 329 |
+
"""Coefficient of variation on S^(d-1).
|
| 330 |
+
CV = std(pairwise_cosine_distances) / mean(pairwise_cosine_distances)
|
| 331 |
+
Pentachoron band: CV ∈ [0.20, 0.23].
|
| 332 |
+
"""
|
| 333 |
+
points = F.normalize(points.float(), dim=-1)
|
| 334 |
+
cos_sim = points @ points.T
|
| 335 |
+
n = points.shape[0]
|
| 336 |
+
idx = torch.triu_indices(n, n, offset=1, device=points.device)
|
| 337 |
+
pairwise_dist = 1.0 - cos_sim[idx[0], idx[1]]
|
| 338 |
+
mean_d = pairwise_dist.mean()
|
| 339 |
+
std_d = pairwise_dist.std()
|
| 340 |
+
cv = (std_d / (mean_d + 1e-8)).item()
|
| 341 |
+
return cv, mean_d.item(), std_d.item()
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@torch.no_grad()
|
| 345 |
+
def log_geometric_analysis(model, writer, epoch, test_loader, device, config):
|
| 346 |
+
"""Full geometric analysis battery with CM diagnostics."""
|
| 347 |
+
model.eval()
|
| 348 |
+
|
| 349 |
+
images, labels = next(iter(test_loader))
|
| 350 |
+
images = images[:min(64, images.shape[0])].to(device)
|
| 351 |
+
labels = labels[:min(64, labels.shape[0])].to(device)
|
| 352 |
+
|
| 353 |
+
logits, geo_states, svd_state = model(images, return_geo_state=True)
|
| 354 |
+
|
| 355 |
+
n_layers = len(geo_states)
|
| 356 |
+
pred = logits.argmax(1)
|
| 357 |
+
batch_acc = (pred == labels).float().mean().item()
|
| 358 |
+
writer.add_scalar('analysis/batch_accuracy', batch_acc, epoch)
|
| 359 |
+
|
| 360 |
+
# ─── SVD Input Stage ───
|
| 361 |
+
S = svd_state['singular_values']
|
| 362 |
+
s_norm = S / (S.sum(dim=-1, keepdim=True) + 1e-8)
|
| 363 |
+
s_ent = -(s_norm * torch.log(s_norm.clamp(min=1e-8))).sum(-1)
|
| 364 |
+
novelty = svd_state['novelty']
|
| 365 |
+
|
| 366 |
+
writer.add_scalar('svd/entropy_mean', s_ent.mean().item(), epoch)
|
| 367 |
+
writer.add_scalar('svd/entropy_std', s_ent.std().item(), epoch)
|
| 368 |
+
writer.add_scalar('svd/novelty_norm', novelty.norm(dim=-1).mean().item(), epoch)
|
| 369 |
+
writer.add_scalar('svd/top1_ratio', (S[:, 0] / (S.sum(-1) + 1e-8)).mean().item(), epoch)
|
| 370 |
+
writer.add_scalar('svd/condition_number',
|
| 371 |
+
(S[:, 0] / (S[:, -1].clamp(min=1e-8))).mean().item(), epoch)
|
| 372 |
+
for k in range(min(S.shape[1], 5)):
|
| 373 |
+
writer.add_scalar(f'svd/S_{k}', S[:, k].mean().item(), epoch)
|
| 374 |
+
|
| 375 |
+
# SVD FiLM deviation
|
| 376 |
+
pe = model['patch_embed']
|
| 377 |
+
writer.add_scalar('svd_film/gamma_weight_norm', pe.svd_to_gamma.weight.data.norm().item(), epoch)
|
| 378 |
+
writer.add_scalar('svd_film/gamma_bias_dev_from_1',
|
| 379 |
+
(pe.svd_to_gamma.bias.data - 1.0).abs().mean().item(), epoch)
|
| 380 |
+
writer.add_scalar('svd_film/beta_weight_norm', pe.svd_to_beta.weight.data.norm().item(), epoch)
|
| 381 |
+
writer.add_scalar('svd_film/beta_bias_norm', pe.svd_to_beta.bias.data.abs().mean().item(), epoch)
|
| 382 |
+
|
| 383 |
+
# ─── Anchor Diagnostics (built-in) ───
|
| 384 |
+
anchor_diag = model.anchor_diagnostics()
|
| 385 |
+
for layer_name, d in anchor_diag.items():
|
| 386 |
+
for k, v in d.items():
|
| 387 |
+
writer.add_scalar(f'anchor_diag/{layer_name}_{k}', v, epoch)
|
| 388 |
+
|
| 389 |
+
# ─── Per-Layer Geometric Analysis ───
|
| 390 |
+
for i, gs in enumerate(geo_states):
|
| 391 |
+
prefix = f'layer_{i}'
|
| 392 |
+
|
| 393 |
+
# === CV — pentachoron band metric ===
|
| 394 |
+
emb = gs['embedding']
|
| 395 |
+
# Anchor CV
|
| 396 |
+
transformer = model['transformer']
|
| 397 |
+
layer = transformer[f'layer_{i}']
|
| 398 |
+
anchors = F.normalize(
|
| 399 |
+
layer['observer'].association.constellation.anchors, dim=-1)
|
| 400 |
+
cv_anchors, mean_d_anchors, std_d_anchors = compute_cv(anchors)
|
| 401 |
+
writer.add_scalar(f'{prefix}/cv_anchors', cv_anchors, epoch)
|
| 402 |
+
writer.add_scalar(f'{prefix}/anchor_mean_dist', mean_d_anchors, epoch)
|
| 403 |
+
writer.add_scalar(f'{prefix}/anchor_std_dist', std_d_anchors, epoch)
|
| 404 |
+
|
| 405 |
+
# Embedding CV
|
| 406 |
+
emb_flat = emb.reshape(-1, emb.shape[-1])
|
| 407 |
+
n_sample = min(512, emb_flat.shape[0])
|
| 408 |
+
idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample]
|
| 409 |
+
cv_emb, mean_d_emb, std_d_emb = compute_cv(emb_flat[idx])
|
| 410 |
+
writer.add_scalar(f'{prefix}/cv_embeddings', cv_emb, epoch)
|
| 411 |
+
writer.add_scalar(f'{prefix}/embedding_mean_dist', mean_d_emb, epoch)
|
| 412 |
+
|
| 413 |
+
# === CM Gate Diagnostics ===
|
| 414 |
+
gate_info = gs.get('gate_info', {})
|
| 415 |
+
gate_values = gs.get('gate_values')
|
| 416 |
+
cm_quality = gs.get('cm_quality')
|
| 417 |
+
|
| 418 |
+
if gate_info:
|
| 419 |
+
writer.add_scalar(f'{prefix}/cm_active_anchors',
|
| 420 |
+
gate_info.get('active', 0), epoch)
|
| 421 |
+
writer.add_scalar(f'{prefix}/cm_gate_mean',
|
| 422 |
+
gate_info.get('gate_mean', 0), epoch)
|
| 423 |
+
writer.add_scalar(f'{prefix}/cm_positive_frac',
|
| 424 |
+
gate_info.get('cm_positive_frac', 0), epoch)
|
| 425 |
+
|
| 426 |
+
if gate_values is not None:
|
| 427 |
+
gv = gate_values
|
| 428 |
+
writer.add_scalar(f'{prefix}/gate_values_min', gv.min().item(), epoch)
|
| 429 |
+
writer.add_scalar(f'{prefix}/gate_values_max', gv.max().item(), epoch)
|
| 430 |
+
writer.add_scalar(f'{prefix}/gate_values_std', gv.std().item(), epoch)
|
| 431 |
+
# Per-anchor gate mean (which anchors are consistently open/closed)
|
| 432 |
+
gv_per_anchor = gv.mean(dim=0).mean(dim=0) # average over B and L
|
| 433 |
+
writer.add_scalar(f'{prefix}/gate_anchor_spread',
|
| 434 |
+
gv_per_anchor.std().item(), epoch)
|
| 435 |
+
# Fraction of positions with >50% anchors open
|
| 436 |
+
if gv.dim() == 3:
|
| 437 |
+
pos_open_frac = (gv.mean(dim=-1) > 0.5).float().mean().item()
|
| 438 |
+
else:
|
| 439 |
+
pos_open_frac = (gv > 0.5).float().mean().item()
|
| 440 |
+
writer.add_scalar(f'{prefix}/gate_positions_open_frac', pos_open_frac, epoch)
|
| 441 |
+
|
| 442 |
+
if cm_quality is not None:
|
| 443 |
+
writer.add_scalar(f'{prefix}/cm_quality_mean', cm_quality.mean().item(), epoch)
|
| 444 |
+
writer.add_scalar(f'{prefix}/cm_quality_std', cm_quality.std().item(), epoch)
|
| 445 |
+
writer.add_scalar(f'{prefix}/cm_quality_min', cm_quality.min().item(), epoch)
|
| 446 |
+
|
| 447 |
+
# === Stream Agreement ===
|
| 448 |
+
content = gs['content']
|
| 449 |
+
geometric = gs['geometric']
|
| 450 |
+
agreement = F.cosine_similarity(
|
| 451 |
+
content.reshape(-1, content.shape[-1]),
|
| 452 |
+
geometric.reshape(-1, geometric.shape[-1]), dim=-1)
|
| 453 |
+
writer.add_scalar(f'{prefix}/stream_agreement_mean', agreement.mean().item(), epoch)
|
| 454 |
+
writer.add_scalar(f'{prefix}/stream_agreement_std', agreement.std().item(), epoch)
|
| 455 |
+
|
| 456 |
+
writer.add_scalar(f'{prefix}/content_norm', content.norm(dim=-1).mean().item(), epoch)
|
| 457 |
+
writer.add_scalar(f'{prefix}/geometric_norm', geometric.norm(dim=-1).mean().item(), epoch)
|
| 458 |
+
|
| 459 |
+
# === Disagreement arm analysis ===
|
| 460 |
+
disagree = content - geometric
|
| 461 |
+
agree = content * geometric
|
| 462 |
+
writer.add_scalar(f'{prefix}/disagree_norm', disagree.norm(dim=-1).mean().item(), epoch)
|
| 463 |
+
writer.add_scalar(f'{prefix}/agree_norm', agree.norm(dim=-1).mean().item(), epoch)
|
| 464 |
+
|
| 465 |
+
# === Anchor Utilization ===
|
| 466 |
+
tri = gs['triangulation']
|
| 467 |
+
assignment = gs['assignment']
|
| 468 |
+
nearest = gs['nearest']
|
| 469 |
+
n_anchors = tri.shape[-1]
|
| 470 |
+
|
| 471 |
+
nearest_flat = nearest.reshape(-1)
|
| 472 |
+
counts = torch.bincount(nearest_flat, minlength=n_anchors).float()
|
| 473 |
+
total_assignments = counts.sum()
|
| 474 |
+
|
| 475 |
+
probs = counts / (total_assignments + 1e-8)
|
| 476 |
+
anchor_entropy = -(probs * torch.log(probs.clamp(min=1e-8))).sum().item()
|
| 477 |
+
max_entropy = math.log(n_anchors)
|
| 478 |
+
writer.add_scalar(f'{prefix}/anchor_entropy_normalized',
|
| 479 |
+
anchor_entropy / (max_entropy + 1e-8), epoch)
|
| 480 |
+
active = (counts > 0).sum().item()
|
| 481 |
+
writer.add_scalar(f'{prefix}/anchors_active', active, epoch)
|
| 482 |
+
writer.add_scalar(f'{prefix}/anchors_active_frac', active / n_anchors, epoch)
|
| 483 |
+
dead = (counts == 0).sum().item()
|
| 484 |
+
writer.add_scalar(f'{prefix}/anchors_dead', dead, epoch)
|
| 485 |
+
|
| 486 |
+
# === Triangulation Statistics ===
|
| 487 |
+
writer.add_scalar(f'{prefix}/tri_mean', tri.mean().item(), epoch)
|
| 488 |
+
writer.add_scalar(f'{prefix}/tri_std', tri.std().item(), epoch)
|
| 489 |
+
|
| 490 |
+
# === Soft Assignment Statistics ===
|
| 491 |
+
assign_ent = -(assignment * torch.log(assignment.clamp(min=1e-8))).sum(-1)
|
| 492 |
+
writer.add_scalar(f'{prefix}/assignment_entropy_mean', assign_ent.mean().item(), epoch)
|
| 493 |
+
writer.add_scalar(f'{prefix}/assignment_max_prob',
|
| 494 |
+
assignment.max(dim=-1).values.mean().item(), epoch)
|
| 495 |
+
|
| 496 |
+
# === Patchwork Statistics (now from CM-validated triangulation) ===
|
| 497 |
+
pw = gs['patchwork']
|
| 498 |
+
writer.add_scalar(f'{prefix}/patchwork_norm', pw.norm(dim=-1).mean().item(), epoch)
|
| 499 |
+
writer.add_scalar(f'{prefix}/patchwork_std', pw.std().item(), epoch)
|
| 500 |
+
pw_sparsity = (pw.abs() < 0.01).float().mean().item()
|
| 501 |
+
writer.add_scalar(f'{prefix}/patchwork_sparsity', pw_sparsity, epoch)
|
| 502 |
+
|
| 503 |
+
# === Bridge Consistency ===
|
| 504 |
+
bridge = gs['bridge']
|
| 505 |
+
bridge_soft = F.softmax(bridge, dim=-1)
|
| 506 |
+
bridge_assign_kl = F.kl_div(
|
| 507 |
+
bridge_soft.log().reshape(-1, n_anchors),
|
| 508 |
+
assignment.reshape(-1, n_anchors),
|
| 509 |
+
reduction='batchmean', log_target=False)
|
| 510 |
+
writer.add_scalar(f'{prefix}/bridge_assignment_kl', bridge_assign_kl.item(), epoch)
|
| 511 |
+
|
| 512 |
+
# === Quaternion Composition ===
|
| 513 |
+
composed = gs['composed']
|
| 514 |
+
writer.add_scalar(f'{prefix}/composed_norm', composed.norm(dim=-1).mean().item(), epoch)
|
| 515 |
+
|
| 516 |
+
# === Geo Context ===
|
| 517 |
+
geo_ctx = gs['geo_ctx']
|
| 518 |
+
writer.add_scalar(f'{prefix}/geo_ctx_norm', geo_ctx.norm(dim=-1).mean().item(), epoch)
|
| 519 |
+
|
| 520 |
+
# === Geometric Residual Stream (CM-conditioned) ===
|
| 521 |
+
geo_res = gs.get('geo_residual')
|
| 522 |
+
if geo_res is not None:
|
| 523 |
+
res_norms = geo_res.norm(dim=-1)
|
| 524 |
+
writer.add_scalar(f'{prefix}/geo_res_norm', res_norms.mean().item(), epoch)
|
| 525 |
+
writer.add_scalar(f'{prefix}/geo_res_std', geo_res.std().item(), epoch)
|
| 526 |
+
writer.add_scalar(f'{prefix}/geo_res_sparsity',
|
| 527 |
+
(geo_res.abs() < 0.01).float().mean().item(), epoch)
|
| 528 |
+
# Cross-position consistency
|
| 529 |
+
geo_res_flat = geo_res.reshape(-1, geo_res.shape[-1])
|
| 530 |
+
n_s = min(256, geo_res_flat.shape[0])
|
| 531 |
+
idx_s = torch.randperm(geo_res_flat.shape[0], device=geo_res.device)[:n_s]
|
| 532 |
+
sampled = F.normalize(geo_res_flat[idx_s], dim=-1)
|
| 533 |
+
cos_mat = sampled @ sampled.T
|
| 534 |
+
triu = torch.triu_indices(n_s, n_s, offset=1, device=geo_res.device)
|
| 535 |
+
writer.add_scalar(f'{prefix}/geo_res_consistency',
|
| 536 |
+
cos_mat[triu[0], triu[1]].mean().item(), epoch)
|
| 537 |
+
|
| 538 |
+
# ─── Cayley Rotation Analysis ───
|
| 539 |
+
for name, mod in model.named_modules():
|
| 540 |
+
if isinstance(mod, CayleyOrthogonal):
|
| 541 |
+
R = mod.get_rotation()
|
| 542 |
+
I = torch.eye(R.shape[0], device=R.device)
|
| 543 |
+
r_dist = (R - I).norm().item()
|
| 544 |
+
clean_name = name.replace('.', '_')
|
| 545 |
+
writer.add_scalar(f'cayley/{clean_name}_R_minus_I', r_dist, epoch)
|
| 546 |
+
|
| 547 |
+
# ─── FiLM Layer Analysis ───
|
| 548 |
+
film_idx = 0
|
| 549 |
+
for name, mod in model.named_modules():
|
| 550 |
+
if isinstance(mod, FiLMLayer):
|
| 551 |
+
g_b = mod.to_gamma.bias.data
|
| 552 |
+
b_b = mod.to_beta.bias.data
|
| 553 |
+
writer.add_scalar(f'film/{film_idx}_gamma_dev',
|
| 554 |
+
(g_b - 1.0).abs().mean().item(), epoch)
|
| 555 |
+
writer.add_scalar(f'film/{film_idx}_beta_dev',
|
| 556 |
+
b_b.abs().mean().item(), epoch)
|
| 557 |
+
film_idx += 1
|
| 558 |
+
|
| 559 |
+
# ─── Cross-Layer Trajectories ───
|
| 560 |
+
cv_trajectory = []
|
| 561 |
+
cm_quality_trajectory = []
|
| 562 |
+
res_norms = []
|
| 563 |
+
bridge_kls = []
|
| 564 |
+
|
| 565 |
+
for i, gs in enumerate(geo_states):
|
| 566 |
+
# CV
|
| 567 |
+
emb = gs['embedding']
|
| 568 |
+
emb_flat = emb.reshape(-1, emb.shape[-1])
|
| 569 |
+
n_sample = min(512, emb_flat.shape[0])
|
| 570 |
+
idx = torch.randperm(emb_flat.shape[0], device=device)[:n_sample]
|
| 571 |
+
cv, _, _ = compute_cv(emb_flat[idx])
|
| 572 |
+
cv_trajectory.append(cv)
|
| 573 |
+
|
| 574 |
+
# CM quality
|
| 575 |
+
cm_q = gs.get('cm_quality')
|
| 576 |
+
if cm_q is not None:
|
| 577 |
+
cm_quality_trajectory.append(cm_q.mean().item())
|
| 578 |
+
|
| 579 |
+
# Geo residual norms
|
| 580 |
+
geo_res = gs.get('geo_residual')
|
| 581 |
+
if geo_res is not None:
|
| 582 |
+
res_norms.append(geo_res.norm(dim=-1).mean().item())
|
| 583 |
+
|
| 584 |
+
# Bridge KL
|
| 585 |
+
n_anchors = gs['assignment'].shape[-1]
|
| 586 |
+
bridge_soft = F.softmax(gs['bridge'], dim=-1)
|
| 587 |
+
bkl = F.kl_div(
|
| 588 |
+
bridge_soft.log().reshape(-1, n_anchors),
|
| 589 |
+
gs['assignment'].reshape(-1, n_anchors),
|
| 590 |
+
reduction='batchmean', log_target=False).item()
|
| 591 |
+
bridge_kls.append(bkl)
|
| 592 |
+
|
| 593 |
+
# CV trajectory
|
| 594 |
+
writer.add_scalar('cv/trajectory_mean', np.mean(cv_trajectory), epoch)
|
| 595 |
+
writer.add_scalar('cv/trajectory_std', np.std(cv_trajectory), epoch)
|
| 596 |
+
in_band = sum(1 for cv in cv_trajectory if 0.20 <= cv <= 0.23)
|
| 597 |
+
writer.add_scalar('cv/layers_in_pentachoron_band', in_band, epoch)
|
| 598 |
+
writer.add_scalar('cv/layers_in_band_frac', in_band / len(cv_trajectory), epoch)
|
| 599 |
+
|
| 600 |
+
# CM quality trajectory
|
| 601 |
+
if cm_quality_trajectory:
|
| 602 |
+
writer.add_scalar('cm/quality_trajectory_mean',
|
| 603 |
+
np.mean(cm_quality_trajectory), epoch)
|
| 604 |
+
writer.add_scalar('cm/quality_trajectory_std',
|
| 605 |
+
np.std(cm_quality_trajectory), epoch)
|
| 606 |
+
writer.add_scalar('cm/quality_min_layer',
|
| 607 |
+
np.min(cm_quality_trajectory), epoch)
|
| 608 |
+
writer.add_scalar('cm/quality_max_layer',
|
| 609 |
+
np.max(cm_quality_trajectory), epoch)
|
| 610 |
+
|
| 611 |
+
# Geometric residual trajectory
|
| 612 |
+
if res_norms:
|
| 613 |
+
writer.add_scalar('geo_res/trajectory_start', res_norms[0], epoch)
|
| 614 |
+
writer.add_scalar('geo_res/trajectory_end', res_norms[-1], epoch)
|
| 615 |
+
writer.add_scalar('geo_res/accumulation_ratio',
|
| 616 |
+
res_norms[-1] / (res_norms[0] + 1e-8), epoch)
|
| 617 |
+
growth = [res_norms[j+1] - res_norms[j] for j in range(len(res_norms)-1)]
|
| 618 |
+
writer.add_scalar('geo_res/growth_mean', np.mean(growth), epoch)
|
| 619 |
+
writer.add_scalar('geo_res/growth_std', np.std(growth), epoch)
|
| 620 |
+
|
| 621 |
+
# Cooperation analysis (includes CM quality)
|
| 622 |
+
if len(res_norms) >= 4:
|
| 623 |
+
cv_corr = float(np.corrcoef(res_norms, cv_trajectory)[0, 1])
|
| 624 |
+
bkl_corr = float(np.corrcoef(res_norms, bridge_kls)[0, 1])
|
| 625 |
+
writer.add_scalar('cooperation/geo_res_vs_cv', cv_corr, epoch)
|
| 626 |
+
writer.add_scalar('cooperation/geo_res_vs_bridge_kl', bkl_corr, epoch)
|
| 627 |
+
|
| 628 |
+
if len(cm_quality_trajectory) == len(res_norms):
|
| 629 |
+
cm_corr = float(np.corrcoef(
|
| 630 |
+
res_norms, cm_quality_trajectory)[0, 1])
|
| 631 |
+
writer.add_scalar('cooperation/geo_res_vs_cm_quality', cm_corr, epoch)
|
| 632 |
+
# CM vs CV: do layers with better CM quality also have better CV?
|
| 633 |
+
cm_cv_corr = float(np.corrcoef(
|
| 634 |
+
cm_quality_trajectory, cv_trajectory)[0, 1])
|
| 635 |
+
writer.add_scalar('cooperation/cm_quality_vs_cv', cm_cv_corr, epoch)
|
| 636 |
+
|
| 637 |
+
return {
|
| 638 |
+
'batch_acc': batch_acc,
|
| 639 |
+
'cv_trajectory': cv_trajectory,
|
| 640 |
+
'cm_quality_trajectory': cm_quality_trajectory,
|
| 641 |
+
'res_norms': res_norms,
|
| 642 |
+
'bridge_kls': bridge_kls,
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
|
| 646 |
+
@torch.no_grad()
|
| 647 |
+
def log_gradient_norms(model, writer, epoch):
|
| 648 |
+
"""Log gradient norms per component type (includes cm_gate)."""
|
| 649 |
+
type_grads = {}
|
| 650 |
+
for name, param in model.named_parameters():
|
| 651 |
+
if param.grad is not None:
|
| 652 |
+
grad_norm = param.grad.norm().item()
|
| 653 |
+
if 'projection' in name and 'proj' in name:
|
| 654 |
+
key = 'manifold_proj'
|
| 655 |
+
elif 'cm_gate' in name:
|
| 656 |
+
key = 'cm_gate'
|
| 657 |
+
elif 'observer' in name or 'constellation' in name or 'anchor' in name:
|
| 658 |
+
key = 'constellation'
|
| 659 |
+
elif 'context' in name:
|
| 660 |
+
key = 'geo_context'
|
| 661 |
+
elif 'content' in name:
|
| 662 |
+
key = 'content_attn'
|
| 663 |
+
elif 'geometric' in name and 'film' not in name:
|
| 664 |
+
key = 'geo_attn'
|
| 665 |
+
elif 'film' in name:
|
| 666 |
+
key = 'film'
|
| 667 |
+
elif 'rotation' in name or 'cayley' in name or 'A_upper' in name:
|
| 668 |
+
key = 'cayley'
|
| 669 |
+
elif 'compose' in name or 'quat' in name or 'proj_w' in name:
|
| 670 |
+
key = 'quaternion'
|
| 671 |
+
elif 'decode' in name:
|
| 672 |
+
key = 'decode'
|
| 673 |
+
elif 'gate' in name:
|
| 674 |
+
key = 'gate'
|
| 675 |
+
elif 'conv' in name or 'patch' in name:
|
| 676 |
+
key = 'input_stage'
|
| 677 |
+
elif 'head' in name:
|
| 678 |
+
key = 'head'
|
| 679 |
+
elif 'svd' in name:
|
| 680 |
+
key = 'svd'
|
| 681 |
+
elif 'geo_proj' in name:
|
| 682 |
+
key = 'geo_residual_proj'
|
| 683 |
+
else:
|
| 684 |
+
key = 'other'
|
| 685 |
+
|
| 686 |
+
if key not in type_grads:
|
| 687 |
+
type_grads[key] = []
|
| 688 |
+
type_grads[key].append(grad_norm)
|
| 689 |
+
|
| 690 |
+
for key, norms in type_grads.items():
|
| 691 |
+
writer.add_scalar(f'grad_norm/{key}_mean', np.mean(norms), epoch)
|
| 692 |
+
writer.add_scalar(f'grad_norm/{key}_max', np.max(norms), epoch)
|
| 693 |
+
|
| 694 |
+
total = sum(p.grad.norm().item() ** 2
|
| 695 |
+
for p in model.parameters() if p.grad is not None) ** 0.5
|
| 696 |
+
writer.add_scalar('grad_norm/total', total, epoch)
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
@torch.no_grad()
|
| 700 |
+
def log_weight_norms(model, writer, epoch):
|
| 701 |
+
"""Log weight norms per component type."""
|
| 702 |
+
for name, param in model.named_parameters():
|
| 703 |
+
if 'A_upper' in name:
|
| 704 |
+
clean = name.replace('.', '_')
|
| 705 |
+
writer.add_scalar(f'weights/{clean}_norm', param.norm().item(), epoch)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 709 |
+
# DATA
|
| 710 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 711 |
+
|
| 712 |
+
def get_dataloaders(config):
|
| 713 |
+
import torchvision
|
| 714 |
+
import torchvision.transforms as T
|
| 715 |
+
|
| 716 |
+
# Augmentation pipeline tuned for geometric transformer:
|
| 717 |
+
# TrivialAugmentWide: continuous severity spectrum of geometric + photometric
|
| 718 |
+
# transforms. Exercises CM gate across full quality range — mild distortion
|
| 719 |
+
# keeps CM high, severe distortion creates partially-degenerate simplices.
|
| 720 |
+
# RandomErasing: creates degenerate manifold projections (zero-volume CM simplices).
|
| 721 |
+
# Trains CM gate to close on corrupted regions.
|
| 722 |
+
# CutMix applied at batch level in train_epoch (not here).
|
| 723 |
+
train_transform = T.Compose([
|
| 724 |
+
T.RandomCrop(32, padding=4),
|
| 725 |
+
T.RandomHorizontalFlip(),
|
| 726 |
+
T.TrivialAugmentWide(),
|
| 727 |
+
T.ToTensor(),
|
| 728 |
+
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
|
| 729 |
+
T.RandomErasing(p=config.get('random_erasing_p', 0.25),
|
| 730 |
+
scale=(0.02, 0.33), ratio=(0.3, 3.3)),
|
| 731 |
+
])
|
| 732 |
+
test_transform = T.Compose([
|
| 733 |
+
T.ToTensor(),
|
| 734 |
+
T.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
|
| 735 |
+
])
|
| 736 |
+
|
| 737 |
+
train_ds = torchvision.datasets.CIFAR100(
|
| 738 |
+
root='./data', train=True, download=True, transform=train_transform)
|
| 739 |
+
test_ds = torchvision.datasets.CIFAR100(
|
| 740 |
+
root='./data', train=False, download=True, transform=test_transform)
|
| 741 |
+
|
| 742 |
+
train_loader = torch.utils.data.DataLoader(
|
| 743 |
+
train_ds, batch_size=config['batch_size'], shuffle=True,
|
| 744 |
+
num_workers=config['num_workers'], pin_memory=True, drop_last=True)
|
| 745 |
+
test_loader = torch.utils.data.DataLoader(
|
| 746 |
+
test_ds, batch_size=config['batch_size'], shuffle=False,
|
| 747 |
+
num_workers=config['num_workers'], pin_memory=True)
|
| 748 |
+
|
| 749 |
+
return train_loader, test_loader
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 753 |
+
# CUTMIX — batch-level augmentation for CM gate boundary training
|
| 754 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 755 |
+
|
| 756 |
+
def cutmix_batch(images, labels, alpha=1.0):
|
| 757 |
+
"""Apply CutMix to a batch. Returns mixed images + label pairs + lambda.
|
| 758 |
+
|
| 759 |
+
CutMix replaces a rectangular region of image A with image B.
|
| 760 |
+
Positions inside each region have coherent geometry — valid CM simplices.
|
| 761 |
+
The boundary between regions has mixed geometric context — the CM gate
|
| 762 |
+
should learn to suppress these positions.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
images: (B, C, H, W) batch
|
| 766 |
+
labels: (B,) integer labels
|
| 767 |
+
alpha: Beta distribution parameter (1.0 = uniform box sizes)
|
| 768 |
+
|
| 769 |
+
Returns:
|
| 770 |
+
images: (B, C, H, W) mixed batch (modified in-place)
|
| 771 |
+
labels_a: (B,) labels for region A
|
| 772 |
+
labels_b: (B,) labels for region B
|
| 773 |
+
lam: float, fraction of image A remaining
|
| 774 |
+
"""
|
| 775 |
+
lam = np.random.beta(alpha, alpha)
|
| 776 |
+
B = images.size(0)
|
| 777 |
+
idx = torch.randperm(B, device=images.device)
|
| 778 |
+
|
| 779 |
+
H, W = images.shape[2], images.shape[3]
|
| 780 |
+
cut_ratio = (1.0 - lam) ** 0.5
|
| 781 |
+
cw = int(W * cut_ratio)
|
| 782 |
+
ch = int(H * cut_ratio)
|
| 783 |
+
cx = np.random.randint(W)
|
| 784 |
+
cy = np.random.randint(H)
|
| 785 |
+
x1 = max(cx - cw // 2, 0); x2 = min(cx + cw // 2, W)
|
| 786 |
+
y1 = max(cy - ch // 2, 0); y2 = min(cy + ch // 2, H)
|
| 787 |
+
|
| 788 |
+
images[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2]
|
| 789 |
+
lam_actual = 1.0 - (x2 - x1) * (y2 - y1) / (W * H)
|
| 790 |
+
return images, labels, labels[idx], lam_actual
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 794 |
+
# TRAINING (geometric losses + CutMix integrated)
|
| 795 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 796 |
+
|
| 797 |
+
def train_epoch(model, loader, optimizer, scheduler, epoch, config, writer):
|
| 798 |
+
model.train()
|
| 799 |
+
total_loss = 0
|
| 800 |
+
total_geo_loss = 0
|
| 801 |
+
total_nce_loss = 0
|
| 802 |
+
correct = 0
|
| 803 |
+
total = 0
|
| 804 |
+
|
| 805 |
+
cutmix_alpha = config.get('cutmix_alpha', 1.0)
|
| 806 |
+
cutmix_prob = config.get('cutmix_prob', 0.5)
|
| 807 |
+
label_smoothing = config.get('label_smoothing', 0.1)
|
| 808 |
+
nce_weight = config.get('nce_weight', 0.1)
|
| 809 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
| 810 |
+
|
| 811 |
+
for batch_idx, (images, labels) in enumerate(loader):
|
| 812 |
+
images = images.to(device)
|
| 813 |
+
labels = labels.to(device)
|
| 814 |
+
|
| 815 |
+
# CutMix: applied probabilistically per batch
|
| 816 |
+
use_cutmix = np.random.rand() < cutmix_prob
|
| 817 |
+
if use_cutmix:
|
| 818 |
+
images, labels_a, labels_b, lam = cutmix_batch(
|
| 819 |
+
images, labels, alpha=cutmix_alpha)
|
| 820 |
+
logits = model(images)
|
| 821 |
+
ce_loss = lam * criterion(logits, labels_a) + \
|
| 822 |
+
(1.0 - lam) * criterion(logits, labels_b)
|
| 823 |
+
# Accuracy: count correct if matches either label
|
| 824 |
+
pred = logits.argmax(1)
|
| 825 |
+
correct += (lam * (pred == labels_a).float() +
|
| 826 |
+
(1.0 - lam) * (pred == labels_b).float()).sum().item()
|
| 827 |
+
else:
|
| 828 |
+
logits = model(images)
|
| 829 |
+
ce_loss = criterion(logits, labels)
|
| 830 |
+
correct += (logits.argmax(1) == labels).sum().item()
|
| 831 |
+
|
| 832 |
+
# Geometric regularization — CV target + anchor spread
|
| 833 |
+
geo_losses = model.geometric_losses()
|
| 834 |
+
geo_loss = geo_losses.get('geo_total', torch.tensor(0.0, device=device))
|
| 835 |
+
|
| 836 |
+
# InfoNCE on geometric residual — discriminative pressure
|
| 837 |
+
nce_losses = model.infonce_loss()
|
| 838 |
+
nce_loss = nce_losses.get('nce', torch.tensor(0.0, device=device))
|
| 839 |
+
|
| 840 |
+
loss = ce_loss + geo_loss + nce_weight * nce_loss
|
| 841 |
+
|
| 842 |
+
optimizer.zero_grad()
|
| 843 |
+
loss.backward()
|
| 844 |
+
|
| 845 |
+
# Enqueue AFTER backward — detached residuals go into bank
|
| 846 |
+
model.update_nce_bank()
|
| 847 |
+
|
| 848 |
+
# Log gradient norms periodically
|
| 849 |
+
if epoch % config['log_grads_every'] == 0 and batch_idx == 0:
|
| 850 |
+
log_gradient_norms(model, writer, epoch)
|
| 851 |
+
|
| 852 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 853 |
+
optimizer.step()
|
| 854 |
+
if scheduler is not None:
|
| 855 |
+
scheduler.step()
|
| 856 |
+
|
| 857 |
+
total_loss += ce_loss.item() * images.size(0)
|
| 858 |
+
total_geo_loss += geo_loss.item() * images.size(0)
|
| 859 |
+
total_nce_loss += nce_loss.item() * images.size(0)
|
| 860 |
+
total += images.size(0)
|
| 861 |
+
|
| 862 |
+
avg_ce = total_loss / total
|
| 863 |
+
avg_geo = total_geo_loss / total
|
| 864 |
+
avg_nce = total_nce_loss / total
|
| 865 |
+
return avg_ce, avg_geo, avg_nce, correct / total
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
@torch.no_grad()
|
| 869 |
+
def evaluate(model, loader):
|
| 870 |
+
model.eval()
|
| 871 |
+
correct = 0
|
| 872 |
+
total = 0
|
| 873 |
+
for images, labels in loader:
|
| 874 |
+
images = images.to(device)
|
| 875 |
+
labels = labels.to(device)
|
| 876 |
+
logits = model(images)
|
| 877 |
+
correct += (logits.argmax(1) == labels).sum().item()
|
| 878 |
+
total += images.size(0)
|
| 879 |
+
return correct / total
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
def main():
|
| 883 |
+
config = CONFIG.copy()
|
| 884 |
+
|
| 885 |
+
print("=" * 60)
|
| 886 |
+
print(" Geometric Transformer — CIFAR-100 (CM-Validated)")
|
| 887 |
+
print(f" Input: conv({config['in_channels']}→{config['conv_channels']}) + "
|
| 888 |
+
f"SVD(rank={config['svd_rank']}) + "
|
| 889 |
+
f"{config['patch_size']}×{config['patch_size']} patches = "
|
| 890 |
+
f"{(config['img_size']//config['patch_size'])**2} tokens + CLS")
|
| 891 |
+
print(f" Model: d={config['d_model']}, heads={config['n_heads']}, "
|
| 892 |
+
f"layers={config['n_layers']}, anchors={config['n_anchors']}")
|
| 893 |
+
print(f" CM: neighbors={config['cm_neighbors']}, "
|
| 894 |
+
f"cv_target={config['cv_target']}, "
|
| 895 |
+
f"cv_weight={config['cv_weight']}, "
|
| 896 |
+
f"spread_weight={config['spread_weight']}")
|
| 897 |
+
print(f" Aug: TrivialAugmentWide + CutMix(α={config['cutmix_alpha']}, "
|
| 898 |
+
f"p={config['cutmix_prob']}) + "
|
| 899 |
+
f"RandomErasing(p={config['random_erasing_p']})")
|
| 900 |
+
print(f" NCE: bank={config['nce_bank_size']}, "
|
| 901 |
+
f"temp={config['nce_temperature']}, "
|
| 902 |
+
f"weight={config['nce_weight']}")
|
| 903 |
+
print("=" * 60)
|
| 904 |
+
|
| 905 |
+
writer = SummaryWriter(config['log_dir'])
|
| 906 |
+
writer.add_text('config', json.dumps(config, indent=2))
|
| 907 |
+
|
| 908 |
+
print("\nLoading CIFAR-100...")
|
| 909 |
+
train_loader, test_loader = get_dataloaders(config)
|
| 910 |
+
print(f" Train: {len(train_loader.dataset):,} | Test: {len(test_loader.dataset):,}")
|
| 911 |
+
|
| 912 |
+
model = GeoViTClassifier('geo_vit_cifar100', config)
|
| 913 |
+
if hasattr(model, 'network_to'):
|
| 914 |
+
model.network_to(device=device, strict=False)
|
| 915 |
+
else:
|
| 916 |
+
model = model.to(device)
|
| 917 |
+
|
| 918 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 919 |
+
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 920 |
+
print(f"\n Total params: {n_params:,}")
|
| 921 |
+
print(f" Trainable params: {n_trainable:,}")
|
| 922 |
+
|
| 923 |
+
for name, module in model.named_children():
|
| 924 |
+
n = sum(p.numel() for p in module.parameters())
|
| 925 |
+
if n > 0:
|
| 926 |
+
print(f" {name:<20s}: {n:,}")
|
| 927 |
+
|
| 928 |
+
writer.add_scalar('model/total_params', n_params, 0)
|
| 929 |
+
|
| 930 |
+
# Initial anchor diagnostics
|
| 931 |
+
print(f"\n Initial anchor diagnostics:")
|
| 932 |
+
diag = model.anchor_diagnostics()
|
| 933 |
+
for layer_name, d in diag.items():
|
| 934 |
+
print(f" {layer_name}: cv={d['anchor_cv']:.4f}, "
|
| 935 |
+
f"cm_pos={d['cm_positive_frac']:.3f}, "
|
| 936 |
+
f"min_dist={d['min_pairwise_dist']:.4f}")
|
| 937 |
+
|
| 938 |
+
# Optimizer + scheduler
|
| 939 |
+
optimizer = torch.optim.AdamW(
|
| 940 |
+
model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
|
| 941 |
+
|
| 942 |
+
total_steps = config['epochs'] * len(train_loader)
|
| 943 |
+
warmup_steps = config['warmup_epochs'] * len(train_loader)
|
| 944 |
+
|
| 945 |
+
def lr_lambda(step):
|
| 946 |
+
if step < warmup_steps:
|
| 947 |
+
return step / warmup_steps
|
| 948 |
+
progress = (step - warmup_steps) / (total_steps - warmup_steps)
|
| 949 |
+
return 0.5 * (1 + np.cos(np.pi * progress))
|
| 950 |
+
|
| 951 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 952 |
+
|
| 953 |
+
print(f"\n{'━'*60}")
|
| 954 |
+
print(f" Training for {config['epochs']} epochs")
|
| 955 |
+
print(f" Warmup: {config['warmup_epochs']} epochs, "
|
| 956 |
+
f"LR: {config['lr']}, WD: {config['weight_decay']}")
|
| 957 |
+
print(f" Geo reg: cv_w={config['cv_weight']}, spread_w={config['spread_weight']}")
|
| 958 |
+
print(f" NCE bank: size={config['nce_bank_size']}, "
|
| 959 |
+
f"temp={config['nce_temperature']}, weight={config['nce_weight']}")
|
| 960 |
+
print(f" Aug: TrivialAugmentWide + CutMix(p={config['cutmix_prob']}) + "
|
| 961 |
+
f"RandomErasing(p={config['random_erasing_p']})")
|
| 962 |
+
print(f" TensorBoard: {config['log_dir']}")
|
| 963 |
+
print(f" Geo analysis every {config['log_geo_every']} epochs")
|
| 964 |
+
print(f"{'━'*60}\n")
|
| 965 |
+
|
| 966 |
+
best_acc = 0
|
| 967 |
+
save_dir = Path('geo_cifar100'); save_dir.mkdir(exist_ok=True)
|
| 968 |
+
|
| 969 |
+
for epoch in tqdm(range(config['epochs']), desc="Epochs"):
|
| 970 |
+
t0 = time.time()
|
| 971 |
+
|
| 972 |
+
ce_loss, geo_loss, nce_loss, train_acc = train_epoch(
|
| 973 |
+
model, train_loader, optimizer, scheduler, epoch, config, writer)
|
| 974 |
+
|
| 975 |
+
test_acc = evaluate(model, test_loader)
|
| 976 |
+
elapsed = time.time() - t0
|
| 977 |
+
|
| 978 |
+
lr = optimizer.param_groups[0]['lr']
|
| 979 |
+
writer.add_scalar('train/ce_loss', ce_loss, epoch)
|
| 980 |
+
writer.add_scalar('train/geo_loss', geo_loss, epoch)
|
| 981 |
+
writer.add_scalar('train/nce_loss', nce_loss, epoch)
|
| 982 |
+
writer.add_scalar('train/total_loss', ce_loss + geo_loss + nce_loss, epoch)
|
| 983 |
+
writer.add_scalar('train/accuracy', train_acc, epoch)
|
| 984 |
+
writer.add_scalar('test/accuracy', test_acc, epoch)
|
| 985 |
+
writer.add_scalar('train/lr', lr, epoch)
|
| 986 |
+
writer.add_scalar('train/epoch_time', elapsed, epoch)
|
| 987 |
+
writer.add_scalar('gap/train_test', train_acc - test_acc, epoch)
|
| 988 |
+
|
| 989 |
+
log_weight_norms(model, writer, epoch)
|
| 990 |
+
|
| 991 |
+
if test_acc > best_acc:
|
| 992 |
+
best_acc = test_acc
|
| 993 |
+
torch.save({
|
| 994 |
+
'state_dict': {k: v.cpu() for k, v in model.state_dict().items()},
|
| 995 |
+
'epoch': epoch,
|
| 996 |
+
'test_acc': test_acc,
|
| 997 |
+
'config': config,
|
| 998 |
+
}, save_dir / 'best.pt')
|
| 999 |
+
|
| 1000 |
+
# Full geometric analysis periodically
|
| 1001 |
+
if epoch % config['log_geo_every'] == 0 or epoch == config['epochs'] - 1:
|
| 1002 |
+
geo_info = log_geometric_analysis(
|
| 1003 |
+
model, writer, epoch, test_loader, device, config)
|
| 1004 |
+
|
| 1005 |
+
cv_str = ', '.join(f'{cv:.3f}' for cv in geo_info['cv_trajectory'])
|
| 1006 |
+
cm_str = ', '.join(f'{q:.3f}' for q in geo_info.get('cm_quality_trajectory', []))
|
| 1007 |
+
res_str = ', '.join(f'{r:.3f}' for r in geo_info.get('res_norms', []))
|
| 1008 |
+
tqdm.write(
|
| 1009 |
+
f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} "
|
| 1010 |
+
f"nce={nce_loss:.4f} "
|
| 1011 |
+
f"train={train_acc:.4f} test={test_acc:.4f} "
|
| 1012 |
+
f"best={best_acc:.4f} {elapsed:.1f}s"
|
| 1013 |
+
f"\n CV=[{cv_str}]"
|
| 1014 |
+
f"\n CM=[{cm_str}]"
|
| 1015 |
+
f"\n GR=[{res_str}]")
|
| 1016 |
+
elif epoch % 5 == 0:
|
| 1017 |
+
tqdm.write(
|
| 1018 |
+
f" E{epoch:>3d} ce={ce_loss:.4f} geo={geo_loss:.4f} "
|
| 1019 |
+
f"nce={nce_loss:.4f} "
|
| 1020 |
+
f"train={train_acc:.4f} test={test_acc:.4f} "
|
| 1021 |
+
f"best={best_acc:.4f} {elapsed:.1f}s")
|
| 1022 |
+
|
| 1023 |
+
# Final summary
|
| 1024 |
+
print(f"\n{'═'*60}")
|
| 1025 |
+
print(f" CIFAR-100 RESULTS (CM-Validated)")
|
| 1026 |
+
print(f"{'═'*60}")
|
| 1027 |
+
print(f" Best test accuracy: {best_acc:.4f} ({best_acc*100:.2f}%)")
|
| 1028 |
+
print(f" Parameters: {n_params:,}")
|
| 1029 |
+
print(f" Checkpoint: {save_dir}/best.pt")
|
| 1030 |
+
print(f" TensorBoard: {config['log_dir']}")
|
| 1031 |
+
|
| 1032 |
+
# Final geometric state + anchor diagnostics
|
| 1033 |
+
print(f"\n Final geometric state:")
|
| 1034 |
+
geo_info = log_geometric_analysis(
|
| 1035 |
+
model, writer, config['epochs'], test_loader, device, config)
|
| 1036 |
+
|
| 1037 |
+
print(f"\n Final anchor diagnostics:")
|
| 1038 |
+
diag = model.anchor_diagnostics()
|
| 1039 |
+
for layer_name, d in diag.items():
|
| 1040 |
+
print(f" {layer_name}: cv={d['anchor_cv']:.4f}, "
|
| 1041 |
+
f"cm_pos={d['cm_positive_frac']:.3f}, "
|
| 1042 |
+
f"cm_mean={d['cm_mean']:.4f}")
|
| 1043 |
+
|
| 1044 |
+
writer.close()
|
| 1045 |
+
print(f"\nDone.")
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
if __name__ == '__main__':
|
| 1049 |
+
main()
|