geolip-cv-experiments / colab_cv_loss_sweep_sequential.py
AbstractPhil's picture
Rename colab_loss_sweep_sequential.py to colab_cv_loss_sweep_sequential.py
2807f28 verified
"""
CV Loss Sweep β€” Pure Noise Prediction
========================================
Random inputs β†’ MLP encoder β†’ S^(d-1) β†’ constellation β†’ predict 10 random labels.
No dataset. No structure. No signal. The model memorizes random noise→label
mappings. Any geometric regularity (CV convergence) is purely from:
- The unit hypersphere S^(d-1)
- The smooth optimizer (AdamW)
- The CV loss pressure (or lack thereof)
If CV β‰ˆ 0.20 with zero CV loss on pure noise, the constant is the sphere's
property, not a training artifact and not a data property.
Each run: 200 steps, ~2 seconds. Full sweep: ~1 minute.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import json
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ═══════════════════════════════════════════════════════════════
# Noise Dataset β€” pure random, zero structure
# ═══════════════════════════════════════════════════════════════
class NoiseDataset(torch.utils.data.Dataset):
"""Random Gaussian inputs with random labels. No signal."""
def __init__(self, n_samples=5000, input_dim=128, num_classes=10, seed=0):
torch.manual_seed(seed)
self.data = torch.randn(n_samples, input_dim)
self.labels = torch.randint(0, num_classes, (n_samples,))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# ═══════════════════════════════════════════════════════════════
# Minimal MLP Encoder β†’ S^(d-1)
# ═══════════════════════════════════════════════════════════════
class NoiseEncoder(nn.Module):
"""MLP β†’ sphere. No convolutions, no structure."""
def __init__(self, input_dim=128, hidden_dim=256, output_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim),
nn.LayerNorm(output_dim),
)
def forward(self, x):
return F.normalize(self.net(x), dim=-1)
# ═══════════════════════════════════════════════════════════════
# Minimal Constellation + Classifier
# ═══════════════════════════════════════════════════════════════
class NoiseConstellation(nn.Module):
"""Minimal: anchors + patchwork + classifier. No bridge, no push, no magnitude."""
def __init__(self, dim=128, n_anchors=64, n_comp=8, num_classes=10):
super().__init__()
self.n_anchors = n_anchors
self.n_comp = n_comp
anchors = F.normalize(torch.randn(n_anchors, dim), dim=-1)
self.anchors = nn.Parameter(anchors)
apc = n_anchors // n_comp
self.patchwork = nn.ModuleList([
nn.Sequential(nn.Linear(apc, 64), nn.GELU(), nn.Linear(64, 64))
for _ in range(n_comp)
])
self.classifier = nn.Linear(n_comp * 64 + dim, num_classes)
def forward(self, emb):
anchors_n = F.normalize(self.anchors, dim=-1)
tri = emb @ anchors_n.T
apc = self.n_anchors // self.n_comp
pw_parts = []
for k in range(self.n_comp):
pw_parts.append(self.patchwork[k](tri[:, k*apc:(k+1)*apc]))
pw = torch.cat(pw_parts, dim=-1)
logits = self.classifier(torch.cat([pw, emb], dim=-1))
return logits, emb
# ═══════════════════════════════════════════════════════════════
# CV Computation
# ═══════════════════════════════════════════════════════════════
def cv_loss(emb, target=0.22, n_samples=32, n_points=5):
"""Differentiable CV loss."""
B = emb.shape[0]
if B < n_points:
return torch.tensor(0.0, device=emb.device, requires_grad=True)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(B, 256), device=emb.device)[:n_points]
pts = emb[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
N = n_points
cm = torch.zeros(1, N+1, N+1, device=emb.device, dtype=emb.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
k = N - 1
pf = ((-1.0) ** (k+1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
v2 = pf * torch.linalg.det(cm.float())
if v2[0].item() > 1e-20:
vols.append(v2[0].to(emb.dtype).sqrt())
if len(vols) < 5:
return torch.tensor(0.0, device=emb.device, requires_grad=True)
vt = torch.stack(vols)
cv = vt.std() / (vt.mean() + 1e-8)
return (cv - target).pow(2)
def cv_metric(emb, n_samples=200, n_points=5):
"""Non-differentiable CV for monitoring."""
vols = []
with torch.no_grad():
for _ in range(n_samples):
idx = torch.randperm(emb.shape[0])[:n_points]
pts = emb[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
N = n_points
cm = torch.zeros(1, N+1, N+1, device=emb.device, dtype=emb.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
k = N - 1
pf = ((-1.0) ** (k+1)) / ((2.0 ** k) * (math.factorial(k) ** 2))
v2 = pf * torch.linalg.det(cm.float())
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt())
if len(vols) < 10:
return 0.0
vols_t = torch.stack(vols)
return (vols_t.std() / (vols_t.mean() + 1e-8)).item()
# ═══════════════════════════════════════════════════════════════
# Single Run
# ═══════════════════════════════════════════════════════════════
def run_experiment(cv_weight, cv_target, n_steps=200, dim=128, n_anchors=64,
batch_size=256, n_samples=5000, seed=42, pure_cv=False):
"""One configuration. Returns results dict. ~2 seconds."""
torch.manual_seed(seed)
ds = NoiseDataset(n_samples=n_samples, input_dim=dim, num_classes=10, seed=seed + 1000)
loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)
encoder = NoiseEncoder(input_dim=dim, hidden_dim=256, output_dim=dim).to(DEVICE)
constellation = NoiseConstellation(dim=dim, n_anchors=n_anchors).to(DEVICE)
params = list(encoder.parameters()) + list(constellation.parameters())
optimizer = torch.optim.AdamW(params, lr=0.001, weight_decay=0.05)
step = 0
cv_history = []
ce_history = []
acc_history = []
while step < n_steps:
for data, labels in loader:
if step >= n_steps:
break
data, labels = data.to(DEVICE), labels.to(DEVICE)
emb = encoder(data)
logits, _ = constellation(emb)
l_ce = F.cross_entropy(logits, labels)
if cv_weight > 0:
l_cv = cv_loss(emb, target=cv_target, n_samples=32)
else:
l_cv = torch.tensor(0.0, device=DEVICE)
if pure_cv:
loss = cv_weight * l_cv # NO CE, pure geometric pressure
else:
loss = l_ce + cv_weight * l_cv
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(params, 1.0)
optimizer.step()
acc = (logits.argmax(-1) == labels).float().mean().item()
# Measure CV every 50 steps
if step % 50 == 0 or step == n_steps - 1:
with torch.no_grad():
# Collect embeddings for CV measurement
all_emb = []
for d_batch, _ in loader:
all_emb.append(encoder(d_batch.to(DEVICE)))
if len(all_emb) * batch_size >= 2000:
break
all_emb = torch.cat(all_emb)[:2000]
v_cv = cv_metric(all_emb, n_samples=200)
# Effective dim
centered = all_emb[:1000] - all_emb[:1000].mean(0)
s = torch.linalg.svdvals(centered.float())
s_n = s / s.sum()
eff_dim = (1.0 / (s_n ** 2).sum()).item()
cv_history.append({'step': step, 'cv': round(v_cv, 4), 'eff_dim': round(eff_dim, 1)})
ce_history.append(l_ce.item())
acc_history.append(acc)
step += 1
# Final measurement
with torch.no_grad():
all_emb = []
for d_batch, _ in loader:
all_emb.append(encoder(d_batch.to(DEVICE)))
if len(all_emb) * batch_size >= 2000:
break
all_emb = torch.cat(all_emb)[:2000]
final_cv = cv_metric(all_emb, n_samples=300)
centered = all_emb[:1000] - all_emb[:1000].mean(0)
s = torch.linalg.svdvals(centered.float())
s_n = s / s.sum()
final_dim = (1.0 / (s_n ** 2).sum()).item()
return {
'cv_weight': cv_weight,
'cv_target': cv_target,
'pure_cv': pure_cv,
'seed': seed,
'n_steps': n_steps,
'dim': dim,
'final_cv': round(final_cv, 4),
'final_dim': round(final_dim, 1),
'final_ce': round(sum(ce_history[-20:]) / 20, 4),
'final_acc': round(sum(acc_history[-20:]) / 20 * 100, 1),
'cv_trajectory': cv_history,
}
# ═══════════════════════════════════════════════════════════════
# SWEEP
# ═══════════════════════════════════════════════════════════════
print("=" * 80)
print("CV LOSS SWEEP β€” PURE NOISE PREDICTION")
print(" Random inputs β†’ MLP β†’ S^(d-1) β†’ constellation β†’ 10 random labels")
print(" No data structure. No signal. Pure sphere geometry + optimizer.")
print(" 200 steps per run, ~2s each.")
print("=" * 80)
# (cv_weight, cv_target, label, seed)
configs = [
# ── NO CV LOSS β€” baseline ──
(0.0, 0.0, "no_cv", 42),
(0.0, 0.0, "no_cv_s2", 123),
(0.0, 0.0, "no_cv_s3", 456),
(0.0, 0.0, "no_cv_s4", 789),
(0.0, 0.0, "no_cv_s5", 1337),
# ── CORRECT TARGET, VARYING WEIGHT ──
(0.001, 0.22, "w0.001_t0.22", 42),
(0.01, 0.22, "w0.01_t0.22", 42),
(0.1, 0.22, "w0.1_t0.22", 42),
(0.5, 0.22, "w0.5_t0.22", 42),
(1.0, 0.22, "w1.0_t0.22", 42),
(5.0, 0.22, "w5.0_t0.22", 42),
(10.0, 0.22, "w10_t0.22", 42),
(50.0, 0.22, "w50_t0.22", 42),
(100.0, 0.22, "w100_t0.22", 42),
# ── WRONG TARGETS, LOW WEIGHT (gentle push) ──
(0.01, 0.00, "w0.01_t0.00", 42),
(0.01, 0.05, "w0.01_t0.05", 42),
(0.01, 0.10, "w0.01_t0.10", 42),
(0.01, 0.30, "w0.01_t0.30", 42),
(0.01, 0.50, "w0.01_t0.50", 42),
(0.01, 0.80, "w0.01_t0.80", 42),
(0.01, 1.00, "w0.01_t1.00", 42),
(0.01, 2.00, "w0.01_t2.00", 42),
# ── WRONG TARGETS, MEDIUM WEIGHT (strong push) ──
(1.0, 0.00, "w1_t0.00", 42),
(1.0, 0.05, "w1_t0.05", 42),
(1.0, 0.50, "w1_t0.50", 42),
(1.0, 0.80, "w1_t0.80", 42),
(1.0, 1.00, "w1_t1.00", 42),
# ── WRONG TARGETS, EXTREME WEIGHT (maximum force) ──
(100.0, 0.00, "w100_t0.00", 42),
(100.0, 0.05, "w100_t0.05", 42),
(100.0, 0.10, "w100_t0.10", 42),
(100.0, 0.50, "w100_t0.50", 42),
(100.0, 0.80, "w100_t0.80", 42),
(100.0, 1.00, "w100_t1.00", 42),
# ── CV LOSS ONLY, NO CE (pure geometric pressure) ──
(1.0, 0.22, "pure_cv_t0.22", 42), # mark for CE override
(1.0, 0.05, "pure_cv_t0.05", 42),
(1.0, 0.50, "pure_cv_t0.50", 42),
(1.0, 0.80, "pure_cv_t0.80", 42),
(1.0, 1.00, "pure_cv_t1.00", 42),
# ── DIMENSION SWEEP (does dim change the constant?) ──
(0.0, 0.0, "dim16", 42), # mark for dim override
(0.0, 0.0, "dim32", 42),
(0.0, 0.0, "dim64", 42),
(0.0, 0.0, "dim256", 42),
(0.0, 0.0, "dim512", 42),
]
# Special handling
pure_cv_labels = {l for _, _, l, _ in configs if l.startswith("pure_cv")}
dim_overrides = {"dim16": 16, "dim32": 32, "dim64": 64, "dim256": 256, "dim512": 512}
all_results = []
total = len(configs)
print(f"\n Running {total} configurations, 200 steps each")
print(f" Estimated time: ~{total * 3}s\n")
for i, (cv_w, cv_t, label, seed) in enumerate(configs):
t0 = time.time()
dim = dim_overrides.get(label, 128)
is_pure_cv = label in pure_cv_labels
print(f"[{i+1:2d}/{total}] {label:20s} w={cv_w:<8.3f} t={cv_t:<5.2f} d={dim:<4d}", end=" ", flush=True)
result = run_experiment(
cv_weight=cv_w,
cv_target=cv_t,
n_steps=200,
dim=dim,
seed=seed,
pure_cv=is_pure_cv,
)
result['label'] = label
elapsed = time.time() - t0
print(f"β†’ CV={result['final_cv']:.4f} dim={result['final_dim']:.0f} "
f"acc={result['final_acc']:.0f}% ({elapsed:.1f}s)")
all_results.append(result)
# ═══════════════════════════════════════════════════════════════
# SUMMARY TABLE
# ═══════════════════════════════════════════════════════════════
print(f"\n\n{'='*90}")
print(f"{'LABEL':20s} {'CV_W':>8s} {'CV_T':>6s} {'DIM':>5s} {'FINAL_CV':>9s} {'EFF_DIM':>8s} {'ACC%':>6s} {'CE':>8s}")
print(f"{'─'*90}")
for r in all_results:
cv_mark = "βœ“" if 0.17 <= r['final_cv'] <= 0.24 else "~" if 0.15 <= r['final_cv'] <= 0.27 else "βœ—"
print(f"{r['label']:20s} {r['cv_weight']:>8.3f} {r['cv_target']:>6.2f} {r['dim']:>5d} "
f"{r['final_cv']:>8.4f}{cv_mark} {r['final_dim']:>7.0f} {r['final_acc']:>5.0f}% {r['final_ce']:>8.4f}")
# ═══════════════════════════════════════════════════════════════
# ANALYSIS
# ═══════════════════════════════════════════════════════════════
print(f"\n\n{'='*90}")
print("ANALYSIS")
print(f"{'='*90}")
# 1. Baseline: no CV loss
no_cv = [r for r in all_results if r['cv_weight'] == 0 and r['dim'] == 128]
if no_cv:
cvs = [r['final_cv'] for r in no_cv]
dims = [r['final_dim'] for r in no_cv]
print(f"\n [1] NO CV LOSS, PURE NOISE (d=128, {len(no_cv)} seeds):")
print(f" CV: mean={sum(cvs)/len(cvs):.4f} min={min(cvs):.4f} max={max(cvs):.4f} spread={max(cvs)-min(cvs):.4f}")
print(f" Dim: mean={sum(dims)/len(dims):.1f}")
within_band = sum(1 for c in cvs if 0.17 <= c <= 0.24)
print(f" Within [0.17, 0.24]: {within_band}/{len(cvs)}")
# 2. Weight sweep at correct target
weight_sweep = [r for r in all_results if r['cv_target'] == 0.22 and r['dim'] == 128 and not r['pure_cv']]
if weight_sweep:
print(f"\n [2] WEIGHT SWEEP (target=0.22, d=128):")
for r in sorted(weight_sweep, key=lambda x: x['cv_weight']):
print(f" w={r['cv_weight']:>8.3f} β†’ CV={r['final_cv']:.4f} acc={r['final_acc']:.0f}%")
# 3. Target sweep at fixed weight
for w in [0.01, 1.0, 100.0]:
target_runs = [r for r in all_results if r['cv_weight'] == w and r['dim'] == 128 and not r['pure_cv']]
if len(target_runs) > 2:
print(f"\n [3] TARGET SWEEP (w={w}, d=128):")
for r in sorted(target_runs, key=lambda x: x['cv_target']):
cv_mark = "βœ“" if 0.17 <= r['final_cv'] <= 0.24 else "βœ—"
print(f" target={r['cv_target']:.2f} β†’ CV={r['final_cv']:.4f}{cv_mark} acc={r['final_acc']:.0f}%")
# 4. Dimension sweep
dim_runs = [r for r in all_results if r['label'].startswith('dim')]
if dim_runs:
print(f"\n [4] DIMENSION SWEEP (no CV loss):")
for r in sorted(dim_runs, key=lambda x: x['dim']):
print(f" d={r['dim']:>4d} β†’ CV={r['final_cv']:.4f} eff_dim={r['final_dim']:.0f}")
# 5. Key question: can extreme weight move CV?
extreme = [r for r in all_results if r['cv_weight'] >= 100 and r['dim'] == 128]
if extreme:
print(f"\n [5] EXTREME FORCE (wβ‰₯100, d=128):")
for r in sorted(extreme, key=lambda x: x['cv_target']):
delta = abs(r['final_cv'] - 0.20)
print(f" target={r['cv_target']:.2f} β†’ CV={r['final_cv']:.4f} (Ξ” from 0.20: {delta:.4f}) acc={r['final_acc']:.0f}%")
# 6. CV trajectory analysis β€” does it start elsewhere and converge?
print(f"\n [6] CV TRAJECTORIES (step 0 β†’ step 200):")
for r in all_results[:5]: # first 5 runs
traj = r.get('cv_trajectory', [])
if len(traj) >= 2:
first = traj[0]['cv']
last = traj[-1]['cv']
print(f" {r['label']:20s}: {first:.4f} β†’ {last:.4f} (Ξ”={last-first:+.4f})")
# Save
with open('cv_sweep_results.json', 'w') as f:
json.dump(all_results, f, indent=2, default=str)
print(f"\n Raw results saved to cv_sweep_results.json")
print(f"\n{'='*80}")
print("CV SWEEP COMPLETE")
print(f"{'='*80}")