AbstractPhil commited on
Commit
8f9bbd4
·
verified ·
1 Parent(s): 1ff23e0

Upload 7 files

Browse files
spectral/notebooks/experiment_1_signal_decomposition.ipynb CHANGED
@@ -83,8 +83,16 @@
83
  "from collections import defaultdict\n",
84
  "\n",
85
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
 
 
 
 
 
 
86
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
87
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
 
 
88
  "\n",
89
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
90
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
@@ -454,8 +462,10 @@
454
  " return torch.stack(vols)\n",
455
  "\n",
456
  "\n",
457
- "def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):\n",
458
- " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\"\"\"\n",
 
 
459
  " if emb.shape[0] < n_points:\n",
460
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
461
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
@@ -668,8 +678,9 @@
668
  " l_spread = spread_loss(constellation.anchors)\n",
669
  " ld['spread'] = l_spread\n",
670
  "\n",
671
- " # \u2500\u2500 kNN \u2500\u2500\n",
672
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
673
  "\n",
674
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
675
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
@@ -962,19 +973,22 @@
962
  " w_assign=0.5, w_assign_nce=0.25,\n",
963
  " w_nce_tri=0.5, w_attract=0.25,\n",
964
  " w_cv=0.01, w_spread=0.01,\n",
965
- " cv_batched=True):\n",
966
- " \"\"\"Three-domain cooperative loss.\n",
 
 
 
967
  " Returns:\n",
968
  " total_loss, loss_dict\n",
969
  " \"\"\"\n",
970
  " ld = {}\n",
971
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
972
- " # \u2500\u2500 EXTERNAL \u2500\u2500\n",
973
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
974
  " ld['ce'], ld['acc'] = l_ce, acc\n",
975
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
976
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
977
- " # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
978
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
979
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
980
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
@@ -982,7 +996,7 @@
982
  " output['bridge1'], output['bridge2'],\n",
983
  " output['assign1'], output['assign2'])\n",
984
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
985
- " # \u2500\u2500 INTERNAL \u2500\u2500\n",
986
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
987
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
988
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
@@ -993,12 +1007,14 @@
993
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
994
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
995
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
996
- " l_cv = cv_loss(emb1, target=self.cv_target, batched=cv_batched)\n",
 
997
  " ld['cv'] = l_cv\n",
998
  " l_spread = spread_loss(self.constellation.anchors)\n",
999
  " ld['spread'] = l_spread\n",
1000
- " # \u2500\u2500 kNN \u2500\u2500\n",
1001
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
1002
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1003
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1004
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
@@ -1335,9 +1351,11 @@
1335
  "\n",
1336
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1337
  " num_workers=num_workers, pin_memory=True,\n",
1338
- " drop_last=True, collate_fn=paired_collate)\n",
 
1339
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1340
- " num_workers=num_workers, pin_memory=True)\n",
 
1341
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1342
  " return train_loader, val_loader\n",
1343
  "\n",
@@ -1468,7 +1486,7 @@
1468
  " optimizer.zero_grad()\n",
1469
  "\n",
1470
  " output = model.forward_paired(v1, v2)\n",
1471
- " loss, ld = model.compute_loss(output, labels, **lw)\n",
1472
  "\n",
1473
  " loss.backward()\n",
1474
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
 
83
  "from collections import defaultdict\n",
84
  "\n",
85
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
86
+ "\n",
87
+ "# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
88
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
89
+ "torch.backends.cudnn.allow_tf32 = True\n",
90
+ "torch.backends.cudnn.benchmark = True\n",
91
+ "\n",
92
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
93
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
94
+ "if device.type == \"cuda\":\n",
95
+ " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
96
  "\n",
97
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
98
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
 
462
  " return torch.stack(vols)\n",
463
  "\n",
464
  "\n",
465
+ "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
466
+ " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
467
+ " Default n_samples=32 for training speed (141x faster than sequential).\n",
468
+ " Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
469
  " if emb.shape[0] < n_points:\n",
470
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
471
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
 
678
  " l_spread = spread_loss(constellation.anchors)\n",
679
  " ld['spread'] = l_spread\n",
680
  "\n",
681
+ " # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
682
+ " if targets is not None and emb1.shape[0] <= 512:\n",
683
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
684
  "\n",
685
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
686
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
 
973
  " w_assign=0.5, w_assign_nce=0.25,\n",
974
  " w_nce_tri=0.5, w_attract=0.25,\n",
975
  " w_cv=0.01, w_spread=0.01,\n",
976
+ " cv_batched=True, compute_knn=False):\n",
977
+ " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
978
+ " Args:\n",
979
+ " compute_knn: if False (default), skip kNN during training for speed.\n",
980
+ " Set True during validation or every N steps.\n",
981
  " Returns:\n",
982
  " total_loss, loss_dict\n",
983
  " \"\"\"\n",
984
  " ld = {}\n",
985
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
986
+ " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
987
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
988
  " ld['ce'], ld['acc'] = l_ce, acc\n",
989
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
990
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
991
+ " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
992
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
993
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
994
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
 
996
  " output['bridge1'], output['bridge2'],\n",
997
  " output['assign1'], output['assign2'])\n",
998
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
999
+ " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
1000
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
1001
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
1002
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
 
1007
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
1008
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
1009
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
1010
+ " # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
1011
+ " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
1012
  " ld['cv'] = l_cv\n",
