Upload 7 files
Browse files- spectral/notebooks/experiment_1_signal_decomposition.ipynb +33 -15
- spectral/notebooks/experiment_2_manifold_structures.ipynb +33 -15
- spectral/notebooks/experiment_3_compact_representations.ipynb +33 -15
- spectral/notebooks/experiment_4_invertible_transforms.ipynb +33 -15
- spectral/notebooks/experiment_5_matrix_decompositions.ipynb +33 -15
- spectral/notebooks/experiment_6_losses_and_anchors.ipynb +33 -15
- spectral/notebooks/experiment_7_composite_pipelines.ipynb +33 -15
spectral/notebooks/experiment_1_signal_decomposition.ipynb
CHANGED
|
@@ -83,8 +83,16 @@
|
|
| 83 |
"from collections import defaultdict\n",
|
| 84 |
"\n",
|
| 85 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 87 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
|
|
|
|
|
|
| 88 |
"\n",
|
| 89 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 90 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
@@ -454,8 +462,10 @@
|
|
| 454 |
" return torch.stack(vols)\n",
|
| 455 |
"\n",
|
| 456 |
"\n",
|
| 457 |
-
"def cv_loss(emb, target=0.22, n_samples=
|
| 458 |
-
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\
|
|
|
|
|
|
|
| 459 |
" if emb.shape[0] < n_points:\n",
|
| 460 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 461 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
@@ -668,8 +678,9 @@
|
|
| 668 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 669 |
" ld['spread'] = l_spread\n",
|
| 670 |
"\n",
|
| 671 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 672 |
-
"
|
|
|
|
| 673 |
"\n",
|
| 674 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 675 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
@@ -962,19 +973,22 @@
|
|
| 962 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 963 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 964 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 965 |
-
" cv_batched=True):\n",
|
| 966 |
-
" \"\"\"Three-domain cooperative loss.\n",
|
|
|
|
|
|
|
|
|
|
| 967 |
" Returns:\n",
|
| 968 |
" total_loss, loss_dict\n",
|
| 969 |
" \"\"\"\n",
|
| 970 |
" ld = {}\n",
|
| 971 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 972 |
-
" # \u2500\u2500 EXTERNAL \u2500\u2500\n",
|
| 973 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 974 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 975 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 976 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 977 |
-
" # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
|
| 978 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 979 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 980 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
@@ -982,7 +996,7 @@
|
|
| 982 |
" output['bridge1'], output['bridge2'],\n",
|
| 983 |
" output['assign1'], output['assign2'])\n",
|
| 984 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 985 |
-
" # \u2500\u2500 INTERNAL \u2500\u2500\n",
|
| 986 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 987 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 988 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
@@ -993,12 +1007,14 @@
|
|
| 993 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 994 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 995 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 996 |
-
"
|
|
|
|
| 997 |
" ld['cv'] = l_cv\n",
|
| 998 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 999 |
" ld['spread'] = l_spread\n",
|
| 1000 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 1001 |
-
"
|
|
|
|
| 1002 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1003 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1004 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
@@ -1335,9 +1351,11 @@
|
|
| 1335 |
"\n",
|
| 1336 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1337 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1338 |
-
" drop_last=True, collate_fn=paired_collate
|
|
|
|
| 1339 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1340 |
-
" num_workers=num_workers, pin_memory=True
|
|
|
|
| 1341 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1342 |
" return train_loader, val_loader\n",
|
| 1343 |
"\n",
|
|
@@ -1468,7 +1486,7 @@
|
|
| 1468 |
" optimizer.zero_grad()\n",
|
| 1469 |
"\n",
|
| 1470 |
" output = model.forward_paired(v1, v2)\n",
|
| 1471 |
-
" loss, ld = model.compute_loss(output, labels, **lw)\n",
|
| 1472 |
"\n",
|
| 1473 |
" loss.backward()\n",
|
| 1474 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
|
|
| 83 |
"from collections import defaultdict\n",
|
| 84 |
"\n",
|
| 85 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
|
| 88 |
+
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 89 |
+
"torch.backends.cudnn.allow_tf32 = True\n",
|
| 90 |
+
"torch.backends.cudnn.benchmark = True\n",
|
| 91 |
+
"\n",
|
| 92 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 93 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
| 94 |
+
"if device.type == \"cuda\":\n",
|
| 95 |
+
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
|
| 96 |
"\n",
|
| 97 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 98 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
|
|
| 462 |
" return torch.stack(vols)\n",
|
| 463 |
"\n",
|
| 464 |
"\n",
|
| 465 |
+
"def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
|
| 466 |
+
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
|
| 467 |
+
" Default n_samples=32 for training speed (141x faster than sequential).\n",
|
| 468 |
+
" Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
|
| 469 |
" if emb.shape[0] < n_points:\n",
|
| 470 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 471 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
|
|
| 678 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 679 |
" ld['spread'] = l_spread\n",
|
| 680 |
"\n",
|
| 681 |
+
" # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
|
| 682 |
+
" if targets is not None and emb1.shape[0] <= 512:\n",
|
| 683 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 684 |
"\n",
|
| 685 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 686 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
|
|
| 973 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 974 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 975 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 976 |
+
" cv_batched=True, compute_knn=False):\n",
|
| 977 |
+
" \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
|
| 978 |
+
" Args:\n",
|
| 979 |
+
" compute_knn: if False (default), skip kNN during training for speed.\n",
|
| 980 |
+
" Set True during validation or every N steps.\n",
|
| 981 |
" Returns:\n",
|
| 982 |
" total_loss, loss_dict\n",
|
| 983 |
" \"\"\"\n",
|
| 984 |
" ld = {}\n",
|
| 985 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 986 |
+
" # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
|
| 987 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 988 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 989 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 990 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 991 |
+
" # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
|
| 992 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 993 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 994 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
|
|
| 996 |
" output['bridge1'], output['bridge2'],\n",
|
| 997 |
" output['assign1'], output['assign2'])\n",
|
| 998 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 999 |
+
" # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
|
| 1000 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 1001 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 1002 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
|
|
| 1007 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 1008 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 1009 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 1010 |
+
" # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
|
| 1011 |
+
" l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
|
| 1012 |
" ld['cv'] = l_cv\n",
|
| 1013 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 1014 |
" ld['spread'] = l_spread\n",
|
| 1015 |
+
" # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
|
| 1016 |
+
" if compute_knn:\n",
|
| 1017 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 1018 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1019 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1020 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
|
|
| 1351 |
"\n",
|
| 1352 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1353 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1354 |
+
" drop_last=True, collate_fn=paired_collate,\n",
|
| 1355 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1356 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1357 |
+
" num_workers=num_workers, pin_memory=True,\n",
|
| 1358 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1359 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1360 |
" return train_loader, val_loader\n",
|
| 1361 |
"\n",
|
|
|
|
| 1486 |
" optimizer.zero_grad()\n",
|
| 1487 |
"\n",
|
| 1488 |
" output = model.forward_paired(v1, v2)\n",
|
| 1489 |
+
" loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
|
| 1490 |
"\n",
|
| 1491 |
" loss.backward()\n",
|
| 1492 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
spectral/notebooks/experiment_2_manifold_structures.ipynb
CHANGED
|
@@ -75,8 +75,16 @@
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 79 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
|
|
|
|
|
|
| 80 |
"\n",
|
| 81 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 82 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
@@ -446,8 +454,10 @@
|
|
| 446 |
" return torch.stack(vols)\n",
|
| 447 |
"\n",
|
| 448 |
"\n",
|
| 449 |
-
"def cv_loss(emb, target=0.22, n_samples=
|
| 450 |
-
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\
|
|
|
|
|
|
|
| 451 |
" if emb.shape[0] < n_points:\n",
|
| 452 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 453 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
@@ -660,8 +670,9 @@
|
|
| 660 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 661 |
" ld['spread'] = l_spread\n",
|
| 662 |
"\n",
|
| 663 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 664 |
-
"
|
|
|
|
| 665 |
"\n",
|
| 666 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 667 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
@@ -954,19 +965,22 @@
|
|
| 954 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 955 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 956 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 957 |
-
" cv_batched=True):\n",
|
| 958 |
-
" \"\"\"Three-domain cooperative loss.\n",
|
|
|
|
|
|
|
|
|
|
| 959 |
" Returns:\n",
|
| 960 |
" total_loss, loss_dict\n",
|
| 961 |
" \"\"\"\n",
|
| 962 |
" ld = {}\n",
|
| 963 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 964 |
-
" # \u2500\u2500 EXTERNAL \u2500\u2500\n",
|
| 965 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 966 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 967 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 968 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 969 |
-
" # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
|
| 970 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 971 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 972 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
@@ -974,7 +988,7 @@
|
|
| 974 |
" output['bridge1'], output['bridge2'],\n",
|
| 975 |
" output['assign1'], output['assign2'])\n",
|
| 976 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 977 |
-
" # \u2500\u2500 INTERNAL \u2500\u2500\n",
|
| 978 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 979 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 980 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
@@ -985,12 +999,14 @@
|
|
| 985 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 986 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 987 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 988 |
-
"
|
|
|
|
| 989 |
" ld['cv'] = l_cv\n",
|
| 990 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 991 |
" ld['spread'] = l_spread\n",
|
| 992 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 993 |
-
"
|
|
|
|
| 994 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 995 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 996 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
@@ -1327,9 +1343,11 @@
|
|
| 1327 |
"\n",
|
| 1328 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1329 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1330 |
-
" drop_last=True, collate_fn=paired_collate
|
|
|
|
| 1331 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1332 |
-
" num_workers=num_workers, pin_memory=True
|
|
|
|
| 1333 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1334 |
" return train_loader, val_loader\n",
|
| 1335 |
"\n",
|
|
@@ -1460,7 +1478,7 @@
|
|
| 1460 |
" optimizer.zero_grad()\n",
|
| 1461 |
"\n",
|
| 1462 |
" output = model.forward_paired(v1, v2)\n",
|
| 1463 |
-
" loss, ld = model.compute_loss(output, labels, **lw)\n",
|
| 1464 |
"\n",
|
| 1465 |
" loss.backward()\n",
|
| 1466 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
|
| 80 |
+
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 81 |
+
"torch.backends.cudnn.allow_tf32 = True\n",
|
| 82 |
+
"torch.backends.cudnn.benchmark = True\n",
|
| 83 |
+
"\n",
|
| 84 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 85 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
| 86 |
+
"if device.type == \"cuda\":\n",
|
| 87 |
+
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
|
| 88 |
"\n",
|
| 89 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 90 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
|
|
| 454 |
" return torch.stack(vols)\n",
|
| 455 |
"\n",
|
| 456 |
"\n",
|
| 457 |
+
"def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
|
| 458 |
+
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
|
| 459 |
+
" Default n_samples=32 for training speed (141x faster than sequential).\n",
|
| 460 |
+
" Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
|
| 461 |
" if emb.shape[0] < n_points:\n",
|
| 462 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 463 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
|
|
| 670 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 671 |
" ld['spread'] = l_spread\n",
|
| 672 |
"\n",
|
| 673 |
+
" # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
|
| 674 |
+
" if targets is not None and emb1.shape[0] <= 512:\n",
|
| 675 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 676 |
"\n",
|
| 677 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 678 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
|
|
| 965 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 966 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 967 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 968 |
+
" cv_batched=True, compute_knn=False):\n",
|
| 969 |
+
" \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
|
| 970 |
+
" Args:\n",
|
| 971 |
+
" compute_knn: if False (default), skip kNN during training for speed.\n",
|
| 972 |
+
" Set True during validation or every N steps.\n",
|
| 973 |
" Returns:\n",
|
| 974 |
" total_loss, loss_dict\n",
|
| 975 |
" \"\"\"\n",
|
| 976 |
" ld = {}\n",
|
| 977 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 978 |
+
" # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
|
| 979 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 980 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 981 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 982 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 983 |
+
" # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
|
| 984 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 985 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 986 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
|
|
| 988 |
" output['bridge1'], output['bridge2'],\n",
|
| 989 |
" output['assign1'], output['assign2'])\n",
|
| 990 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 991 |
+
" # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
|
| 992 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 993 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 994 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
|
|
| 999 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 1000 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 1001 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 1002 |
+
" # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
|
| 1003 |
+
" l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
|
| 1004 |
" ld['cv'] = l_cv\n",
|
| 1005 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 1006 |
" ld['spread'] = l_spread\n",
|
| 1007 |
+
" # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
|
| 1008 |
+
" if compute_knn:\n",
|
| 1009 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 1010 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1011 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1012 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
|
|
| 1343 |
"\n",
|
| 1344 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1345 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1346 |
+
" drop_last=True, collate_fn=paired_collate,\n",
|
| 1347 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1348 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1349 |
+
" num_workers=num_workers, pin_memory=True,\n",
|
| 1350 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1351 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1352 |
" return train_loader, val_loader\n",
|
| 1353 |
"\n",
|
|
|
|
| 1478 |
" optimizer.zero_grad()\n",
|
| 1479 |
"\n",
|
| 1480 |
" output = model.forward_paired(v1, v2)\n",
|
| 1481 |
+
" loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
|
| 1482 |
"\n",
|
| 1483 |
" loss.backward()\n",
|
| 1484 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
spectral/notebooks/experiment_3_compact_representations.ipynb
CHANGED
|
@@ -74,8 +74,16 @@
|
|
| 74 |
"from collections import defaultdict\n",
|
| 75 |
"\n",
|
| 76 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 78 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
|
|
|
|
|
|
| 79 |
"\n",
|
| 80 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 81 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
@@ -445,8 +453,10 @@
|
|
| 445 |
" return torch.stack(vols)\n",
|
| 446 |
"\n",
|
| 447 |
"\n",
|
| 448 |
-
"def cv_loss(emb, target=0.22, n_samples=
|
| 449 |
-
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\
|
|
|
|
|
|
|
| 450 |
" if emb.shape[0] < n_points:\n",
|
| 451 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 452 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
@@ -659,8 +669,9 @@
|
|
| 659 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 660 |
" ld['spread'] = l_spread\n",
|
| 661 |
"\n",
|
| 662 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 663 |
-
"
|
|
|
|
| 664 |
"\n",
|
| 665 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 666 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
@@ -953,19 +964,22 @@
|
|
| 953 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 954 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 955 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 956 |
-
" cv_batched=True):\n",
|
| 957 |
-
" \"\"\"Three-domain cooperative loss.\n",
|
|
|
|
|
|
|
|
|
|
| 958 |
" Returns:\n",
|
| 959 |
" total_loss, loss_dict\n",
|
| 960 |
" \"\"\"\n",
|
| 961 |
" ld = {}\n",
|
| 962 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 963 |
-
" # \u2500\u2500 EXTERNAL \u2500\u2500\n",
|
| 964 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 965 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 966 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 967 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 968 |
-
" # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
|
| 969 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 970 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 971 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
@@ -973,7 +987,7 @@
|
|
| 973 |
" output['bridge1'], output['bridge2'],\n",
|
| 974 |
" output['assign1'], output['assign2'])\n",
|
| 975 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 976 |
-
" # \u2500\u2500 INTERNAL \u2500\u2500\n",
|
| 977 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 978 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 979 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
@@ -984,12 +998,14 @@
|
|
| 984 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 985 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 986 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 987 |
-
"
|
|
|
|
| 988 |
" ld['cv'] = l_cv\n",
|
| 989 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 990 |
" ld['spread'] = l_spread\n",
|
| 991 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 992 |
-
"
|
|
|
|
| 993 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 994 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 995 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
@@ -1326,9 +1342,11 @@
|
|
| 1326 |
"\n",
|
| 1327 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1328 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1329 |
-
" drop_last=True, collate_fn=paired_collate
|
|
|
|
| 1330 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1331 |
-
" num_workers=num_workers, pin_memory=True
|
|
|
|
| 1332 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1333 |
" return train_loader, val_loader\n",
|
| 1334 |
"\n",
|
|
@@ -1459,7 +1477,7 @@
|
|
| 1459 |
" optimizer.zero_grad()\n",
|
| 1460 |
"\n",
|
| 1461 |
" output = model.forward_paired(v1, v2)\n",
|
| 1462 |
-
" loss, ld = model.compute_loss(output, labels, **lw)\n",
|
| 1463 |
"\n",
|
| 1464 |
" loss.backward()\n",
|
| 1465 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
|
|
| 74 |
"from collections import defaultdict\n",
|
| 75 |
"\n",
|
| 76 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
|
| 79 |
+
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 80 |
+
"torch.backends.cudnn.allow_tf32 = True\n",
|
| 81 |
+
"torch.backends.cudnn.benchmark = True\n",
|
| 82 |
+
"\n",
|
| 83 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 84 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
| 85 |
+
"if device.type == \"cuda\":\n",
|
| 86 |
+
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
|
| 87 |
"\n",
|
| 88 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 89 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
|
|
| 453 |
" return torch.stack(vols)\n",
|
| 454 |
"\n",
|
| 455 |
"\n",
|
| 456 |
+
"def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
|
| 457 |
+
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
|
| 458 |
+
" Default n_samples=32 for training speed (141x faster than sequential).\n",
|
| 459 |
+
" Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
|
| 460 |
" if emb.shape[0] < n_points:\n",
|
| 461 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 462 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
|
|
| 669 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 670 |
" ld['spread'] = l_spread\n",
|
| 671 |
"\n",
|
| 672 |
+
" # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
|
| 673 |
+
" if targets is not None and emb1.shape[0] <= 512:\n",
|
| 674 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 675 |
"\n",
|
| 676 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 677 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
|
|
| 964 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 965 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 966 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 967 |
+
" cv_batched=True, compute_knn=False):\n",
|
| 968 |
+
" \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
|
| 969 |
+
" Args:\n",
|
| 970 |
+
" compute_knn: if False (default), skip kNN during training for speed.\n",
|
| 971 |
+
" Set True during validation or every N steps.\n",
|
| 972 |
" Returns:\n",
|
| 973 |
" total_loss, loss_dict\n",
|
| 974 |
" \"\"\"\n",
|
| 975 |
" ld = {}\n",
|
| 976 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 977 |
+
" # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
|
| 978 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 979 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 980 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 981 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 982 |
+
" # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
|
| 983 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 984 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 985 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
|
|
| 987 |
" output['bridge1'], output['bridge2'],\n",
|
| 988 |
" output['assign1'], output['assign2'])\n",
|
| 989 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 990 |
+
" # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
|
| 991 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 992 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 993 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
|
|
| 998 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 999 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 1000 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 1001 |
+
" # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
|
| 1002 |
+
" l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
|
| 1003 |
" ld['cv'] = l_cv\n",
|
| 1004 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 1005 |
" ld['spread'] = l_spread\n",
|
| 1006 |
+
" # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
|
| 1007 |
+
" if compute_knn:\n",
|
| 1008 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 1009 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1010 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1011 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
|
|
| 1342 |
"\n",
|
| 1343 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1344 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1345 |
+
" drop_last=True, collate_fn=paired_collate,\n",
|
| 1346 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1347 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1348 |
+
" num_workers=num_workers, pin_memory=True,\n",
|
| 1349 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1350 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1351 |
" return train_loader, val_loader\n",
|
| 1352 |
"\n",
|
|
|
|
| 1477 |
" optimizer.zero_grad()\n",
|
| 1478 |
"\n",
|
| 1479 |
" output = model.forward_paired(v1, v2)\n",
|
| 1480 |
+
" loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
|
| 1481 |
"\n",
|
| 1482 |
" loss.backward()\n",
|
| 1483 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
spectral/notebooks/experiment_4_invertible_transforms.ipynb
CHANGED
|
@@ -75,8 +75,16 @@
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 79 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
|
|
|
|
|
|
| 80 |
"\n",
|
| 81 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 82 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
@@ -446,8 +454,10 @@
|
|
| 446 |
" return torch.stack(vols)\n",
|
| 447 |
"\n",
|
| 448 |
"\n",
|
| 449 |
-
"def cv_loss(emb, target=0.22, n_samples=
|
| 450 |
-
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\
|
|
|
|
|
|
|
| 451 |
" if emb.shape[0] < n_points:\n",
|
| 452 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 453 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
@@ -660,8 +670,9 @@
|
|
| 660 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 661 |
" ld['spread'] = l_spread\n",
|
| 662 |
"\n",
|
| 663 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 664 |
-
"
|
|
|
|
| 665 |
"\n",
|
| 666 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 667 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
@@ -954,19 +965,22 @@
|
|
| 954 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 955 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 956 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 957 |
-
" cv_batched=True):\n",
|
| 958 |
-
" \"\"\"Three-domain cooperative loss.\n",
|
|
|
|
|
|
|
|
|
|
| 959 |
" Returns:\n",
|
| 960 |
" total_loss, loss_dict\n",
|
| 961 |
" \"\"\"\n",
|
| 962 |
" ld = {}\n",
|
| 963 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 964 |
-
" # \u2500\u2500 EXTERNAL \u2500\u2500\n",
|
| 965 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 966 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 967 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 968 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 969 |
-
" # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
|
| 970 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 971 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 972 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
@@ -974,7 +988,7 @@
|
|
| 974 |
" output['bridge1'], output['bridge2'],\n",
|
| 975 |
" output['assign1'], output['assign2'])\n",
|
| 976 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 977 |
-
" # \u2500\u2500 INTERNAL \u2500\u2500\n",
|
| 978 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 979 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 980 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
@@ -985,12 +999,14 @@
|
|
| 985 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 986 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 987 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 988 |
-
"
|
|
|
|
| 989 |
" ld['cv'] = l_cv\n",
|
| 990 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 991 |
" ld['spread'] = l_spread\n",
|
| 992 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 993 |
-
"
|
|
|
|
| 994 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 995 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 996 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
@@ -1327,9 +1343,11 @@
|
|
| 1327 |
"\n",
|
| 1328 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1329 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1330 |
-
" drop_last=True, collate_fn=paired_collate
|
|
|
|
| 1331 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1332 |
-
" num_workers=num_workers, pin_memory=True
|
|
|
|
| 1333 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1334 |
" return train_loader, val_loader\n",
|
| 1335 |
"\n",
|
|
@@ -1460,7 +1478,7 @@
|
|
| 1460 |
" optimizer.zero_grad()\n",
|
| 1461 |
"\n",
|
| 1462 |
" output = model.forward_paired(v1, v2)\n",
|
| 1463 |
-
" loss, ld = model.compute_loss(output, labels, **lw)\n",
|
| 1464 |
"\n",
|
| 1465 |
" loss.backward()\n",
|
| 1466 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
|
| 80 |
+
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 81 |
+
"torch.backends.cudnn.allow_tf32 = True\n",
|
| 82 |
+
"torch.backends.cudnn.benchmark = True\n",
|
| 83 |
+
"\n",
|
| 84 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 85 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
| 86 |
+
"if device.type == \"cuda\":\n",
|
| 87 |
+
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
|
| 88 |
"\n",
|
| 89 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 90 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
|
|
| 454 |
" return torch.stack(vols)\n",
|
| 455 |
"\n",
|
| 456 |
"\n",
|
| 457 |
+
"def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
|
| 458 |
+
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
|
| 459 |
+
" Default n_samples=32 for training speed (141x faster than sequential).\n",
|
| 460 |
+
" Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
|
| 461 |
" if emb.shape[0] < n_points:\n",
|
| 462 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 463 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
|
|
| 670 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 671 |
" ld['spread'] = l_spread\n",
|
| 672 |
"\n",
|
| 673 |
+
" # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
|
| 674 |
+
" if targets is not None and emb1.shape[0] <= 512:\n",
|
| 675 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 676 |
"\n",
|
| 677 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 678 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
|
|
| 965 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 966 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 967 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 968 |
+
" cv_batched=True, compute_knn=False):\n",
|
| 969 |
+
" \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
|
| 970 |
+
" Args:\n",
|
| 971 |
+
" compute_knn: if False (default), skip kNN during training for speed.\n",
|
| 972 |
+
" Set True during validation or every N steps.\n",
|
| 973 |
" Returns:\n",
|
| 974 |
" total_loss, loss_dict\n",
|
| 975 |
" \"\"\"\n",
|
| 976 |
" ld = {}\n",
|
| 977 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 978 |
+
" # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
|
| 979 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 980 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 981 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 982 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 983 |
+
" # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
|
| 984 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 985 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 986 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
|
|
| 988 |
" output['bridge1'], output['bridge2'],\n",
|
| 989 |
" output['assign1'], output['assign2'])\n",
|
| 990 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 991 |
+
" # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
|
| 992 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 993 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 994 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
|
|
| 999 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 1000 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 1001 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 1002 |
+
" # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
|
| 1003 |
+
" l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
|
| 1004 |
" ld['cv'] = l_cv\n",
|
| 1005 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 1006 |
" ld['spread'] = l_spread\n",
|
| 1007 |
+
" # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
|
| 1008 |
+
" if compute_knn:\n",
|
| 1009 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 1010 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1011 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1012 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
|
|
| 1343 |
"\n",
|
| 1344 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1345 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1346 |
+
" drop_last=True, collate_fn=paired_collate,\n",
|
| 1347 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1348 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1349 |
+
" num_workers=num_workers, pin_memory=True,\n",
|
| 1350 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1351 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1352 |
" return train_loader, val_loader\n",
|
| 1353 |
"\n",
|
|
|
|
| 1478 |
" optimizer.zero_grad()\n",
|
| 1479 |
"\n",
|
| 1480 |
" output = model.forward_paired(v1, v2)\n",
|
| 1481 |
+
" loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
|
| 1482 |
"\n",
|
| 1483 |
" loss.backward()\n",
|
| 1484 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
spectral/notebooks/experiment_5_matrix_decompositions.ipynb
CHANGED
|
@@ -75,8 +75,16 @@
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 79 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
|
|
|
|
|
|
| 80 |
"\n",
|
| 81 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 82 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
@@ -446,8 +454,10 @@
|
|
| 446 |
" return torch.stack(vols)\n",
|
| 447 |
"\n",
|
| 448 |
"\n",
|
| 449 |
-
"def cv_loss(emb, target=0.22, n_samples=
|
| 450 |
-
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\
|
|
|
|
|
|
|
| 451 |
" if emb.shape[0] < n_points:\n",
|
| 452 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 453 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
@@ -660,8 +670,9 @@
|
|
| 660 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 661 |
" ld['spread'] = l_spread\n",
|
| 662 |
"\n",
|
| 663 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 664 |
-
"
|
|
|
|
| 665 |
"\n",
|
| 666 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 667 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
@@ -954,19 +965,22 @@
|
|
| 954 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 955 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 956 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 957 |
-
" cv_batched=True):\n",
|
| 958 |
-
" \"\"\"Three-domain cooperative loss.\n",
|
|
|
|
|
|
|
|
|
|
| 959 |
" Returns:\n",
|
| 960 |
" total_loss, loss_dict\n",
|
| 961 |
" \"\"\"\n",
|
| 962 |
" ld = {}\n",
|
| 963 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 964 |
-
" # \u2500\u2500 EXTERNAL \u2500\u2500\n",
|
| 965 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 966 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 967 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 968 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 969 |
-
" # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
|
| 970 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 971 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 972 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
@@ -974,7 +988,7 @@
|
|
| 974 |
" output['bridge1'], output['bridge2'],\n",
|
| 975 |
" output['assign1'], output['assign2'])\n",
|
| 976 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 977 |
-
" # \u2500\u2500 INTERNAL \u2500\u2500\n",
|
| 978 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 979 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 980 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
@@ -985,12 +999,14 @@
|
|
| 985 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 986 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 987 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 988 |
-
"
|
|
|
|
| 989 |
" ld['cv'] = l_cv\n",
|
| 990 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 991 |
" ld['spread'] = l_spread\n",
|
| 992 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 993 |
-
"
|
|
|
|
| 994 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 995 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 996 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
@@ -1327,9 +1343,11 @@
|
|
| 1327 |
"\n",
|
| 1328 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1329 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1330 |
-
" drop_last=True, collate_fn=paired_collate
|
|
|
|
| 1331 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1332 |
-
" num_workers=num_workers, pin_memory=True
|
|
|
|
| 1333 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1334 |
" return train_loader, val_loader\n",
|
| 1335 |
"\n",
|
|
@@ -1460,7 +1478,7 @@
|
|
| 1460 |
" optimizer.zero_grad()\n",
|
| 1461 |
"\n",
|
| 1462 |
" output = model.forward_paired(v1, v2)\n",
|
| 1463 |
-
" loss, ld = model.compute_loss(output, labels, **lw)\n",
|
| 1464 |
"\n",
|
| 1465 |
" loss.backward()\n",
|
| 1466 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
|
| 80 |
+
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 81 |
+
"torch.backends.cudnn.allow_tf32 = True\n",
|
| 82 |
+
"torch.backends.cudnn.benchmark = True\n",
|
| 83 |
+
"\n",
|
| 84 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 85 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
| 86 |
+
"if device.type == \"cuda\":\n",
|
| 87 |
+
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
|
| 88 |
"\n",
|
| 89 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 90 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
|
|
| 454 |
" return torch.stack(vols)\n",
|
| 455 |
"\n",
|
| 456 |
"\n",
|
| 457 |
+
"def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
|
| 458 |
+
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
|
| 459 |
+
" Default n_samples=32 for training speed (141x faster than sequential).\n",
|
| 460 |
+
" Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
|
| 461 |
" if emb.shape[0] < n_points:\n",
|
| 462 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 463 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
|
|
| 670 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 671 |
" ld['spread'] = l_spread\n",
|
| 672 |
"\n",
|
| 673 |
+
" # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
|
| 674 |
+
" if targets is not None and emb1.shape[0] <= 512:\n",
|
| 675 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 676 |
"\n",
|
| 677 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 678 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
|
|
| 965 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 966 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 967 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 968 |
+
" cv_batched=True, compute_knn=False):\n",
|
| 969 |
+
" \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
|
| 970 |
+
" Args:\n",
|
| 971 |
+
" compute_knn: if False (default), skip kNN during training for speed.\n",
|
| 972 |
+
" Set True during validation or every N steps.\n",
|
| 973 |
" Returns:\n",
|
| 974 |
" total_loss, loss_dict\n",
|
| 975 |
" \"\"\"\n",
|
| 976 |
" ld = {}\n",
|
| 977 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 978 |
+
" # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
|
| 979 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 980 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 981 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 982 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 983 |
+
" # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
|
| 984 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 985 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 986 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
|
|
| 988 |
" output['bridge1'], output['bridge2'],\n",
|
| 989 |
" output['assign1'], output['assign2'])\n",
|
| 990 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 991 |
+
" # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
|
| 992 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 993 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 994 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
|
|
| 999 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 1000 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 1001 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 1002 |
+
" # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
|
| 1003 |
+
" l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
|
| 1004 |
" ld['cv'] = l_cv\n",
|
| 1005 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 1006 |
" ld['spread'] = l_spread\n",
|
| 1007 |
+
" # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
|
| 1008 |
+
" if compute_knn:\n",
|
| 1009 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 1010 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1011 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1012 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
|
|
| 1343 |
"\n",
|
| 1344 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1345 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1346 |
+
" drop_last=True, collate_fn=paired_collate,\n",
|
| 1347 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1348 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1349 |
+
" num_workers=num_workers, pin_memory=True,\n",
|
| 1350 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1351 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1352 |
" return train_loader, val_loader\n",
|
| 1353 |
"\n",
|
|
|
|
| 1478 |
" optimizer.zero_grad()\n",
|
| 1479 |
"\n",
|
| 1480 |
" output = model.forward_paired(v1, v2)\n",
|
| 1481 |
+
" loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
|
| 1482 |
"\n",
|
| 1483 |
" loss.backward()\n",
|
| 1484 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
spectral/notebooks/experiment_6_losses_and_anchors.ipynb
CHANGED
|
@@ -77,8 +77,16 @@
|
|
| 77 |
"from collections import defaultdict\n",
|
| 78 |
"\n",
|
| 79 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 81 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
|
|
|
|
|
|
| 82 |
"\n",
|
| 83 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 84 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
@@ -448,8 +456,10 @@
|
|
| 448 |
" return torch.stack(vols)\n",
|
| 449 |
"\n",
|
| 450 |
"\n",
|
| 451 |
-
"def cv_loss(emb, target=0.22, n_samples=
|
| 452 |
-
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\
|
|
|
|
|
|
|
| 453 |
" if emb.shape[0] < n_points:\n",
|
| 454 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 455 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
@@ -662,8 +672,9 @@
|
|
| 662 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 663 |
" ld['spread'] = l_spread\n",
|
| 664 |
"\n",
|
| 665 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 666 |
-
"
|
|
|
|
| 667 |
"\n",
|
| 668 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 669 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
@@ -956,19 +967,22 @@
|
|
| 956 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 957 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 958 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 959 |
-
" cv_batched=True):\n",
|
| 960 |
-
" \"\"\"Three-domain cooperative loss.\n",
|
|
|
|
|
|
|
|
|
|
| 961 |
" Returns:\n",
|
| 962 |
" total_loss, loss_dict\n",
|
| 963 |
" \"\"\"\n",
|
| 964 |
" ld = {}\n",
|
| 965 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 966 |
-
" # \u2500\u2500 EXTERNAL \u2500\u2500\n",
|
| 967 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 968 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 969 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 970 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 971 |
-
" # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
|
| 972 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 973 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 974 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
@@ -976,7 +990,7 @@
|
|
| 976 |
" output['bridge1'], output['bridge2'],\n",
|
| 977 |
" output['assign1'], output['assign2'])\n",
|
| 978 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 979 |
-
" # \u2500\u2500 INTERNAL \u2500\u2500\n",
|
| 980 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 981 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 982 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
@@ -987,12 +1001,14 @@
|
|
| 987 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 988 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 989 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 990 |
-
"
|
|
|
|
| 991 |
" ld['cv'] = l_cv\n",
|
| 992 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 993 |
" ld['spread'] = l_spread\n",
|
| 994 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 995 |
-
"
|
|
|
|
| 996 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 997 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 998 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
@@ -1329,9 +1345,11 @@
|
|
| 1329 |
"\n",
|
| 1330 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1331 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1332 |
-
" drop_last=True, collate_fn=paired_collate
|
|
|
|
| 1333 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1334 |
-
" num_workers=num_workers, pin_memory=True
|
|
|
|
| 1335 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1336 |
" return train_loader, val_loader\n",
|
| 1337 |
"\n",
|
|
@@ -1462,7 +1480,7 @@
|
|
| 1462 |
" optimizer.zero_grad()\n",
|
| 1463 |
"\n",
|
| 1464 |
" output = model.forward_paired(v1, v2)\n",
|
| 1465 |
-
" loss, ld = model.compute_loss(output, labels, **lw)\n",
|
| 1466 |
"\n",
|
| 1467 |
" loss.backward()\n",
|
| 1468 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
|
|
| 77 |
"from collections import defaultdict\n",
|
| 78 |
"\n",
|
| 79 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
|
| 82 |
+
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 83 |
+
"torch.backends.cudnn.allow_tf32 = True\n",
|
| 84 |
+
"torch.backends.cudnn.benchmark = True\n",
|
| 85 |
+
"\n",
|
| 86 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 87 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
| 88 |
+
"if device.type == \"cuda\":\n",
|
| 89 |
+
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
|
| 90 |
"\n",
|
| 91 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 92 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
|
|
| 456 |
" return torch.stack(vols)\n",
|
| 457 |
"\n",
|
| 458 |
"\n",
|
| 459 |
+
"def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
|
| 460 |
+
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
|
| 461 |
+
" Default n_samples=32 for training speed (141x faster than sequential).\n",
|
| 462 |
+
" Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
|
| 463 |
" if emb.shape[0] < n_points:\n",
|
| 464 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 465 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
|
|
| 672 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 673 |
" ld['spread'] = l_spread\n",
|
| 674 |
"\n",
|
| 675 |
+
" # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
|
| 676 |
+
" if targets is not None and emb1.shape[0] <= 512:\n",
|
| 677 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 678 |
"\n",
|
| 679 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 680 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
|
|
| 967 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 968 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 969 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 970 |
+
" cv_batched=True, compute_knn=False):\n",
|
| 971 |
+
" \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
|
| 972 |
+
" Args:\n",
|
| 973 |
+
" compute_knn: if False (default), skip kNN during training for speed.\n",
|
| 974 |
+
" Set True during validation or every N steps.\n",
|
| 975 |
" Returns:\n",
|
| 976 |
" total_loss, loss_dict\n",
|
| 977 |
" \"\"\"\n",
|
| 978 |
" ld = {}\n",
|
| 979 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 980 |
+
" # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
|
| 981 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 982 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 983 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 984 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 985 |
+
" # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
|
| 986 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 987 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 988 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
|
|
| 990 |
" output['bridge1'], output['bridge2'],\n",
|
| 991 |
" output['assign1'], output['assign2'])\n",
|
| 992 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 993 |
+
" # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
|
| 994 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 995 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 996 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
|
|
| 1001 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 1002 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 1003 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 1004 |
+
" # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
|
| 1005 |
+
" l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
|
| 1006 |
" ld['cv'] = l_cv\n",
|
| 1007 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 1008 |
" ld['spread'] = l_spread\n",
|
| 1009 |
+
" # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
|
| 1010 |
+
" if compute_knn:\n",
|
| 1011 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 1012 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1013 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1014 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
|
|
| 1345 |
"\n",
|
| 1346 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1347 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1348 |
+
" drop_last=True, collate_fn=paired_collate,\n",
|
| 1349 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1350 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1351 |
+
" num_workers=num_workers, pin_memory=True,\n",
|
| 1352 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1353 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1354 |
" return train_loader, val_loader\n",
|
| 1355 |
"\n",
|
|
|
|
| 1480 |
" optimizer.zero_grad()\n",
|
| 1481 |
"\n",
|
| 1482 |
" output = model.forward_paired(v1, v2)\n",
|
| 1483 |
+
" loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
|
| 1484 |
"\n",
|
| 1485 |
" loss.backward()\n",
|
| 1486 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
spectral/notebooks/experiment_7_composite_pipelines.ipynb
CHANGED
|
@@ -75,8 +75,16 @@
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 79 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
|
|
|
|
|
|
| 80 |
"\n",
|
| 81 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 82 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
@@ -446,8 +454,10 @@
|
|
| 446 |
" return torch.stack(vols)\n",
|
| 447 |
"\n",
|
| 448 |
"\n",
|
| 449 |
-
"def cv_loss(emb, target=0.22, n_samples=
|
| 450 |
-
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\
|
|
|
|
|
|
|
| 451 |
" if emb.shape[0] < n_points:\n",
|
| 452 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 453 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
@@ -660,8 +670,9 @@
|
|
| 660 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 661 |
" ld['spread'] = l_spread\n",
|
| 662 |
"\n",
|
| 663 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 664 |
-
"
|
|
|
|
| 665 |
"\n",
|
| 666 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 667 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
@@ -954,19 +965,22 @@
|
|
| 954 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 955 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 956 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 957 |
-
" cv_batched=True):\n",
|
| 958 |
-
" \"\"\"Three-domain cooperative loss.\n",
|
|
|
|
|
|
|
|
|
|
| 959 |
" Returns:\n",
|
| 960 |
" total_loss, loss_dict\n",
|
| 961 |
" \"\"\"\n",
|
| 962 |
" ld = {}\n",
|
| 963 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 964 |
-
" # \u2500\u2500 EXTERNAL \u2500\u2500\n",
|
| 965 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 966 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 967 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 968 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 969 |
-
" # \u2500\u2500 GEOMETRIC \u2500\u2500\n",
|
| 970 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 971 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 972 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
@@ -974,7 +988,7 @@
|
|
| 974 |
" output['bridge1'], output['bridge2'],\n",
|
| 975 |
" output['assign1'], output['assign2'])\n",
|
| 976 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 977 |
-
" # \u2500\u2500 INTERNAL \u2500\u2500\n",
|
| 978 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 979 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 980 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
@@ -985,12 +999,14 @@
|
|
| 985 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 986 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 987 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 988 |
-
"
|
|
|
|
| 989 |
" ld['cv'] = l_cv\n",
|
| 990 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 991 |
" ld['spread'] = l_spread\n",
|
| 992 |
-
" # \u2500\u2500 kNN \u2500\u2500\n",
|
| 993 |
-
"
|
|
|
|
| 994 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 995 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 996 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
@@ -1327,9 +1343,11 @@
|
|
| 1327 |
"\n",
|
| 1328 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1329 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1330 |
-
" drop_last=True, collate_fn=paired_collate
|
|
|
|
| 1331 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1332 |
-
" num_workers=num_workers, pin_memory=True
|
|
|
|
| 1333 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1334 |
" return train_loader, val_loader\n",
|
| 1335 |
"\n",
|
|
@@ -1460,7 +1478,7 @@
|
|
| 1460 |
" optimizer.zero_grad()\n",
|
| 1461 |
"\n",
|
| 1462 |
" output = model.forward_paired(v1, v2)\n",
|
| 1463 |
-
" loss, ld = model.compute_loss(output, labels, **lw)\n",
|
| 1464 |
"\n",
|
| 1465 |
" loss.backward()\n",
|
| 1466 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
|
|
| 75 |
"from collections import defaultdict\n",
|
| 76 |
"\n",
|
| 77 |
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# \u2500\u2500 Performance: TF32 + cudnn benchmark \u2500\u2500\n",
|
| 80 |
+
"torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 81 |
+
"torch.backends.cudnn.allow_tf32 = True\n",
|
| 82 |
+
"torch.backends.cudnn.benchmark = True\n",
|
| 83 |
+
"\n",
|
| 84 |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 85 |
"print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n",
|
| 86 |
+
"if device.type == \"cuda\":\n",
|
| 87 |
+
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}\")\n",
|
| 88 |
"\n",
|
| 89 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 90 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
|
|
|
| 454 |
" return torch.stack(vols)\n",
|
| 455 |
"\n",
|
| 456 |
"\n",
|
| 457 |
+
"def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n",
|
| 458 |
+
" \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2.\n",
|
| 459 |
+
" Default n_samples=32 for training speed (141x faster than sequential).\n",
|
| 460 |
+
" Use n_samples=200 for monitoring/metrics only.\"\"\"\n",
|
| 461 |
" if emb.shape[0] < n_points:\n",
|
| 462 |
" return torch.tensor(0.0, device=emb.device, requires_grad=True)\n",
|
| 463 |
" vols = _batch_pentachoron_volumes(emb, n_samples, n_points) if batched else _sequential_pentachoron_volumes(emb, n_samples, n_points)\n",
|
|
|
|
| 670 |
" l_spread = spread_loss(constellation.anchors)\n",
|
| 671 |
" ld['spread'] = l_spread\n",
|
| 672 |
"\n",
|
| 673 |
+
" # \u2500\u2500 kNN (skip during training for speed \u2014 only compute when explicitly needed) \u2500\u2500\n",
|
| 674 |
+
" if targets is not None and emb1.shape[0] <= 512:\n",
|
| 675 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 676 |
"\n",
|
| 677 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 678 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
|
|
|
| 965 |
" w_assign=0.5, w_assign_nce=0.25,\n",
|
| 966 |
" w_nce_tri=0.5, w_attract=0.25,\n",
|
| 967 |
" w_cv=0.01, w_spread=0.01,\n",
|
| 968 |
+
" cv_batched=True, compute_knn=False):\n",
|
| 969 |
+
" \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n",
|
| 970 |
+
" Args:\n",
|
| 971 |
+
" compute_knn: if False (default), skip kNN during training for speed.\n",
|
| 972 |
+
" Set True during validation or every N steps.\n",
|
| 973 |
" Returns:\n",
|
| 974 |
" total_loss, loss_dict\n",
|
| 975 |
" \"\"\"\n",
|
| 976 |
" ld = {}\n",
|
| 977 |
" emb1, emb2 = output['embedding'], output['embedding_aug']\n",
|
| 978 |
+
" # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n",
|
| 979 |
" l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n",
|
| 980 |
" ld['ce'], ld['acc'] = l_ce, acc\n",
|
| 981 |
" l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n",
|
| 982 |
" ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n",
|
| 983 |
+
" # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n",
|
| 984 |
" l_nce_pw, nce_pw_acc = nce_loss(\n",
|
| 985 |
" output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n",
|
| 986 |
" ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n",
|
|
|
|
| 988 |
" output['bridge1'], output['bridge2'],\n",
|
| 989 |
" output['assign1'], output['assign2'])\n",
|
| 990 |
" ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n",
|
| 991 |
+
" # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n",
|
| 992 |
" l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n",
|
| 993 |
" ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n",
|
| 994 |
" l_assign_nce, assign_nce_acc = assign_nce_loss(\n",
|
|
|
|
| 999 |
" ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n",
|
| 1000 |
" l_attract, nearest_cos = attraction_loss(output['cos1'])\n",
|
| 1001 |
" ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n",
|
| 1002 |
+
" # CV: batched Cayley-Menger, n_samples=32 for training speed\n",
|
| 1003 |
+
" l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n",
|
| 1004 |
" ld['cv'] = l_cv\n",
|
| 1005 |
" l_spread = spread_loss(self.constellation.anchors)\n",
|
| 1006 |
" ld['spread'] = l_spread\n",
|
| 1007 |
+
" # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n",
|
| 1008 |
+
" if compute_knn:\n",
|
| 1009 |
+
" ld['knn_acc'] = knn_accuracy(emb1, targets)\n",
|
| 1010 |
" # \u2500\u2500 TOTAL \u2500\u2500\n",
|
| 1011 |
" loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n",
|
| 1012 |
" loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n",
|
|
|
|
| 1343 |
"\n",
|
| 1344 |
" train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n",
|
| 1345 |
" num_workers=num_workers, pin_memory=True,\n",
|
| 1346 |
+
" drop_last=True, collate_fn=paired_collate,\n",
|
| 1347 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1348 |
" val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n",
|
| 1349 |
+
" num_workers=num_workers, pin_memory=True,\n",
|
| 1350 |
+
" persistent_workers=(num_workers > 0))\n",
|
| 1351 |
" print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n",
|
| 1352 |
" return train_loader, val_loader\n",
|
| 1353 |
"\n",
|
|
|
|
| 1478 |
" optimizer.zero_grad()\n",
|
| 1479 |
"\n",
|
| 1480 |
" output = model.forward_paired(v1, v2)\n",
|
| 1481 |
+
" loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n",
|
| 1482 |
"\n",
|
| 1483 |
" loss.backward()\n",
|
| 1484 |
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|