AbstractPhil commited on
Commit
9b40529
Β·
verified Β·
1 Parent(s): 6aaaf01

Create geolip_loss_profiler.py

Browse files
Files changed (1) hide show
  1. geolip_loss_profiler.py +410 -0
geolip_loss_profiler.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Loss Spectrum Profiler β€” Standalone
3
+ =====================================
4
+ Builds its own model + noise data. Profiles every loss computation
5
+ in the GeoLIP pipeline with CUDA-synced microsecond timing.
6
+
7
+ Zero external dependencies beyond torch. Single cell.
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import time
14
+ import math
15
+
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ torch.backends.cuda.matmul.allow_tf32 = True
18
+ torch.backends.cudnn.allow_tf32 = True
19
+
20
+ # ═══════════════════════════════════════════════════════════════
21
+ # Config β€” matches our architecture
22
+ # ═══════════════════════════════════════════════════════════════
23
+ DIM = 256
24
+ N_ANCHORS = 256
25
+ N_COMP = 8
26
+ D_COMP = 64
27
+ BATCH = 256
28
+ NUM_CLASSES = 100
29
+
30
+
31
+ # ═══════════════════════════════════════════════════════════════
32
+ # Minimal model components (self-contained, no imports)
33
+ # ═══════════════════════════════════════════════════════════════
34
+
35
+ class ProfileEncoder(nn.Module):
36
+ def __init__(self, dim=256):
37
+ super().__init__()
38
+ self.features = nn.Sequential(
39
+ nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
40
+ nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),
41
+ nn.MaxPool2d(2),
42
+ nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
43
+ nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),
44
+ nn.MaxPool2d(2),
45
+ nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
46
+ nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),
47
+ nn.MaxPool2d(2),
48
+ nn.Conv2d(256, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(),
49
+ nn.Conv2d(384, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(),
50
+ nn.MaxPool2d(2),
51
+ nn.AdaptiveAvgPool2d(1), nn.Flatten(),
52
+ )
53
+ self.proj = nn.Sequential(nn.Linear(384, dim), nn.LayerNorm(dim))
54
+
55
+ def forward(self, x):
56
+ feat = self.features(x)
57
+ return F.normalize(self.proj(feat), dim=-1), feat[:, :1] # emb, fake raw_mag
58
+
59
+
60
+ class ProfilePatchwork(nn.Module):
61
+ def __init__(self, n_anchors=256, n_comp=8, d_comp=64):
62
+ super().__init__()
63
+ apc = n_anchors // n_comp
64
+ self.n_comp = n_comp
65
+ self.comps = nn.ModuleList([
66
+ nn.Sequential(nn.Linear(apc, d_comp * 2), nn.GELU(), nn.Linear(d_comp * 2, d_comp))
67
+ for _ in range(n_comp)
68
+ ])
69
+
70
+ def forward(self, tri):
71
+ apc = tri.shape[1] // self.n_comp
72
+ parts = []
73
+ for k in range(self.n_comp):
74
+ parts.append(self.comps[k](tri[:, k*apc:(k+1)*apc]))
75
+ return torch.cat(parts, dim=-1)
76
+
77
+
78
+ # Build all components
79
+ print("Building profile model...")
80
+ encoder = ProfileEncoder(DIM).to(DEVICE)
81
+ anchors = nn.Parameter(F.normalize(torch.randn(N_ANCHORS, DIM, device=DEVICE), dim=-1))
82
+ patchwork = ProfilePatchwork(N_ANCHORS, N_COMP, D_COMP).to(DEVICE)
83
+ bridge = nn.Linear(N_COMP * D_COMP, N_ANCHORS).to(DEVICE)
84
+ task_head = nn.Sequential(
85
+ nn.Linear(N_ANCHORS + N_COMP * D_COMP + DIM, N_COMP * D_COMP),
86
+ nn.GELU(), nn.LayerNorm(N_COMP * D_COMP), nn.Dropout(0.1),
87
+ nn.Linear(N_COMP * D_COMP, NUM_CLASSES),
88
+ ).to(DEVICE)
89
+
90
+ # Fake batch β€” random images + labels
91
+ v1 = torch.randn(BATCH, 3, 32, 32, device=DEVICE)
92
+ v2 = torch.randn(BATCH, 3, 32, 32, device=DEVICE)
93
+ targets = torch.randint(0, NUM_CLASSES, (BATCH,), device=DEVICE)
94
+ labels_nce = torch.arange(BATCH, device=DEVICE)
95
+
96
+ # Pre-compute intermediates
97
+ with torch.no_grad():
98
+ emb1, raw_mag1 = encoder(v1)
99
+ emb2, raw_mag2 = encoder(v2)
100
+ anchors_n = F.normalize(anchors, dim=-1)
101
+ cos1 = emb1 @ anchors_n.T
102
+ cos2 = emb2 @ anchors_n.T
103
+ tri1 = 1.0 - cos1
104
+ tri2 = 1.0 - cos2
105
+ assign1 = F.softmax(cos1 / 0.1, dim=-1)
106
+ assign2 = F.softmax(cos2 / 0.1, dim=-1)
107
+ pw1 = patchwork(tri1)
108
+ pw2 = patchwork(tri2)
109
+ bridge1 = bridge(pw1)
110
+ feat1 = torch.cat([assign1, pw1, emb1], dim=-1)
111
+ logits1 = task_head(feat1)
112
+
113
+ all_params = (list(encoder.parameters()) + [anchors] +
114
+ list(patchwork.parameters()) + list(bridge.parameters()) +
115
+ list(task_head.parameters()))
116
+
117
+ print(f" Device: {DEVICE}")
118
+ print(f" Batch: {BATCH}, Dim: {DIM}, Anchors: {N_ANCHORS}, Comp: {N_COMP}Γ—{D_COMP}")
119
+ n_params = sum(p.numel() for p in all_params)
120
+ print(f" Parameters: {n_params:,}")
121
+
122
+
123
+ # ════════════════════════════════════��══════════════════════════
124
+ # Timer
125
+ # ═══════════════════════════════════════════════════════════════
126
+
127
+ def timed(name, fn, n_runs=30, warmup=5):
128
+ """CUDA-synced timing. Returns (result, avg_ms)."""
129
+ for _ in range(warmup):
130
+ r = fn()
131
+ torch.cuda.synchronize()
132
+ times = []
133
+ for _ in range(n_runs):
134
+ torch.cuda.synchronize()
135
+ t0 = time.perf_counter()
136
+ r = fn()
137
+ torch.cuda.synchronize()
138
+ times.append((time.perf_counter() - t0) * 1000)
139
+ avg = sum(times) / len(times)
140
+ return r, avg
141
+
142
+ results = []
143
+
144
+ def record(name, fn, **kw):
145
+ _, ms = timed(name, fn, **kw)
146
+ results.append((name, ms))
147
+ return ms
148
+
149
+
150
+ # ═══════════════════════════════════════════════════════════════
151
+ # SECTION 1: Forward Components
152
+ # ═══════════════════════════════════════════════════════════════
153
+
154
+ print(f"\n{'='*80}")
155
+ print("SECTION 1: FORWARD PASS COMPONENTS")
156
+ print(f"{'='*80}\n")
157
+
158
+ record("encoder(v1)", lambda: encoder(v1))
159
+ record("triangulation (emb@A.T)", lambda: emb1 @ anchors_n.T)
160
+ record("soft_assign (softmax)", lambda: F.softmax(cos1 / 0.1, dim=-1))
161
+ record("patchwork(tri)", lambda: patchwork(tri1))
162
+ record("bridge(pw)", lambda: bridge(pw1))
163
+ record("task_head(feat)", lambda: task_head(feat1))
164
+
165
+ def _full_fwd():
166
+ e1, _ = encoder(v1)
167
+ e2, _ = encoder(v2)
168
+ an = F.normalize(anchors, dim=-1)
169
+ c1 = e1 @ an.T; c2 = e2 @ an.T
170
+ t1 = 1 - c1; t2 = 1 - c2
171
+ a1 = F.softmax(c1/0.1, dim=-1); a2 = F.softmax(c2/0.1, dim=-1)
172
+ p1 = patchwork(t1); p2 = patchwork(t2)
173
+ b1 = bridge(p1)
174
+ f1 = torch.cat([a1, p1, e1], -1)
175
+ return task_head(f1)
176
+
177
+ record("FULL forward (both views)", _full_fwd)
178
+
179
+
180
+ # ═══════════════════════════════════════════════════════════════
181
+ # SECTION 2: Individual Loss Terms (forward only)
182
+ # ═══════════════════════════════════════════════════════════════
183
+
184
+ print(f"\n{'='*80}")
185
+ print("SECTION 2: INDIVIDUAL LOSS TERMS (forward only)")
186
+ print(f"{'='*80}\n")
187
+
188
+ record("CE (cross_entropy)", lambda: F.cross_entropy(logits1, targets))
189
+
190
+ record("NCE_emb (BΓ—B + CE)", lambda: F.cross_entropy(
191
+ emb1 @ emb2.T / 0.07, labels_nce))
192
+
193
+ record("NCE_pw (norm + BΓ—B + CE)", lambda: F.cross_entropy(
194
+ F.normalize(pw1, dim=-1) @ F.normalize(pw2, dim=-1).T / 0.1, labels_nce))
195
+
196
+ record("NCE_tri (norm + BΓ—B + CE)", lambda: F.cross_entropy(
197
+ F.normalize(tri1, dim=-1) @ F.normalize(tri2, dim=-1).T / 0.1, labels_nce))
198
+
199
+ record("NCE_assign (BΓ—B + CE)", lambda: F.cross_entropy(
200
+ assign1 @ assign2.T / 0.1, labels_nce))
201
+
202
+ def _bridge_loss():
203
+ at = assign1.detach()
204
+ return -(at * F.log_softmax(bridge1, dim=-1)).sum(-1).mean()
205
+ record("Bridge (soft CE)", _bridge_loss)
206
+
207
+ def _assign_bce():
208
+ nearest = cos1.argmax(dim=-1)
209
+ hard = torch.zeros_like(assign1)
210
+ hard.scatter_(1, nearest.unsqueeze(1), 1.0)
211
+ return F.binary_cross_entropy(assign1.float().clamp(1e-7, 1-1e-7), hard.float())
212
+ record("Assign BCE", _assign_bce)
213
+
214
+ record("Attraction (max + mean)", lambda: (1.0 - cos1.max(dim=1).values).mean())
215
+
216
+ def _spread():
217
+ a = F.normalize(anchors, dim=-1)
218
+ sim = a @ a.T
219
+ mask = ~torch.eye(N_ANCHORS, dtype=torch.bool, device=DEVICE)
220
+ return F.relu(sim[mask]).mean()
221
+ record("Spread (AΓ—A + relu)", _spread)
222
+
223
+ record("kNN (BΓ—B + argmax)", lambda: (
224
+ targets[(emb1 @ emb1.T).fill_diagonal_(-1).argmax(1)] == targets).float().mean())
225
+
226
+
227
+ # ═══════════════════════════════════════════════════════════════
228
+ # SECTION 3: CV Loss β€” Old vs Batched
229
+ # ═══════════════════════════════════════════════════════════════
230
+
231
+ print(f"\n{'='*80}")
232
+ print("SECTION 3: CV LOSS β€” OLD SEQUENTIAL vs BATCHED")
233
+ print(f"{'='*80}\n")
234
+
235
+ # Old sequential
236
+ def _cv_old(n_samples=64):
237
+ vols = []
238
+ for _ in range(n_samples):
239
+ idx = torch.randperm(min(BATCH, 256), device=DEVICE)[:5]
240
+ pts = emb1[idx].unsqueeze(0)
241
+ gram = torch.bmm(pts, pts.transpose(1, 2))
242
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
243
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
244
+ cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=pts.dtype)
245
+ cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
246
+ pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2))
247
+ v2 = pf * torch.linalg.det(cm.float())
248
+ if v2[0].item() > 1e-20:
249
+ vols.append(v2[0].sqrt())
250
+ if len(vols) < 5:
251
+ return torch.tensor(0.0, device=DEVICE)
252
+ vt = torch.stack(vols)
253
+ return ((vt.std() / (vt.mean() + 1e-8)) - 0.22).pow(2)
254
+
255
+ # Batched
256
+ def _cv_batched(n_samples=64):
257
+ pool = min(BATCH, 256)
258
+ rand_keys = torch.rand(n_samples, pool, device=DEVICE)
259
+ indices = rand_keys.argsort(dim=1)[:, :5]
260
+ pts = emb1[:pool][indices]
261
+ gram = torch.bmm(pts, pts.transpose(1, 2))
262
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
263
+ d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
264
+ cm = torch.zeros(n_samples, 6, 6, device=DEVICE, dtype=pts.dtype)
265
+ cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2
266
+ pf = ((-1)**5) / ((2**4) * (math.factorial(4)**2))
267
+ dets = pf * torch.linalg.det(cm.float())
268
+ valid = dets > 1e-20
269
+ vols = dets[valid].sqrt()
270
+ if vols.shape[0] < 5:
271
+ return torch.tensor(0.0, device=DEVICE)
272
+ return ((vols.std() / (vols.mean() + 1e-8)) - 0.22).pow(2)
273
+
274
+ for ns in [32, 64, 128, 200]:
275
+ record(f"CV OLD n={ns}", lambda ns=ns: _cv_old(ns), n_runs=10)
276
+ record(f"CV BATCH n={ns}", lambda ns=ns: _cv_batched(ns), n_runs=10)
277
+
278
+ # Non-differentiable metric versions
279
+ def _cv_metric_old(n_samples=200):
280
+ with torch.no_grad():
281
+ return _cv_old(n_samples)
282
+ def _cv_metric_batch(n_samples=200):
283
+ with torch.no_grad():
284
+ return _cv_batched(n_samples)
285
+
286
+ record("CV metric OLD n=200", _cv_metric_old, n_runs=10)
287
+ record("CV metric BATCH n=200", _cv_metric_batch, n_runs=10)
288
+
289
+
290
+ # ═══════════════════════════════════════════════════════════════
291
+ # SECTION 4: Backward costs
292
+ # ═══════════════════════════════════════════════════════════════
293
+
294
+ print(f"\n{'='*80}")
295
+ print("SECTION 4: BACKWARD COSTS (forward + backward)")
296
+ print(f"{'='*80}\n")
297
+
298
+ def _bwd(loss_fn):
299
+ for p in all_params:
300
+ if p.grad is not None:
301
+ p.grad.zero_()
302
+ loss = loss_fn()
303
+ if torch.is_tensor(loss) and loss.requires_grad:
304
+ loss.backward()
305
+ return loss
306
+
307
+ # Need fresh forward for each backward
308
+ def _fwd_bwd_ce():
309
+ e, _ = encoder(v1)
310
+ an = F.normalize(anchors, dim=-1)
311
+ c = e @ an.T; t = 1 - c
312
+ a = F.softmax(c/0.1, dim=-1)
313
+ p = patchwork(t)
314
+ f = torch.cat([a, p, e], -1)
315
+ return _bwd(lambda: F.cross_entropy(task_head(f), targets))
316
+
317
+ def _fwd_bwd_nce_emb():
318
+ e1, _ = encoder(v1); e2, _ = encoder(v2)
319
+ return _bwd(lambda: F.cross_entropy(e1 @ e2.T / 0.07, labels_nce))
320
+
321
+ def _fwd_bwd_nce_pw():
322
+ e1, _ = encoder(v1); e2, _ = encoder(v2)
323
+ an = F.normalize(anchors, dim=-1)
324
+ t1 = 1 - e1 @ an.T; t2 = 1 - e2 @ an.T
325
+ p1 = patchwork(t1); p2 = patchwork(t2)
326
+ return _bwd(lambda: F.cross_entropy(
327
+ F.normalize(p1, dim=-1) @ F.normalize(p2, dim=-1).T / 0.1, labels_nce))
328
+
329
+ def _fwd_bwd_cv_old():
330
+ e, _ = encoder(v1)
331
+ return _bwd(lambda: _cv_old(64))
332
+
333
+ def _fwd_bwd_cv_batch():
334
+ e, _ = encoder(v1)
335
+ return _bwd(lambda: _cv_batched(64))
336
+
337
+ def _fwd_bwd_bridge():
338
+ e, _ = encoder(v1)
339
+ an = F.normalize(anchors, dim=-1)
340
+ c = e @ an.T; t = 1 - c
341
+ a = F.softmax(c/0.1, dim=-1)
342
+ p = patchwork(t); b = bridge(p)
343
+ at = a.detach()
344
+ return _bwd(lambda: -(at * F.log_softmax(b, dim=-1)).sum(-1).mean())
345
+
346
+ record("fwd+bwd CE", _fwd_bwd_ce, n_runs=10, warmup=3)
347
+ record("fwd+bwd NCE_emb", _fwd_bwd_nce_emb, n_runs=10, warmup=3)
348
+ record("fwd+bwd NCE_pw", _fwd_bwd_nce_pw, n_runs=10, warmup=3)
349
+ record("fwd+bwd CV old", _fwd_bwd_cv_old, n_runs=10, warmup=3)
350
+ record("fwd+bwd CV batch", _fwd_bwd_cv_batch, n_runs=10, warmup=3)
351
+ record("fwd+bwd Bridge", _fwd_bwd_bridge, n_runs=10, warmup=3)
352
+
353
+
354
+ # ═══════════════════════════════════════════════════════════════
355
+ # REPORT
356
+ # ═══════════════════════════════════════════════════════════════
357
+
358
+ print(f"\n\n{'='*80}")
359
+ print("FULL TIMING REPORT (sorted by cost)")
360
+ print(f"{'='*80}\n")
361
+
362
+ total = sum(ms for _, ms in results)
363
+ for name, ms in sorted(results, key=lambda x: -x[1]):
364
+ pct = 100 * ms / total if total > 0 else 0
365
+ bar_len = int(pct / 2)
366
+ bar = "β–ˆ" * bar_len + "β–‘" * (40 - bar_len)
367
+ print(f" {name:35s} {ms:>9.3f}ms {bar} {pct:>5.1f}%")
368
+
369
+ print(f" {'─'*90}")
370
+ print(f" {'SUM':35s} {total:>9.3f}ms")
371
+
372
+ # CV speedup summary
373
+ print(f"\n{'='*80}")
374
+ print("CV SPEEDUP SUMMARY")
375
+ print(f"{'='*80}")
376
+
377
+ cv_pairs = {}
378
+ for name, ms in results:
379
+ if name.startswith("CV "):
380
+ key = name.split("n=")[1] if "n=" in name else "?"
381
+ tag = "old" if "OLD" in name else "batch"
382
+ cv_pairs.setdefault(key, {})[tag] = ms
383
+
384
+ for k in sorted(cv_pairs.keys()):
385
+ p = cv_pairs[k]
386
+ if 'old' in p and 'batch' in p:
387
+ speedup = p['old'] / p['batch'] if p['batch'] > 0 else 0
388
+ print(f" n={k:>4s}: {p['old']:>8.2f}ms β†’ {p['batch']:>8.2f}ms ({speedup:.1f}x speedup)")
389
+
390
+ # Per-step estimate
391
+ print(f"\n{'='*80}")
392
+ print("PER-STEP ESTIMATE")
393
+ print(f"{'='*80}")
394
+
395
+ fwd_time = next((ms for n, ms in results if n == "FULL forward (both views)"), 0)
396
+ bwd_ce = next((ms for n, ms in results if n == "fwd+bwd CE"), 0)
397
+ bwd_cv_old = next((ms for n, ms in results if n == "fwd+bwd CV old"), 0)
398
+ bwd_cv_new = next((ms for n, ms in results if n == "fwd+bwd CV batch"), 0)
399
+
400
+ print(f" Forward (both views): {fwd_time:.2f}ms")
401
+ print(f" fwd+bwd CE: {bwd_ce:.2f}ms")
402
+ print(f" fwd+bwd CV (old): {bwd_cv_old:.2f}ms")
403
+ print(f" fwd+bwd CV (batched): {bwd_cv_new:.2f}ms")
404
+ if bwd_cv_old > 0 and bwd_cv_new > 0:
405
+ saved = bwd_cv_old - bwd_cv_new
406
+ print(f" CV savings per step: {saved:.2f}ms ({saved/bwd_cv_old*100:.0f}%)")
407
+
408
+ print(f"\n{'='*80}")
409
+ print("PROFILING COMPLETE")
410
+ print(f"{'='*80}")