1013
  " l_spread = spread_loss(self.constellation.anchors)\n",
1014
  " ld['spread'] = l_spread\n",
1015
+ " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
1016
+ " if compute_knn:\n",
1017
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
1018
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1019
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1020
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
 
1351
  "\n",
1352
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1353
  " num_workers=num_workers, pin_memory=True,\n",
1354
+ " drop_last=True, collate_fn=paired_collate,\n",
1355
+ " persistent_workers=(num_workers > 0))\n",
1356
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1357
+ " num_workers=num_workers, pin_memory=True,\n",
1358
+ " persistent_workers=(num_workers > 0))\n",
1359
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1360
  " return train_loader, val_loader\n",
1361
  "\n",
 
1486
  " optimizer.zero_grad()\n",
1487
  "\n",
1488
  " output = model.forward_paired(v1, v2)\n",
1489
+ " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
1490
  "\n",
1491
  " loss.backward()\n",
1492
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
spectral/notebooks/experiment_2_manifold_structures.ipynb CHANGED
@@ -75,8 +75,16 @@
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
 
 
 
 
 
 
78
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
79
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
 
 
80
  "\n",
81
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
82
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
@@ -446,8 +454,10 @@
446
  " return torch.stack(vols)\n",
447
  "\n",
448
  "\n",
449
- "def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):\n",
450
- " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\"\"\"\n",
 
 
451
  " if emb.shape[0] < n_points:\n",
452
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
453
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
@@ -660,8 +670,9 @@
660
  " l_spread = spread_loss(constellation.anchors)\n",
661
  " ld['spread'] = l_spread\n",
662
  "\n",
663
- " # \u2500\u2500 kNN \u2500\u2500\n",
664
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
665
  "\n",
666
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
667
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
@@ -954,19 +965,22 @@
954
  " w_assign=0.5, w_assign_nce=0.25,\n",
955
  " w_nce_tri=0.5, w_attract=0.25,\n",
956
  " w_cv=0.01, w_spread=0.01,\n",
957
- " cv_batched=True):\n",
958
- " \"\"\"Three-domain cooperative loss.\n",
 
 
 
959
  " Returns:\n",
960
  " total_loss, loss_dict\n",
961
  " \"\"\"\n",
962
  " ld = {}\n",
963
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
964
- " # \u2500\u2500 EXTERNAL \u2500\u2500\n",
965
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
966
  " ld['ce'], ld['acc'] = l_ce, acc\n",
967
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
968
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
969
- " # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
970
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
971
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
972
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
@@ -974,7 +988,7 @@
974
  " output['bridge1'], output['bridge2'],\n",
975
  " output['assign1'], output['assign2'])\n",
976
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
977
- " # \u2500\u2500 INTERNAL \u2500\u2500\n",
978
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
979
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
980
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
@@ -985,12 +999,14 @@
985
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
986
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
987
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
988
- " l_cv = cv_loss(emb1, target=self.cv_target, batched=cv_batched)\n",
 
989
  " ld['cv'] = l_cv\n",
990
  " l_spread = spread_loss(self.constellation.anchors)\n",
991
  " ld['spread'] = l_spread\n",
992
- " # \u2500\u2500 kNN \u2500\u2500\n",
993
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
994
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
995
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
996
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
@@ -1327,9 +1343,11 @@
1327
  "\n",
1328
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1329
  " num_workers=num_workers, pin_memory=True,\n",
1330
- " drop_last=True, collate_fn=paired_collate)\n",
 
1331
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1332
- " num_workers=num_workers, pin_memory=True)\n",
 
1333
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1334
  " return train_loader, val_loader\n",
1335
  "\n",
@@ -1460,7 +1478,7 @@
1460
  " optimizer.zero_grad()\n",
1461
  "\n",
1462
  " output = model.forward_paired(v1, v2)\n",
1463
- " loss, ld = model.compute_loss(output, labels, **lw)\n",
1464
  "\n",
1465
  " loss.backward()\n",
1466
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
 
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
78
+ "\n",
79
+ "# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
80
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
81
+ "torch.backends.cudnn.allow_tf32 = True\n",
82
+ "torch.backends.cudnn.benchmark = True\n",
83
+ "\n",
84
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
85
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
86
+ "if device.type == \"cuda\":\n",
87
+ " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
88
  "\n",
89
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
90
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
 
454
  " return torch.stack(vols)\n",
455
  "\n",
456
  "\n",
457
+ "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
458
+ " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
459
+ " Default n_samples=32 for training speed (141x faster than sequential).\n",
460
+ " Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
461
  " if emb.shape[0] < n_points:\n",
462
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
463
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
 
670
  " l_spread = spread_loss(constellation.anchors)\n",
671
  " ld['spread'] = l_spread\n",
672
  "\n",
673
+ " # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
674
+ " if targets is not None and emb1.shape[0] <= 512:\n",
675
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
676
  "\n",
677
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
678
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
 
965
  " w_assign=0.5, w_assign_nce=0.25,\n",
966
  " w_nce_tri=0.5, w_attract=0.25,\n",
967
  " w_cv=0.01, w_spread=0.01,\n",
968
+ " cv_batched=True, compute_knn=False):\n",
969
+ " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
970
+ " Args:\n",
971
+ " compute_knn: if False (default), skip kNN during training for speed.\n",
972
+ " Set True during validation or every N steps.\n",
973
  " Returns:\n",
974
  " total_loss, loss_dict\n",
