| """ |
| Loss Spectrum Profiler β Standalone |
| ===================================== |
| Builds its own model + noise data. Profiles every loss computation |
| in the GeoLIP pipeline with CUDA-synced microsecond timing. |
| |
| Zero external dependencies beyond torch. Single cell. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import time |
| import math |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| |
| |
| DIM = 256 |
| N_ANCHORS = 256 |
| N_COMP = 8 |
| D_COMP = 64 |
| BATCH = 256 |
| NUM_CLASSES = 100 |
|
|
|
|
| |
| |
| |
|
|
| class ProfileEncoder(nn.Module): |
| def __init__(self, dim=256): |
| super().__init__() |
| self.features = nn.Sequential( |
| nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(), |
| nn.MaxPool2d(2), |
| nn.Conv2d(256, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(), |
| nn.Conv2d(384, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(), |
| nn.MaxPool2d(2), |
| nn.AdaptiveAvgPool2d(1), nn.Flatten(), |
| ) |
| self.proj = nn.Sequential(nn.Linear(384, dim), nn.LayerNorm(dim)) |
|
|
| def forward(self, x): |
| feat = self.features(x) |
| return F.normalize(self.proj(feat), dim=-1), feat[:, :1] |
|
|
|
|
| class ProfilePatchwork(nn.Module): |
| def __init__(self, n_anchors=256, n_comp=8, d_comp=64): |
| super().__init__() |
| apc = n_anchors // n_comp |
| self.n_comp = n_comp |
| self.comps = nn.ModuleList([ |
| nn.Sequential(nn.Linear(apc, d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp)) |
| for _ in range(n_comp) |
| ]) |
|
|
| def forward(self, tri): |
| apc = tri.shape[1] // self.n_comp |
| parts = [] |
| for k in range(self.n_comp): |
| parts.append(self.comps[k](tri[:, k*apc:(k+1)*apc])) |
| return torch.cat(parts, dim=-1) |
|
|
|
|
| |
| print("Building profile model...") |
| encoder = ProfileEncoder(DIM).to(DEVICE) |
| anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, DIM, device=DEVICE), dim=-1)) |
| patchwork = ProfilePatchwork(N_ANCHORS, N_COMP, D_COMP).to(DEVICE) |
| bridge = nn.Linear(N_COMP * D_COMP, N_ANCHORS).to(DEVICE) |
| task_head = nn.Sequential( |
| nn.Linear(N_ANCHORS + N_COMP * D_COMP + DIM, N_COMP * D_COMP), |
| nn.GELU(), nn.LayerNorm(N_COMP * D_COMP), nn.Dropout(0.1), |
| nn.Linear(N_COMP * D_COMP, NUM_CLASSES), |
| ).to(DEVICE) |
|
|
| |
| v1 = torch.randn(BATCH, 3, 32, 32, device=DEVICE) |
| v2 = torch.randn(BATCH, 3, 32, 32, device=DEVICE) |
| targets = torch.randint(0, NUM_CLASSES, (BATCH,), device=DEVICE) |
| labels_nce = torch.arange(BATCH, device=DEVICE) |
|
|
| |
| with torch.no_grad(): |
| emb1, raw_mag1 = encoder(v1) |
| emb2, raw_mag2 = encoder(v2) |
| anchors_n = F.normalize(anchors, dim=-1) |
| cos1 = emb1 @ anchors_n.T |
| cos2 = emb2 @ anchors_n.T |
| tri1 = 1.0 - cos1 |
| tri2 = 1.0 - cos2 |
| assign1 = F.softmax(cos1 / 0.1, dim=-1) |
| assign2 = F.softmax(cos2 / 0.1, dim=-1) |
| pw1 = patchwork(tri1) |
| pw2 = patchwork(tri2) |
| bridge1 = bridge(pw1) |
| feat1 = torch.cat([assign1, pw1, emb1], dim=-1) |
| logits1 = task_head(feat1) |
|
|
| all_params = (list(encoder.parameters()) + [anchors] + |
| list(patchwork.parameters()) + list(bridge.parameters()) + |
| list(task_head.parameters())) |
|
|
| print(f" Device: {DEVICE}") |
| print(f" Batch: {BATCH}, Dim: {DIM}, Anchors: {N_ANCHORS}, Comp: {N_COMP}Γ{D_COMP}") |
| n_params = sum(p.numel() for p in all_params) |
| print(f" Parameters: {n_params:,}") |
|
|
|
|
| |
| |
| |
|
|
| def timed(name, fn, n_runs=30, warmup=5): |
| """CUDA-synced timing. Returns (result, avg_ms).""" |
| for _ in range(warmup): |
| r = fn() |
| torch.cuda.synchronize() |
| times = [] |
| for _ in range(n_runs): |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| r = fn() |
| torch.cuda.synchronize() |
| times.append((time.perf_counter() - t0) * 1000) |
| avg = sum(times) / len(times) |
| return r, avg |
|
|
| results = [] |
|
|
| def record(name, fn, **kw): |
| _, ms = timed(name, fn, **kw) |
| results.append((name, ms)) |
| return ms |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*80}") |
| print("SECTION 1: FORWARD PASS COMPONENTS") |
| print(f"{'='*80}\n") |
|
|
| record("encoder(v1)", lambda: encoder(v1)) |
| record("triangulation (emb@A.T)", lambda: emb1 @ anchors_n.T) |
| record("soft_assign (softmax)", lambda: F.softmax(cos1 / 0.1, dim=-1)) |
| record("patchwork(tri)", lambda: patchwork(tri1)) |
| record("bridge(pw)", lambda: bridge(pw1)) |
| record("task_head(feat)", lambda: task_head(feat1)) |
|
|
| def _full_fwd(): |
| e1, _ = encoder(v1) |
| e2, _ = encoder(v2) |
| an = F.normalize(anchors, dim=-1) |
| c1 = e1 @ an.T; c2 = e2 @ an.T |
| t1 = 1 - c1; t2 = 1 - c2 |
| a1 = F.softmax(c1/0.1, dim=-1); a2 = F.softmax(c2/0.1, dim=-1) |
| p1 = patchwork(t1); p2 = patchwork(t2) |
| b1 = bridge(p1) |
| f1 = torch.cat([a1, p1, e1], -1) |
| return task_head(f1) |
|
|
| record("FULL forward (both views)", _full_fwd) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*80}") |
| print("SECTION 2: INDIVIDUAL LOSS TERMS (forward only)") |
| print(f"{'='*80}\n") |
|
|
| record("CE (cross_entropy)", lambda: F.cross_entropy(logits1, targets)) |
|
|
| record("NCE_emb (BΓB + CE)", lambda: F.cross_entropy( |
| emb1 @ emb2.T / 0.07, labels_nce)) |
|
|
| record("NCE_pw (norm + BΓB + CE)", lambda: F.cross_entropy( |
| F.normalize(pw1, dim=-1) @ F.normalize(pw2, dim=-1).T / 0.1, labels_nce)) |
|
|
| record("NCE_tri (norm + BΓB + CE)", lambda: F.cross_entropy( |
| F.normalize(tri1, dim=-1) @ F.normalize(tri2, dim=-1).T / 0.1, labels_nce)) |
|
|
| record("NCE_assign (BΓB + CE)", lambda: F.cross_entropy( |
| assign1 @ assign2.T / 0.1, labels_nce)) |
|
|
| def _bridge_loss(): |
| at = assign1.detach() |
| return -(at * F.log_softmax(bridge1, dim=-1)).sum(-1).mean() |
| record("Bridge (soft CE)", _bridge_loss) |
|
|
| def _assign_bce(): |
| nearest = cos1.argmax(dim=-1) |
| hard = torch.zeros_like(assign1) |
| hard.scatter_(1, nearest.unsqueeze(1), 1.0) |
| return F.binary_cross_entropy(assign1.float().clamp(1e-7, 1-1e-7), hard.float()) |
| record("Assign BCE", _assign_bce) |
|
|
| record("Attraction (max + mean)", lambda: (1.0 - cos1.max(dim=1).values).mean()) |
|
|
| def _spread(): |
| a = F.normalize(anchors, dim=-1) |
| sim = a @ a.T |
| mask = ~torch.eye(N_ANCHORS, dtype=torch.bool, device=DEVICE) |
| return F.relu(sim[mask]).mean() |
| record("Spread (AΓA + relu)", _spread) |
|
|
| record("kNN (BΓB + argmax)", lambda: ( |
| targets[(emb1 @ emb1.T).fill_diagonal_(-1).argmax(1)] == targets).float().mean()) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*80}") |
| print("SECTION 3: CV LOSS β OLD SEQUENTIAL vs BATCHED") |
| print(f"{'='*80}\n") |
|
|
| |
| def _cv_old(n_samples=64): |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(BATCH, 256), device=DEVICE)[:5] |
| pts = emb1[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=pts.dtype) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2)) |
| v2 = pf * torch.linalg.det(cm.float()) |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt()) |
| if len(vols) < 5: |
| return torch.tensor(0.0, device=DEVICE) |
| vt = torch.stack(vols) |
| return ((vt.std() / (vt.mean() + 1e-8)) - 0.22).pow(2) |
|
|
| |
| def _cv_batched(n_samples=64): |
| pool = min(BATCH, 256) |
| rand_keys = torch.rand(n_samples, pool, device=DEVICE) |
| indices = rand_keys.argsort(dim=1)[:, :5] |
| pts = emb1[:pool][indices] |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(n_samples, 6, 6, device=DEVICE, dtype=pts.dtype) |
| cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2 |
| pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2)) |
| dets = pf * torch.linalg.det(cm.float()) |
| valid = dets > 1e-20 |
| vols = dets[valid].sqrt() |
| if vols.shape[0] < 5: |
| return torch.tensor(0.0, device=DEVICE) |
| return ((vols.std() / (vols.mean() + 1e-8)) - 0.22).pow(2) |
|
|
| for ns in [32, 64, 128, 200]: |
| record(f"CV OLD n={ns}", lambda ns=ns: _cv_old(ns), n_runs=10) |
| record(f"CV BATCH n={ns}", lambda ns=ns: _cv_batched(ns), n_runs=10) |
|
|
| |
| def _cv_metric_old(n_samples=200): |
| with torch.no_grad(): |
| return _cv_old(n_samples) |
| def _cv_metric_batch(n_samples=200): |
| with torch.no_grad(): |
| return _cv_batched(n_samples) |
|
|
| record("CV metric OLD n=200", _cv_metric_old, n_runs=10) |
| record("CV metric BATCH n=200", _cv_metric_batch, n_runs=10) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*80}") |
| print("SECTION 4: BACKWARD COSTS (forward + backward)") |
| print(f"{'='*80}\n") |
|
|
| def _bwd(loss_fn): |
| for p in all_params: |
| if p.grad is not None: |
| p.grad.zero_() |
| loss = loss_fn() |
| if torch.is_tensor(loss) and loss.requires_grad: |
| loss.backward() |
| return loss |
|
|
| |
| def _fwd_bwd_ce(): |
| e, _ = encoder(v1) |
| an = F.normalize(anchors, dim=-1) |
| c = e @ an.T; t = 1 - c |
| a = F.softmax(c/0.1, dim=-1) |
| p = patchwork(t) |
| f = torch.cat([a, p, e], -1) |
| return _bwd(lambda: F.cross_entropy(task_head(f), targets)) |
|
|
| def _fwd_bwd_nce_emb(): |
| e1, _ = encoder(v1); e2, _ = encoder(v2) |
| return _bwd(lambda: F.cross_entropy(e1 @ e2.T / 0.07, labels_nce)) |
|
|
| def _fwd_bwd_nce_pw(): |
| e1, _ = encoder(v1); e2, _ = encoder(v2) |
| an = F.normalize(anchors, dim=-1) |
| t1 = 1 - e1 @ an.T; t2 = 1 - e2 @ an.T |
| p1 = patchwork(t1); p2 = patchwork(t2) |
| return _bwd(lambda: F.cross_entropy( |
| F.normalize(p1, dim=-1) @ F.normalize(p2, dim=-1).T / 0.1, labels_nce)) |
|
|
| def _fwd_bwd_cv_old(): |
| e, _ = encoder(v1) |
| return _bwd(lambda: _cv_old(64)) |
|
|
| def _fwd_bwd_cv_batch(): |
| e, _ = encoder(v1) |
| return _bwd(lambda: _cv_batched(64)) |
|
|
| def _fwd_bwd_bridge(): |
| e, _ = encoder(v1) |
| an = F.normalize(anchors, dim=-1) |
| c = e @ an.T; t = 1 - c |
| a = F.softmax(c/0.1, dim=-1) |
| p = patchwork(t); b = bridge(p) |
| at = a.detach() |
| return _bwd(lambda: -(at * F.log_softmax(b, dim=-1)).sum(-1).mean()) |
|
|
| record("fwd+bwd CE", _fwd_bwd_ce, n_runs=10, warmup=3) |
| record("fwd+bwd NCE_emb", _fwd_bwd_nce_emb, n_runs=10, warmup=3) |
| record("fwd+bwd NCE_pw", _fwd_bwd_nce_pw, n_runs=10, warmup=3) |
| record("fwd+bwd CV old", _fwd_bwd_cv_old, n_runs=10, warmup=3) |
| record("fwd+bwd CV batch", _fwd_bwd_cv_batch, n_runs=10, warmup=3) |
| record("fwd+bwd Bridge", _fwd_bwd_bridge, n_runs=10, warmup=3) |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n\n{'='*80}") |
| print("FULL TIMING REPORT (sorted by cost)") |
| print(f"{'='*80}\n") |
|
|
| total = sum(ms for _, ms in results) |
| for name, ms in sorted(results, key=lambda x: -x[1]): |
| pct = 100 * ms / total if total > 0 else 0 |
| bar_len = int(pct / 2) |
| bar = "β" * bar_len + "β" * (40 - bar_len) |
| print(f" {name:35s} {ms:>9.3f}ms {bar} {pct:>5.1f}%") |
|
|
| print(f" {'β'*90}") |
| print(f" {'SUM':35s} {total:>9.3f}ms") |
|
|
| |
| print(f"\n{'='*80}") |
| print("CV SPEEDUP SUMMARY") |
| print(f"{'='*80}") |
|
|
| cv_pairs = {} |
| for name, ms in results: |
| if name.startswith("CV "): |
| key = name.split("n=")[1] if "n=" in name else "?" |
| tag = "old" if "OLD" in name else "batch" |
| cv_pairs.setdefault(key, {})[tag] = ms |
|
|
| for k in sorted(cv_pairs.keys()): |
| p = cv_pairs[k] |
| if 'old' in p and 'batch' in p: |
| speedup = p['old'] / p['batch'] if p['batch'] > 0 else 0 |
| print(f" n={k:>4s}: {p['old']:>8.2f}ms β {p['batch']:>8.2f}ms ({speedup:.1f}x speedup)") |
|
|
| |
| print(f"\n{'='*80}") |
| print("PER-STEP ESTIMATE") |
| print(f"{'='*80}") |
|
|
| fwd_time = next((ms for n, ms in results if n == "FULL forward (both views)"), 0) |
| bwd_ce = next((ms for n, ms in results if n == "fwd+bwd CE"), 0) |
| bwd_cv_old = next((ms for n, ms in results if n == "fwd+bwd CV old"), 0) |
| bwd_cv_new = next((ms for n, ms in results if n == "fwd+bwd CV batch"), 0) |
|
|
| print(f" Forward (both views): {fwd_time:.2f}ms") |
| print(f" fwd+bwd CE: {bwd_ce:.2f}ms") |
| print(f" fwd+bwd CV (old): {bwd_cv_old:.2f}ms") |
| print(f" fwd+bwd CV (batched): {bwd_cv_new:.2f}ms") |
| if bwd_cv_old > 0 and bwd_cv_new > 0: |
| saved = bwd_cv_old - bwd_cv_new |
| print(f" CV savings per step: {saved:.2f}ms ({saved/bwd_cv_old*100:.0f}%)") |
|
|
| print(f"\n{'='*80}") |
| print("PROFILING COMPLETE") |
| print(f"{'='*80}") |