geolip-cv-experiments / geolip_loss_profiler.py
AbstractPhil's picture
Create geolip_loss_profiler.py
9b40529 verified
"""
Loss Spectrum Profiler β€” Standalone
=====================================
Builds its own model + noise data. Profiles every loss computation
in the GeoLIP pipeline with CUDA-synced microsecond timing.
Zero external dependencies beyond torch. Single cell.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# ═══════════════════════════════════════════════════════════════
# Config β€” matches our architecture
# ═══════════════════════════════════════════════════════════════
DIM = 256
N_ANCHORS = 256
N_COMP = 8
D_COMP = 64
BATCH = 256
NUM_CLASSES = 100
# ═══════════════════════════════════════════════════════════════
# Minimal model components (self-contained, no imports)
# ═══════════════════════════════════════════════════════════════
class ProfileEncoder(nn.Module):
def __init__(self, dim=256):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
nn.MaxPool2d(2),
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
nn.MaxPool2d(2),
nn.Conv2d(256, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(),
nn.Conv2d(384, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(),
nn.MaxPool2d(2),
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
)
self.proj = nn.Sequential(nn.Linear(384, dim), nn.LayerNorm(dim))
def forward(self, x):
feat = self.features(x)
return F.normalize(self.proj(feat), dim=-1), feat[:, :1] # emb, fake raw_mag
class ProfilePatchwork(nn.Module):
def __init__(self, n_anchors=256, n_comp=8, d_comp=64):
super().__init__()
apc = n_anchors // n_comp
self.n_comp = n_comp
self.comps = nn.ModuleList([
nn.Sequential(nn.Linear(apc, d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp))
for _ in range(n_comp)
])
def forward(self, tri):
apc = tri.shape[1] // self.n_comp
parts = []
for k in range(self.n_comp):
parts.append(self.comps[k](tri[:, k*apc:(k+1)*apc]))
return torch.cat(parts, dim=-1)
# Build all components
print("Building profile model...")
encoder = ProfileEncoder(DIM).to(DEVICE)
anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, DIM, device=DEVICE), dim=-1))
patchwork = ProfilePatchwork(N_ANCHORS, N_COMP, D_COMP).to(DEVICE)
bridge = nn.Linear(N_COMP * D_COMP, N_ANCHORS).to(DEVICE)
task_head = nn.Sequential(
nn.Linear(N_ANCHORS + N_COMP * D_COMP + DIM, N_COMP * D_COMP),
nn.GELU(), nn.LayerNorm(N_COMP * D_COMP), nn.Dropout(0.1),
nn.Linear(N_COMP * D_COMP, NUM_CLASSES),
).to(DEVICE)
# Fake batch β€” random images + labels
v1 = torch.randn(BATCH, 3, 32, 32, device=DEVICE)
v2 = torch.randn(BATCH, 3, 32, 32, device=DEVICE)
targets = torch.randint(0, NUM_CLASSES, (BATCH,), device=DEVICE)
labels_nce = torch.arange(BATCH, device=DEVICE)
# Pre-compute intermediates
with torch.no_grad():
emb1, raw_mag1 = encoder(v1)
emb2, raw_mag2 = encoder(v2)
anchors_n = F.normalize(anchors, dim=-1)
cos1 = emb1 @ anchors_n.T
cos2 = emb2 @ anchors_n.T
tri1 = 1.0 - cos1
tri2 = 1.0 - cos2
assign1 = F.softmax(cos1 / 0.1, dim=-1)
assign2 = F.softmax(cos2 / 0.1, dim=-1)
pw1 = patchwork(tri1)
pw2 = patchwork(tri2)
bridge1 = bridge(pw1)
feat1 = torch.cat([assign1, pw1, emb1], dim=-1)
logits1 = task_head(feat1)
all_params = (list(encoder.parameters()) + [anchors] +
list(patchwork.parameters()) + list(bridge.parameters()) +
list(task_head.parameters()))
print(f" Device: {DEVICE}")
print(f" Batch: {BATCH}, Dim: {DIM}, Anchors: {N_ANCHORS}, Comp: {N_COMP}Γ—{D_COMP}")
n_params = sum(p.numel() for p in all_params)
print(f" Parameters: {n_params:,}")
# ═══════════════════════════════════════════════════════════════
# Timer
# ═══════════════════════════════════════════════════════════════
def timed(name, fn, n_runs=30, warmup=5):
"""CUDA-synced timing. Returns (result, avg_ms)."""
for _ in range(warmup):
r = fn()
torch.cuda.synchronize()
times = []
for _ in range(n_runs):
torch.cuda.synchronize()
t0 = time.perf_counter()
r = fn()
torch.cuda.synchronize()
times.append((time.perf_counter() - t0) * 1000)
avg = sum(times) / len(times)
return r, avg
results = []
def record(name, fn, **kw):
_, ms = timed(name, fn, **kw)
results.append((name, ms))
return ms
# ═══════════════════════════════════════════════════════════════
# SECTION 1: Forward Components
# ═══════════════════════════════════════════════════════════════
print(f"\n{'='*80}")
print("SECTION 1: FORWARD PASS COMPONENTS")
print(f"{'='*80}\n")
record("encoder(v1)", lambda: encoder(v1))
record("triangulation (emb@A.T)", lambda: emb1 @ anchors_n.T)
record("soft_assign (softmax)", lambda: F.softmax(cos1 / 0.1, dim=-1))
record("patchwork(tri)", lambda: patchwork(tri1))
record("bridge(pw)", lambda: bridge(pw1))
record("task_head(feat)", lambda: task_head(feat1))
def _full_fwd():
e1, _ = encoder(v1)
e2, _ = encoder(v2)
an = F.normalize(anchors, dim=-1)
c1 = e1 @ an.T; c2 = e2 @ an.T
t1 = 1 - c1; t2 = 1 - c2
a1 = F.softmax(c1/0.1, dim=-1); a2 = F.softmax(c2/0.1, dim=-1)
p1 = patchwork(t1); p2 = patchwork(t2)
b1 = bridge(p1)
f1 = torch.cat([a1, p1, e1], -1)
return task_head(f1)
record("FULL forward (both views)", _full_fwd)
# ═══════════════════════════════════════════════════════════════
# SECTION 2: Individual Loss Terms (forward only)
# ═══════════════════════════════════════════════════════════════
print(f"\n{'='*80}")
print("SECTION 2: INDIVIDUAL LOSS TERMS (forward only)")
print(f"{'='*80}\n")
record("CE (cross_entropy)", lambda: F.cross_entropy(logits1, targets))
record("NCE_emb (BΓ—B + CE)", lambda: F.cross_entropy(
emb1 @ emb2.T / 0.07, labels_nce))
record("NCE_pw (norm + BΓ—B + CE)", lambda: F.cross_entropy(
F.normalize(pw1, dim=-1) @ F.normalize(pw2, dim=-1).T / 0.1, labels_nce))
record("NCE_tri (norm + BΓ—B + CE)", lambda: F.cross_entropy(
F.normalize(tri1, dim=-1) @ F.normalize(tri2, dim=-1).T / 0.1, labels_nce))
record("NCE_assign (BΓ—B + CE)", lambda: F.cross_entropy(
assign1 @ assign2.T / 0.1, labels_nce))
def _bridge_loss():
at = assign1.detach()
return -(at * F.log_softmax(bridge1, dim=-1)).sum(-1).mean()
record("Bridge (soft CE)", _bridge_loss)
def _assign_bce():
nearest = cos1.argmax(dim=-1)
hard = torch.zeros_like(assign1)
hard.scatter_(1, nearest.unsqueeze(1), 1.0)
return F.binary_cross_entropy(assign1.float().clamp(1e-7, 1-1e-7), hard.float())
record("Assign BCE", _assign_bce)
record("Attraction (max + mean)", lambda: (1.0 - cos1.max(dim=1).values).mean())
def _spread():
a = F.normalize(anchors, dim=-1)
sim = a @ a.T
mask = ~torch.eye(N_ANCHORS, dtype=torch.bool, device=DEVICE)
return F.relu(sim[mask]).mean()
record("Spread (AΓ—A + relu)", _spread)
record("kNN (BΓ—B + argmax)", lambda: (
targets[(emb1 @ emb1.T).fill_diagonal_(-1).argmax(1)] == targets).float().mean())
# ═══════════════════════════════════════════════════════════════
# SECTION 3: CV Loss β€” Old vs Batched
# ═══════════════════════════════════════════════════════════════
print(f"\n{'='*80}")
print("SECTION 3: CV LOSS β€” OLD SEQUENTIAL vs BATCHED")
print(f"{'='*80}\n")
# Old sequential
def _cv_old(n_samples=64):
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(BATCH, 256), device=DEVICE)[:5]
pts = emb1[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=pts.dtype)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2))
v2 = pf * torch.linalg.det(cm.float())
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt())
if len(vols) < 5:
return torch.tensor(0.0, device=DEVICE)
vt = torch.stack(vols)
return ((vt.std() / (vt.mean() + 1e-8)) - 0.22).pow(2)
# Batched
def _cv_batched(n_samples=64):
pool = min(BATCH, 256)
rand_keys = torch.rand(n_samples, pool, device=DEVICE)
indices = rand_keys.argsort(dim=1)[:, :5]
pts = emb1[:pool][indices]
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(n_samples, 6, 6, device=DEVICE, dtype=pts.dtype)
cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2
pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2))
dets = pf * torch.linalg.det(cm.float())
valid = dets > 1e-20
vols = dets[valid].sqrt()
if vols.shape[0] < 5:
return torch.tensor(0.0, device=DEVICE)
return ((vols.std() / (vols.mean() + 1e-8)) - 0.22).pow(2)
for ns in [32, 64, 128, 200]:
record(f"CV OLD n={ns}", lambda ns=ns: _cv_old(ns), n_runs=10)
record(f"CV BATCH n={ns}", lambda ns=ns: _cv_batched(ns), n_runs=10)
# Non-differentiable metric versions
def _cv_metric_old(n_samples=200):
with torch.no_grad():
return _cv_old(n_samples)
def _cv_metric_batch(n_samples=200):
with torch.no_grad():
return _cv_batched(n_samples)
record("CV metric OLD n=200", _cv_metric_old, n_runs=10)
record("CV metric BATCH n=200", _cv_metric_batch, n_runs=10)
# ═══════════════════════════════════════════════════════════════
# SECTION 4: Backward costs
# ═══════════════════════════════════════════════════════════════
print(f"\n{'='*80}")
print("SECTION 4: BACKWARD COSTS (forward + backward)")
print(f"{'='*80}\n")
def _bwd(loss_fn):
for p in all_params:
if p.grad is not None:
p.grad.zero_()
loss = loss_fn()
if torch.is_tensor(loss) and loss.requires_grad:
loss.backward()
return loss
# Need fresh forward for each backward
def _fwd_bwd_ce():
e, _ = encoder(v1)
an = F.normalize(anchors, dim=-1)
c = e @ an.T; t = 1 - c
a = F.softmax(c/0.1, dim=-1)
p = patchwork(t)
f = torch.cat([a, p, e], -1)
return _bwd(lambda: F.cross_entropy(task_head(f), targets))
def _fwd_bwd_nce_emb():
e1, _ = encoder(v1); e2, _ = encoder(v2)
return _bwd(lambda: F.cross_entropy(e1 @ e2.T / 0.07, labels_nce))
def _fwd_bwd_nce_pw():
e1, _ = encoder(v1); e2, _ = encoder(v2)
an = F.normalize(anchors, dim=-1)
t1 = 1 - e1 @ an.T; t2 = 1 - e2 @ an.T
p1 = patchwork(t1); p2 = patchwork(t2)
return _bwd(lambda: F.cross_entropy(
F.normalize(p1, dim=-1) @ F.normalize(p2, dim=-1).T / 0.1, labels_nce))
def _fwd_bwd_cv_old():
e, _ = encoder(v1)
return _bwd(lambda: _cv_old(64))
def _fwd_bwd_cv_batch():
e, _ = encoder(v1)
return _bwd(lambda: _cv_batched(64))
def _fwd_bwd_bridge():
e, _ = encoder(v1)
an = F.normalize(anchors, dim=-1)
c = e @ an.T; t = 1 - c
a = F.softmax(c/0.1, dim=-1)
p = patchwork(t); b = bridge(p)
at = a.detach()
return _bwd(lambda: -(at * F.log_softmax(b, dim=-1)).sum(-1).mean())
record("fwd+bwd CE", _fwd_bwd_ce, n_runs=10, warmup=3)
record("fwd+bwd NCE_emb", _fwd_bwd_nce_emb, n_runs=10, warmup=3)
record("fwd+bwd NCE_pw", _fwd_bwd_nce_pw, n_runs=10, warmup=3)
record("fwd+bwd CV old", _fwd_bwd_cv_old, n_runs=10, warmup=3)
record("fwd+bwd CV batch", _fwd_bwd_cv_batch, n_runs=10, warmup=3)
record("fwd+bwd Bridge", _fwd_bwd_bridge, n_runs=10, warmup=3)
# ═══════════════════════════════════════════════════════════════
# REPORT
# ═══════════════════════════════════════════════════════════════
print(f"\n\n{'='*80}")
print("FULL TIMING REPORT (sorted by cost)")
print(f"{'='*80}\n")
total = sum(ms for _, ms in results)
for name, ms in sorted(results, key=lambda x: -x[1]):
pct = 100 * ms / total if total > 0 else 0
bar_len = int(pct / 2)
bar = "β–ˆ" * bar_len + "β–‘" * (40 - bar_len)
print(f" {name:35s} {ms:>9.3f}ms {bar} {pct:>5.1f}%")
print(f" {'─'*90}")
print(f" {'SUM':35s} {total:>9.3f}ms")
# CV speedup summary
print(f"\n{'='*80}")
print("CV SPEEDUP SUMMARY")
print(f"{'='*80}")
cv_pairs = {}
for name, ms in results:
if name.startswith("CV "):
key = name.split("n=")[1] if "n=" in name else "?"
tag = "old" if "OLD" in name else "batch"
cv_pairs.setdefault(key, {})[tag] = ms
for k in sorted(cv_pairs.keys()):
p = cv_pairs[k]
if 'old' in p and 'batch' in p:
speedup = p['old'] / p['batch'] if p['batch'] > 0 else 0
print(f" n={k:>4s}: {p['old']:>8.2f}ms β†’ {p['batch']:>8.2f}ms ({speedup:.1f}x speedup)")
# Per-step estimate
print(f"\n{'='*80}")
print("PER-STEP ESTIMATE")
print(f"{'='*80}")
fwd_time = next((ms for n, ms in results if n == "FULL forward (both views)"), 0)
bwd_ce = next((ms for n, ms in results if n == "fwd+bwd CE"), 0)
bwd_cv_old = next((ms for n, ms in results if n == "fwd+bwd CV old"), 0)
bwd_cv_new = next((ms for n, ms in results if n == "fwd+bwd CV batch"), 0)
print(f" Forward (both views): {fwd_time:.2f}ms")
print(f" fwd+bwd CE: {bwd_ce:.2f}ms")
print(f" fwd+bwd CV (old): {bwd_cv_old:.2f}ms")
print(f" fwd+bwd CV (batched): {bwd_cv_new:.2f}ms")
if bwd_cv_old > 0 and bwd_cv_new > 0:
saved = bwd_cv_old - bwd_cv_new
print(f" CV savings per step: {saved:.2f}ms ({saved/bwd_cv_old*100:.0f}%)")
print(f"\n{'='*80}")
print("PROFILING COMPLETE")
print(f"{'='*80}")