975
  " \"\"\"\n",
976
  " ld = {}\n",
977
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
978
+ " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
979
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
980
  " ld['ce'], ld['acc'] = l_ce, acc\n",
981
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
982
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
983
+ " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
984
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
985
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
986
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
 
988
  " output['bridge1'], output['bridge2'],\n",
989
  " output['assign1'], output['assign2'])\n",
990
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
991
+ " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
992
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
993
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
994
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
 
999
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
1000
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
1001
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
1002
+ " # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
1003
+ " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
1004
  " ld['cv'] = l_cv\n",
1005
  " l_spread = spread_loss(self.constellation.anchors)\n",
1006
  " ld['spread'] = l_spread\n",
1007
+ " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
1008
+ " if compute_knn:\n",
1009
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
1010
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1011
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1012
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
 
1343
  "\n",
1344
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1345
  " num_workers=num_workers, pin_memory=True,\n",
1346
+ " drop_last=True, collate_fn=paired_collate,\n",
1347
+ " persistent_workers=(num_workers > 0))\n",
1348
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1349
+ " num_workers=num_workers, pin_memory=True,\n",
1350
+ " persistent_workers=(num_workers > 0))\n",
1351
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1352
  " return train_loader, val_loader\n",
1353
  "\n",
 
1478
  " optimizer.zero_grad()\n",
1479
  "\n",
1480
  " output = model.forward_paired(v1, v2)\n",
1481
+ " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
1482
  "\n",
1483
  " loss.backward()\n",
1484
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
spectral/notebooks/experiment_3_compact_representations.ipynb CHANGED
@@ -74,8 +74,16 @@
74
  "from collections import defaultdict\n",
75
  "\n",
76
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
 
 
 
 
 
 
77
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
78
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
 
 
79
  "\n",
80
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
81
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
@@ -445,8 +453,10 @@
445
  " return torch.stack(vols)\n",
446
  "\n",
447
  "\n",
448
- "def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):\n",
449
- " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\"\"\"\n",
 
 
450
  " if emb.shape[0] < n_points:\n",
451
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
452
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
@@ -659,8 +669,9 @@
659
  " l_spread = spread_loss(constellation.anchors)\n",
660
  " ld['spread'] = l_spread\n",
661
  "\n",
662
- " # \u2500\u2500 kNN \u2500\u2500\n",
663
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
664
  "\n",
665
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
666
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
@@ -953,19 +964,22 @@
953
  " w_assign=0.5, w_assign_nce=0.25,\n",
954
  " w_nce_tri=0.5, w_attract=0.25,\n",
955
  " w_cv=0.01, w_spread=0.01,\n",
956
- " cv_batched=True):\n",
957
- " \"\"\"Three-domain cooperative loss.\n",
 
 
 
958
  " Returns:\n",
959
  " total_loss, loss_dict\n",
960
  " \"\"\"\n",
961
  " ld = {}\n",
962
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
963
- " # \u2500\u2500 EXTERNAL \u2500\u2500\n",
964
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
965
  " ld['ce'], ld['acc'] = l_ce, acc\n",
966
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
967
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
968
- " # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
969
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
970
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
971
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
@@ -973,7 +987,7 @@
973
  " output['bridge1'], output['bridge2'],\n",
974
  " output['assign1'], output['assign2'])\n",
975
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
976
- " # \u2500\u2500 INTERNAL \u2500\u2500\n",
977
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
978
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
979
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
@@ -984,12 +998,14 @@
984
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
985
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
986
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
987
- " l_cv = cv_loss(emb1, target=self.cv_target, batched=cv_batched)\n",
 
988
  " ld['cv'] = l_cv\n",
989
  " l_spread = spread_loss(self.constellation.anchors)\n",
990
  " ld['spread'] = l_spread\n",
991
- " # \u2500\u2500 kNN \u2500\u2500\n",
992
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
993
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
994
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
995
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
@@ -1326,9 +1342,11 @@
1326
  "\n",
1327
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1328
  " num_workers=num_workers, pin_memory=True,\n",
1329
- " drop_last=True, collate_fn=paired_collate)\n",
 
1330
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1331
- " num_workers=num_workers, pin_memory=True)\n",
 
1332
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1333
  " return train_loader, val_loader\n",
1334
  "\n",
@@ -1459,7 +1477,7 @@
1459
  " optimizer.zero_grad()\n",
1460
  "\n",
1461
  " output = model.forward_paired(v1, v2)\n",
1462
- " loss, ld = model.compute_loss(output, labels, **lw)\n",
1463
  "\n",
1464
  " loss.backward()\n",
1465
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
 
74
  "from collections import defaultdict\n",
75
  "\n",
76
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
77
+ "\n",
78
+ "# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
79
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
80
+ "torch.backends.cudnn.allow_tf32 = True\n",
81
+ "torch.backends.cudnn.benchmark = True\n",
82
+ "\n",
83
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
84
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
85
+ "if device.type == \"cuda\":\n",
86
+ " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
87
  "\n",
88
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
89
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
 
453
  " return torch.stack(vols)\n",
454
  "\n",
455
  "\n",
456
+ "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
457
+ " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
458
+ " Default n_samples=32 for training speed (141x faster than sequential).\n",
459
+ " Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
460
  " if emb.shape[0] < n_points:\n",
461
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
462
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
 
669
  " l_spread = spread_loss(constellation.anchors)\n",
670
  " ld['spread'] = l_spread\n",
671
  "\n",
672
+ " # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
673
+ " if targets is not None and emb1.shape[0] <= 512:\n",
674
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
675
  "\n",
676
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
677
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
 
964
  " w_assign=0.5, w_assign_nce=0.25,\n",
965
  " w_nce_tri=0.5, w_attract=0.25,\n",
966
  " w_cv=0.01, w_spread=0.01,\n",
967
+ " cv_batched=True, compute_knn=False):\n",
968
+ " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
969
+ " Args:\n",
970
+ " compute_knn: if False (default), skip kNN during training for speed.\n",
971
+ " Set True during validation or every N steps.\n",
972
  " Returns:\n",
973
  " total_loss, loss_dict\n",
974
  " \"\"\"\n",
975
  " ld = {}\n",
976
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
977
+ " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
978
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
979
  " ld['ce'], ld['acc'] = l_ce, acc\n",
980
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
981
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
982
+ " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
983
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
984
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
985
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
 
987
  " output['bridge1'], output['bridge2'],\n",
988
  " output['assign1'], output['assign2'])\n",
989
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
990
+ " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
991
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
992
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
993
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
 
998
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
999
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
1000
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
1001
+ " # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
1002
+ " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
1003
  " ld['cv'] = l_cv\n",
1004
  " l_spread = spread_loss(self.constellation.anchors)\n",
1005
  " ld['spread'] = l_spread\n",
1006
+ " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
1007
+ " if compute_knn:\n",
1008
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
1009
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1010
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1011
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
 
1342
  "\n",
1343
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1344
  " num_workers=num_workers, pin_memory=True,\n",
1345
+ " drop_last=True, collate_fn=paired_collate,\n",
1346
+ " persistent_workers=(num_workers > 0))\n",
1347
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1348
+ " num_workers=num_workers, pin_memory=True,\n",
1349
+ " persistent_workers=(num_workers > 0))\n",
1350
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1351
  " return train_loader, val_loader\n",
1352
  "\n",
 
1477
  " optimizer.zero_grad()\n",
1478
  "\n",
1479
  " output = model.forward_paired(v1, v2)\n",
1480
+ " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
1481
  "\n",
1482
  " loss.backward()\n",
1483
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
spectral/notebooks/experiment_4_invertible_transforms.ipynb CHANGED
@@ -75,8 +75,16 @@
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
 
 
 
 
 
 
78
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
79
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
 
 
80
  "\n",
81
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
82
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
@@ -446,8 +454,10 @@
446
  " return torch.stack(vols)\n",
447
  "\n",
448
  "\n",
449
- "def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):\n",
450
- " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\"\"\"\n",
 
 
451
  " if emb.shape[0] < n_points:\n",
452
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
453
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
@@ -660,8 +670,9 @@
660
  " l_spread = spread_loss(constellation.anchors)\n",
661
  " ld['spread'] = l_spread\n",
662
  "\n",
663
- " # \u2500\u2500 kNN \u2500\u2500\n",
664
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
665
  "\n",
666
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
667
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
@@ -954,19 +965,22 @@
954
  " w_assign=0.5, w_assign_nce=0.25,\n",
955
  " w_nce_tri=0.5, w_attract=0.25,\n",
956
  " w_cv=0.01, w_spread=0.01,\n",
957
- " cv_batched=True):\n",
958
- " \"\"\"Three-domain cooperative loss.\n",
 
 
 
959
  " Returns:\n",
960
  " total_loss, loss_dict\n",
961
  " \"\"\"\n",
962
  " ld = {}\n",
963
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
964
- " # \u2500\u2500 EXTERNAL \u2500\u2500\n",
965
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
966
  " ld['ce'], ld['acc'] = l_ce, acc\n",
967
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
968
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
969
- " # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
970
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
971
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
972
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
@@ -974,7 +988,7 @@
974
  " output['bridge1'], output['bridge2'],\n",
975
  " output['assign1'], output['assign2'])\n",
976
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
977
- " # \u2500\u2500 INTERNAL \u2500\u2500\n",
978
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
979
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
980
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
@@ -985,12 +999,14 @@
985
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
986
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
987
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
988
- " l_cv = cv_loss(emb1, target=self.cv_target, batched=cv_batched)\n",
 
989
  " ld['cv'] = l_cv\n",
990
  " l_spread = spread_loss(self.constellation.anchors)\n",
991
  " ld['spread'] = l_spread\n",
992
- " # \u2500\u2500 kNN \u2500\u2500\n",
993
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
994
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
995
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
996
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
@@ -1327,9 +1343,11 @@
1327
  "\n",
1328
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1329
  " num_workers=num_workers, pin_memory=True,\n",
1330
- " drop_last=True, collate_fn=paired_collate)\n",
 
1331
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1332
- " num_workers=num_workers, pin_memory=True)\n",
 
1333
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1334
  " return train_loader, val_loader\n",
1335
  "\n",
@@ -1460,7 +1478,7 @@
1460
  " optimizer.zero_grad()\n",
1461
  "\n",
1462
  " output = model.forward_paired(v1, v2)\n",
1463
- " loss, ld = model.compute_loss(output, labels, **lw)\n",
1464
  "\n",
1465
  " loss.backward()\n",
1466
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
 
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
78
+ "\n",
79
+ "# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
80
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
81
+ "torch.backends.cudnn.allow_tf32 = True\n",
82
+ "torch.backends.cudnn.benchmark = True\n",
83
+ "\n",
84
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
85
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
86
+ "if device.type == \"cuda\":\n",
87
+ " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
88
  "\n",
89
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
90
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
 
454
  " return torch.stack(vols)\n",
455
  "\n",
456
  "\n",
457
+ "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
458
+ " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
459
+ " Default n_samples=32 for training speed (141x faster than sequential).\n",
460
+ " Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
461
  " if emb.shape[0] < n_points:\n",
462
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
463
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
 
670
  " l_spread = spread_loss(constellation.anchors)\n",
671
  " ld['spread'] = l_spread\n",
672
  "\n",
673
+ " # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
674
+ " if targets is not None and emb1.shape[0] <= 512:\n",
675
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
676
  "\n",
677
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
678
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
 
965
  " w_assign=0.5, w_assign_nce=0.25,\n",
966
  " w_nce_tri=0.5, w_attract=0.25,\n",
967
  " w_cv=0.01, w_spread=0.01,\n",
968
+ " cv_batched=True, compute_knn=False):\n",
969
+ " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
970
+ " Args:\n",
971
+ " compute_knn: if False (default), skip kNN during training for speed.\n",
972
+ " Set True during validation or every N steps.\n",
973
  " Returns:\n",
974
  " total_loss, loss_dict\n",
975
  " \"\"\"\n",
976
  " ld = {}\n",
977
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
978
+ " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
979
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
980
  " ld['ce'], ld['acc'] = l_ce, acc\n",
981
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
982
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
983
+ " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
984
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
985
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
986
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
 
988
  " output['bridge1'], output['bridge2'],\n",
989
  " output['assign1'], output['assign2'])\n",
990
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
991
+ " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
992
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
993
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
994
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
 
999
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
1000
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
1001
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
1002
+ " # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
1003
+ " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
1004
  " ld['cv'] = l_cv\n",
1005
  " l_spread = spread_loss(self.constellation.anchors)\n",
1006
  " ld['spread'] = l_spread\n",
1007
+ " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
1008
+ " if compute_knn:\n",
1009
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
1010
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1011
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1012
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
 
1343
  "\n",
1344
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1345
  " num_workers=num_workers, pin_memory=True,\n",
1346
+ " drop_last=True, collate_fn=paired_collate,\n",
1347
+ " persistent_workers=(num_workers > 0))\n",
1348
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1349
+ " num_workers=num_workers, pin_memory=True,\n",
1350
+ " persistent_workers=(num_workers > 0))\n",
1351
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1352
  " return train_loader, val_loader\n",
1353
  "\n",
 
1478
  " optimizer.zero_grad()\n",
1479
  "\n",
1480
  " output = model.forward_paired(v1, v2)\n",
1481
+ " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
1482
  "\n",
1483
  " loss.backward()\n",
1484
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
spectral/notebooks/experiment_5_matrix_decompositions.ipynb CHANGED
@@ -75,8 +75,16 @@
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
 
 
 
 
 
 
78
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
79
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
 
 
80
  "\n",
81
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
82
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
@@ -446,8 +454,10 @@
446
  " return torch.stack(vols)\n",
447
  "\n",
448
  "\n",
449
- "def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):\n",
450
- " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\"\"\"\n",
 
 
451
  " if emb.shape[0] < n_points:\n",
452
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
453
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
@@ -660,8 +670,9 @@
660
  " l_spread = spread_loss(constellation.anchors)\n",
661
  " ld['spread'] = l_spread\n",
662
  "\n",
663
- " # \u2500\u2500 kNN \u2500\u2500\n",
664
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
665
  "\n",
666
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
667
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
@@ -954,19 +965,22 @@
954
  " w_assign=0.5, w_assign_nce=0.25,\n",
955
  " w_nce_tri=0.5, w_attract=0.25,\n",
956
  " w_cv=0.01, w_spread=0.01,\n",
957
- " cv_batched=True):\n",
958
- " \"\"\"Three-domain cooperative loss.\n",
 
 
 
959
  " Returns:\n",
960
  " total_loss, loss_dict\n",
961
  " \"\"\"\n",
962
  " ld = {}\n",
963
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
964
- " # \u2500\u2500 EXTERNAL \u2500\u2500\n",
965
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
966
  " ld['ce'], ld['acc'] = l_ce, acc\n",
967
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
968
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
969
- " # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
970
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
971
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
972
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
@@ -974,7 +988,7 @@
974
  " output['bridge1'], output['bridge2'],\n",
975
  " output['assign1'], output['assign2'])\n",
976
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
977
- " # \u2500\u2500 INTERNAL \u2500\u2500\n",
978
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
979
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
980
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
@@ -985,12 +999,14 @@
985
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
986
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
987
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
988
- " l_cv = cv_loss(emb1, target=self.cv_target, batched=cv_batched)\n",
 
989
  " ld['cv'] = l_cv\n",
990
  " l_spread = spread_loss(self.constellation.anchors)\n",
991
  " ld['spread'] = l_spread\n",
992
- " # \u2500\u2500 kNN \u2500\u2500\n",
993
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
994
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
995
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
996
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
@@ -1327,9 +1343,11 @@
1327
  "\n",
1328
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1329
  " num_workers=num_workers, pin_memory=True,\n",
1330
- " drop_last=True, collate_fn=paired_collate)\n",
 
1331
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1332
- " num_workers=num_workers, pin_memory=True)\n",
 
1333
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1334
  " return train_loader, val_loader\n",
1335
  "\n",
@@ -1460,7 +1478,7 @@
1460
  " optimizer.zero_grad()\n",
1461
  "\n",
1462
  " output = model.forward_paired(v1, v2)\n",
1463
- " loss, ld = model.compute_loss(output, labels, **lw)\n",
1464
  "\n",
1465
  " loss.backward()\n",
1466
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
 
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
78
+ "\n",
79
+ "# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
80
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
81
+ "torch.backends.cudnn.allow_tf32 = True\n",
82
+ "torch.backends.cudnn.benchmark = True\n",
83
+ "\n",
84
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
85
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
86
+ "if device.type == \"cuda\":\n",
87
+ " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
88
  "\n",
89
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
90
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
 
454
  " return torch.stack(vols)\n",
455
  "\n",
456
  "\n",
457
+ "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
458
+ " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
459
+ " Default n_samples=32 for training speed (141x faster than sequential).\n",
460
+ " Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
461
  " if emb.shape[0] < n_points:\n",
462
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
463
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
 
670
  " l_spread = spread_loss(constellation.anchors)\n",
671
  " ld['spread'] = l_spread\n",
672
  "\n",
673
+ " # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
674
+ " if targets is not None and emb1.shape[0] <= 512:\n",
675
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
676
  "\n",
677
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
678
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
 
965
  " w_assign=0.5, w_assign_nce=0.25,\n",
966
  " w_nce_tri=0.5, w_attract=0.25,\n",
967
  " w_cv=0.01, w_spread=0.01,\n",
968
+ " cv_batched=True, compute_knn=False):\n",
969
+ " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
970
+ " Args:\n",
971
+ " compute_knn: if False (default), skip kNN during training for speed.\n",
972
+ " Set True during validation or every N steps.\n",
973
  " Returns:\n",
974
  " total_loss, loss_dict\n",
975
  " \"\"\"\n",
976
  " ld = {}\n",
977
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
978
+ " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
979
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
980
  " ld['ce'], ld['acc'] = l_ce, acc\n",
981
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
982
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
983
+ " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
984
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
985
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
986
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
 
988
  " output['bridge1'], output['bridge2'],\n",
989
  " output['assign1'], output['assign2'])\n",
990
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
991
+ " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
992
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
993
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
994
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
 
999
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
1000
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
1001
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
1002
+ " # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
1003
+ " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
1004
  " ld['cv'] = l_cv\n",
1005
  " l_spread = spread_loss(self.constellation.anchors)\n",
1006
  " ld['spread'] = l_spread\n",
1007
+ " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
1008
+ " if compute_knn:\n",
1009
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
1010
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1011
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1012
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
 
1343
  "\n",
1344
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1345
  " num_workers=num_workers, pin_memory=True,\n",
1346
+ " drop_last=True, collate_fn=paired_collate,\n",
1347
+ " persistent_workers=(num_workers > 0))\n",
1348
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1349
+ " num_workers=num_workers, pin_memory=True,\n",
1350
+ " persistent_workers=(num_workers > 0))\n",
1351
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1352
  " return train_loader, val_loader\n",
1353
  "\n",
 
1478
  " optimizer.zero_grad()\n",
1479
  "\n",
1480
  " output = model.forward_paired(v1, v2)\n",
1481
+ " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
1482
  "\n",
1483
  " loss.backward()\n",
1484
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
spectral/notebooks/experiment_6_losses_and_anchors.ipynb CHANGED
@@ -77,8 +77,16 @@
77
  "from collections import defaultdict\n",
78
  "\n",
79
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
 
 
 
 
 
 
80
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
81
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
 
 
82
  "\n",
83
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
84
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
@@ -448,8 +456,10 @@
448
  " return torch.stack(vols)\n",
449
  "\n",
450
  "\n",
451
- "def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):\n",
452
- " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\"\"\"\n",
 
 
453
  " if emb.shape[0] < n_points:\n",
454
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
455
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
@@ -662,8 +672,9 @@
662
  " l_spread = spread_loss(constellation.anchors)\n",
663
  " ld['spread'] = l_spread\n",
664
  "\n",
665
- " # \u2500\u2500 kNN \u2500\u2500\n",
666
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
667
  "\n",
668
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
669
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
@@ -956,19 +967,22 @@
956
  " w_assign=0.5, w_assign_nce=0.25,\n",
957
  " w_nce_tri=0.5, w_attract=0.25,\n",
958
  " w_cv=0.01, w_spread=0.01,\n",
959
- " cv_batched=True):\n",
960
- " \"\"\"Three-domain cooperative loss.\n",
 
 
 
961
  " Returns:\n",
962
  " total_loss, loss_dict\n",
963
  " \"\"\"\n",
964
  " ld = {}\n",
965
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
966
- " # \u2500\u2500 EXTERNAL \u2500\u2500\n",
967
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
968
  " ld['ce'], ld['acc'] = l_ce, acc\n",
969
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
970
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
971
- " # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
972
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
973
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
974
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
@@ -976,7 +990,7 @@
976
  " output['bridge1'], output['bridge2'],\n",
977
  " output['assign1'], output['assign2'])\n",
978
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
979
- " # \u2500\u2500 INTERNAL \u2500\u2500\n",
980
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
981
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
982
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
@@ -987,12 +1001,14 @@
987
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
988
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
989
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
990
- " l_cv = cv_loss(emb1, target=self.cv_target, batched=cv_batched)\n",
 
991
  " ld['cv'] = l_cv\n",
992
  " l_spread = spread_loss(self.constellation.anchors)\n",
993
  " ld['spread'] = l_spread\n",
994
- " # \u2500\u2500 kNN \u2500\u2500\n",
995
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
996
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
997
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
998
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
@@ -1329,9 +1345,11 @@
1329
  "\n",
1330
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1331
  " num_workers=num_workers, pin_memory=True,\n",
1332
- " drop_last=True, collate_fn=paired_collate)\n",
 
1333
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1334
- " num_workers=num_workers, pin_memory=True)\n",
 
1335
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1336
  " return train_loader, val_loader\n",
1337
  "\n",
@@ -1462,7 +1480,7 @@
1462
  " optimizer.zero_grad()\n",
1463
  "\n",
1464
  " output = model.forward_paired(v1, v2)\n",
1465
- " loss, ld = model.compute_loss(output, labels, **lw)\n",
1466
  "\n",
1467
  " loss.backward()\n",
1468
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
 
77
  "from collections import defaultdict\n",
78
  "\n",
79
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
80
+ "\n",
81
+ "# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
82
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
83
+ "torch.backends.cudnn.allow_tf32 = True\n",
84
+ "torch.backends.cudnn.benchmark = True\n",
85
+ "\n",
86
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
87
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
88
+ "if device.type == \"cuda\":\n",
89
+ " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
90
  "\n",
91
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
92
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
 
456
  " return torch.stack(vols)\n",
457
  "\n",
458
  "\n",
459
+ "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
460
+ " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
461
+ " Default n_samples=32 for training speed (141x faster than sequential).\n",
462
+ " Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
463
  " if emb.shape[0] < n_points:\n",
464
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
465
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
 
672
  " l_spread = spread_loss(constellation.anchors)\n",
673
  " ld['spread'] = l_spread\n",
674
  "\n",
675
+ " # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
676
+ " if targets is not None and emb1.shape[0] <= 512:\n",
677
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
678
  "\n",
679
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
680
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
 
967
  " w_assign=0.5, w_assign_nce=0.25,\n",
968
  " w_nce_tri=0.5, w_attract=0.25,\n",
969
  " w_cv=0.01, w_spread=0.01,\n",
970
+ " cv_batched=True, compute_knn=False):\n",
971
+ " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
972
+ " Args:\n",
973
+ " compute_knn: if False (default), skip kNN during training for speed.\n",
974
+ " Set True during validation or every N steps.\n",
975
  " Returns:\n",
976
  " total_loss, loss_dict\n",
977
  " \"\"\"\n",
978
  " ld = {}\n",
979
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
980
+ " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
981
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
982
  " ld['ce'], ld['acc'] = l_ce, acc\n",
983
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
984
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
985
+ " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
986
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
987
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
988
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
 
990
  " output['bridge1'], output['bridge2'],\n",
991
  " output['assign1'], output['assign2'])\n",
992
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
993
+ " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
994
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
995
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
996
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
 
1001
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
1002
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
1003
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
1004
+ " # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
1005
+ " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
1006
  " ld['cv'] = l_cv\n",
1007
  " l_spread = spread_loss(self.constellation.anchors)\n",
1008
  " ld['spread'] = l_spread\n",
1009
+ " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
1010
+ " if compute_knn:\n",
1011
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
1012
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1013
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1014
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
 
1345
  "\n",
1346
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1347
  " num_workers=num_workers, pin_memory=True,\n",
1348
+ " drop_last=True, collate_fn=paired_collate,\n",
1349
+ " persistent_workers=(num_workers > 0))\n",
1350
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1351
+ " num_workers=num_workers, pin_memory=True,\n",
1352
+ " persistent_workers=(num_workers > 0))\n",
1353
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1354
  " return train_loader, val_loader\n",
1355
  "\n",
 
1480
  " optimizer.zero_grad()\n",
1481
  "\n",
1482
  " output = model.forward_paired(v1, v2)\n",
1483
+ " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
1484
  "\n",
1485
  " loss.backward()\n",
1486
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
spectral/notebooks/experiment_7_composite_pipelines.ipynb CHANGED
@@ -75,8 +75,16 @@
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
 
 
 
 
 
 
78
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
79
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
 
 
80
  "\n",
81
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
82
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
@@ -446,8 +454,10 @@
446
  " return torch.stack(vols)\n",
447
  "\n",
448
  "\n",
449
- "def cv_loss(emb, target=0.22, n_samples=64, n_points=5, batched=True):\n",
450
- " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\"\"\"\n",
 
 
451
  " if emb.shape[0] < n_points:\n",
452
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
453
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
@@ -660,8 +670,9 @@
660
  " l_spread = spread_loss(constellation.anchors)\n",
661
  " ld['spread'] = l_spread\n",
662
  "\n",
663
- " # \u2500\u2500 kNN \u2500\u2500\n",
664
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
665
  "\n",
666
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
667
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
@@ -954,19 +965,22 @@
954
  " w_assign=0.5, w_assign_nce=0.25,\n",
955
  " w_nce_tri=0.5, w_attract=0.25,\n",
956
  " w_cv=0.01, w_spread=0.01,\n",
957
- " cv_batched=True):\n",
958
- " \"\"\"Three-domain cooperative loss.\n",
 
 
 
959
  " Returns:\n",
960
  " total_loss, loss_dict\n",
961
  " \"\"\"\n",
962
  " ld = {}\n",
963
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
964
- " # \u2500\u2500 EXTERNAL \u2500\u2500\n",
965
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
966
  " ld['ce'], ld['acc'] = l_ce, acc\n",
967
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
968
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
969
- " # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
970
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
971
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
972
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
@@ -974,7 +988,7 @@
974
  " output['bridge1'], output['bridge2'],\n",
975
  " output['assign1'], output['assign2'])\n",
976
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
977
- " # \u2500\u2500 INTERNAL \u2500\u2500\n",
978
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
979
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
980
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
@@ -985,12 +999,14 @@
985
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
986
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
987
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
988
- " l_cv = cv_loss(emb1, target=self.cv_target, batched=cv_batched)\n",
 
989
  " ld['cv'] = l_cv\n",
990
  " l_spread = spread_loss(self.constellation.anchors)\n",
991
  " ld['spread'] = l_spread\n",
992
- " # \u2500\u2500 kNN \u2500\u2500\n",
993
- " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
 
994
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
995
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
996
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
@@ -1327,9 +1343,11 @@
1327
  "\n",
1328
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1329
  " num_workers=num_workers, pin_memory=True,\n",
1330
- " drop_last=True, collate_fn=paired_collate)\n",
 
1331
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1332
- " num_workers=num_workers, pin_memory=True)\n",
 
1333
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1334
  " return train_loader, val_loader\n",
1335
  "\n",
@@ -1460,7 +1478,7 @@
1460
  " optimizer.zero_grad()\n",
1461
  "\n",
1462
  " output = model.forward_paired(v1, v2)\n",
1463
- " loss, ld = model.compute_loss(output, labels, **lw)\n",
1464
  "\n",
1465
  " loss.backward()\n",
1466
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
 
75
  "from collections import defaultdict\n",
76
  "\n",
77
  "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
78
+ "\n",
79
+ "# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
80
+ "torch.backends.cuda.matmul.allow_tf32 = True\n",
81
+ "torch.backends.cudnn.allow_tf32 = True\n",
82
+ "torch.backends.cudnn.benchmark = True\n",
83
+ "\n",
84
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
85
  "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
86
+ "if device.type == \"cuda\":\n",
87
+ " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
88
  "\n",
89
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
90
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
 
454
  " return torch.stack(vols)\n",
455
  "\n",
456
  "\n",
457
+ "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
458
+ " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
459
+ " Default n_samples=32 for training speed (141x faster than sequential).\n",
460
+ " Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
461
  " if emb.shape[0] < n_points:\n",
462
  " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
463
  " vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
 
670
  " l_spread = spread_loss(constellation.anchors)\n",
671
  " ld['spread'] = l_spread\n",
672
  "\n",
673
+ " # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
674
+ " if targets is not None and emb1.shape[0] <= 512:\n",
675
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
676
  "\n",
677
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
678
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
 
965
  " w_assign=0.5, w_assign_nce=0.25,\n",
966
  " w_nce_tri=0.5, w_attract=0.25,\n",
967
  " w_cv=0.01, w_spread=0.01,\n",
968
+ " cv_batched=True, compute_knn=False):\n",
969
+ " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
970
+ " Args:\n",
971
+ " compute_knn: if False (default), skip kNN during training for speed.\n",
972
+ " Set True during validation or every N steps.\n",
973
  " Returns:\n",
974
  " total_loss, loss_dict\n",
975
  " \"\"\"\n",
976
  " ld = {}\n",
977
  " emb1, emb2 = output['embedding'], output['embedding_aug']\n",
978
+ " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
979
  " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
980
  " ld['ce'], ld['acc'] = l_ce, acc\n",
981
  " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
982
  " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
983
+ " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
984
  " l_nce_pw, nce_pw_acc = nce_loss(\n",
985
  " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
986
  " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
 
988
  " output['bridge1'], output['bridge2'],\n",
989
  " output['assign1'], output['assign2'])\n",
990
  " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
991
+ " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
992
  " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
993
  " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
994
  " l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
 
999
  " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
1000
  " l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
1001
  " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
1002
+ " # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
1003
+ " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
1004
  " ld['cv'] = l_cv\n",
1005
  " l_spread = spread_loss(self.constellation.anchors)\n",
1006
  " ld['spread'] = l_spread\n",
1007
+ " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
1008
+ " if compute_knn:\n",
1009
+ " ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
1010
  " # \u2500\u2500 TOTAL \u2500\u2500\n",
1011
  " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
1012
  " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
 
1343
  "\n",
1344
  " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
1345
  " num_workers=num_workers, pin_memory=True,\n",
1346
+ " drop_last=True, collate_fn=paired_collate,\n",
1347
+ " persistent_workers=(num_workers > 0))\n",
1348
  " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
1349
+ " num_workers=num_workers, pin_memory=True,\n",
1350
+ " persistent_workers=(num_workers > 0))\n",
1351
  " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
1352
  " return train_loader, val_loader\n",
1353
  "\n",
 
1478
  " optimizer.zero_grad()\n",
1479
  "\n",
1480
  " output = model.forward_paired(v1, v2)\n",
1481
+ " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
1482
  "\n",
1483
  " loss.backward()\n",
1484
  " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",