Upload 6 files
Browse files- spectral/notebooks/experiment_2_manifold_structures.ipynb +10 -5
- spectral/notebooks/experiment_3_compact_representations.ipynb +4 -2
- spectral/notebooks/experiment_4_invertible_transforms.ipynb +4 -2
- spectral/notebooks/experiment_5_matrix_decompositions.ipynb +8 -6
- spectral/notebooks/experiment_6_losses_and_anchors.ipynb +4 -2
- spectral/notebooks/experiment_7_composite_pipelines.ipynb +4 -2
spectral/notebooks/experiment_2_manifold_structures.ipynb
CHANGED
|
@@ -1390,10 +1390,12 @@
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
|
|
|
|
|
|
| 1393 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1394 |
" try:\n",
|
| 1395 |
-
" model = torch.compile(model, mode='
|
| 1396 |
-
" print(\"[PERF] torch.compile enabled (
|
| 1397 |
" except Exception as e:\n",
|
| 1398 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1399 |
"\n",
|
|
@@ -1770,6 +1772,7 @@
|
|
| 1770 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1771 |
" print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim}\")\n",
|
| 1772 |
"\n",
|
|
|
|
| 1773 |
" def forward(self, x):\n",
|
| 1774 |
" B, C, H, W = x.shape\n",
|
| 1775 |
" ps = self.patch_size\n",
|
|
@@ -1778,6 +1781,7 @@
|
|
| 1778 |
" n_p = patches.shape[1]\n",
|
| 1779 |
" # Reshape to (B*n_patches, ps*ps, C) for batched SVD\n",
|
| 1780 |
" patches_flat = patches.reshape(B * n_p, C, ps * ps).permute(0, 2, 1) # (B*n_p, ps*ps, C)\n",
|
|
|
|
| 1781 |
" U, S, Vh = torch.linalg.svd(patches_flat, full_matrices=False)\n",
|
| 1782 |
" # Top-k singular values\n",
|
| 1783 |
" sk = S[:, :self.k] # (B*n_p, k)\n",
|
|
@@ -1836,17 +1840,18 @@
|
|
| 1836 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1837 |
" print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim}\")\n",
|
| 1838 |
"\n",
|
|
|
|
| 1839 |
" def forward(self, x):\n",
|
| 1840 |
" B, C, H, W = x.shape\n",
|
| 1841 |
" ps = self.patch_size\n",
|
| 1842 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1843 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1844 |
" patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
|
| 1845 |
-
" # SVD \u2014 batched
|
| 1846 |
" if C <= ps * ps:\n",
|
| 1847 |
-
" mat = patches
|
| 1848 |
" else:\n",
|
| 1849 |
-
" mat = patches.transpose(1, 2)
|
| 1850 |
" U, S, Vh = torch.linalg.svd(mat, full_matrices=False)\n",
|
| 1851 |
" # Collect features at each flag level\n",
|
| 1852 |
" max_k = max(self.levels)\n",
|
|
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
| 1393 |
+
" # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
|
| 1394 |
+
" # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
|
| 1395 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1396 |
" try:\n",
|
| 1397 |
+
" model = torch.compile(model, mode='default')\n",
|
| 1398 |
+
" print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
|
| 1399 |
" except Exception as e:\n",
|
| 1400 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1401 |
"\n",
|
|
|
|
| 1772 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1773 |
" print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim}\")\n",
|
| 1774 |
"\n",
|
| 1775 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1776 |
" def forward(self, x):\n",
|
| 1777 |
" B, C, H, W = x.shape\n",
|
| 1778 |
" ps = self.patch_size\n",
|
|
|
|
| 1781 |
" n_p = patches.shape[1]\n",
|
| 1782 |
" # Reshape to (B*n_patches, ps*ps, C) for batched SVD\n",
|
| 1783 |
" patches_flat = patches.reshape(B * n_p, C, ps * ps).permute(0, 2, 1) # (B*n_p, ps*ps, C)\n",
|
| 1784 |
+
" # SVD requires FP32 \u2014 cusolver does not support BF16/FP16\n",
|
| 1785 |
" U, S, Vh = torch.linalg.svd(patches_flat, full_matrices=False)\n",
|
| 1786 |
" # Top-k singular values\n",
|
| 1787 |
" sk = S[:, :self.k] # (B*n_p, k)\n",
|
|
|
|
| 1840 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1841 |
" print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim}\")\n",
|
| 1842 |
"\n",
|
| 1843 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1844 |
" def forward(self, x):\n",
|
| 1845 |
" B, C, H, W = x.shape\n",
|
| 1846 |
" ps = self.patch_size\n",
|
| 1847 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1848 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1849 |
" patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
|
| 1850 |
+
" # SVD \u2014 batched, FP32 (cusolver requirement)\n",
|
| 1851 |
" if C <= ps * ps:\n",
|
| 1852 |
+
" mat = patches\n",
|
| 1853 |
" else:\n",
|
| 1854 |
+
" mat = patches.transpose(1, 2)\n",
|
| 1855 |
" U, S, Vh = torch.linalg.svd(mat, full_matrices=False)\n",
|
| 1856 |
" # Collect features at each flag level\n",
|
| 1857 |
" max_k = max(self.levels)\n",
|
spectral/notebooks/experiment_3_compact_representations.ipynb
CHANGED
|
@@ -1389,10 +1389,12 @@
|
|
| 1389 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1390 |
"\n",
|
| 1391 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
|
|
|
|
|
|
| 1392 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1393 |
" try:\n",
|
| 1394 |
-
" model = torch.compile(model, mode='
|
| 1395 |
-
" print(\"[PERF] torch.compile enabled (
|
| 1396 |
" except Exception as e:\n",
|
| 1397 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1398 |
"\n",
|
|
|
|
| 1389 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1390 |
"\n",
|
| 1391 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
| 1392 |
+
" # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
|
| 1393 |
+
" # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
|
| 1394 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1395 |
" try:\n",
|
| 1396 |
+
" model = torch.compile(model, mode='default')\n",
|
| 1397 |
+
" print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
|
| 1398 |
" except Exception as e:\n",
|
| 1399 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1400 |
"\n",
|
spectral/notebooks/experiment_4_invertible_transforms.ipynb
CHANGED
|
@@ -1390,10 +1390,12 @@
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
|
|
|
|
|
|
| 1393 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1394 |
" try:\n",
|
| 1395 |
-
" model = torch.compile(model, mode='
|
| 1396 |
-
" print(\"[PERF] torch.compile enabled (
|
| 1397 |
" except Exception as e:\n",
|
| 1398 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1399 |
"\n",
|
|
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
| 1393 |
+
" # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
|
| 1394 |
+
" # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
|
| 1395 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1396 |
" try:\n",
|
| 1397 |
+
" model = torch.compile(model, mode='default')\n",
|
| 1398 |
+
" print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
|
| 1399 |
" except Exception as e:\n",
|
| 1400 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1401 |
"\n",
|
spectral/notebooks/experiment_5_matrix_decompositions.ipynb
CHANGED
|
@@ -1390,10 +1390,12 @@
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
|
|
|
|
|
|
| 1393 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1394 |
" try:\n",
|
| 1395 |
-
" model = torch.compile(model, mode='
|
| 1396 |
-
" print(\"[PERF] torch.compile enabled (
|
| 1397 |
" except Exception as e:\n",
|
| 1398 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1399 |
"\n",
|
|
@@ -1707,15 +1709,15 @@
|
|
| 1707 |
" self.output_dim = n_patches * n_upper\n",
|
| 1708 |
" print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim}\")\n",
|
| 1709 |
"\n",
|
|
|
|
| 1710 |
" def forward(self, x):\n",
|
| 1711 |
" B, C, H, W = x.shape\n",
|
| 1712 |
" ps = self.patch_size\n",
|
| 1713 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1714 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1715 |
" patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
|
| 1716 |
-
" # Batched QR
|
| 1717 |
" Q, R = torch.linalg.qr(patches.transpose(1, 2))\n",
|
| 1718 |
-
" # Upper triangle of R: R is (B*n_p, C, C), take upper tri\n",
|
| 1719 |
" k = self.k\n",
|
| 1720 |
" triu_idx = torch.triu_indices(k, k)\n",
|
| 1721 |
" upper = R[:, triu_idx[0], triu_idx[1]] # (B*n_p, k*(k+1)/2)\n",
|
|
@@ -1762,15 +1764,15 @@
|
|
| 1762 |
" self.output_dim = n_patches * n_eig * 2 # eigenvalues + phases\n",
|
| 1763 |
" print(f\"[SCHUR] {n_patches} patches, n_eig={n_eig}, dim={self.output_dim}\")\n",
|
| 1764 |
"\n",
|
|
|
|
| 1765 |
" def forward(self, x):\n",
|
| 1766 |
" B, C, H, W = x.shape\n",
|
| 1767 |
" ps = self.patch_size\n",
|
| 1768 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1769 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1770 |
" patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
|
| 1771 |
-
" # Batched covariance
|
| 1772 |
" cov = torch.bmm(patches, patches.transpose(1, 2)) / (ps * ps)\n",
|
| 1773 |
-
" # Batched eigenvalues: (B*n_p, C)\n",
|
| 1774 |
" eigvals = torch.linalg.eigvalsh(cov)\n",
|
| 1775 |
" # Top n_eig eigenvalues (eigvalsh returns ascending, so take last n_eig)\n",
|
| 1776 |
" top = eigvals[:, -self.n_eig:] # (B*n_p, n_eig)\n",
|
|
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
| 1393 |
+
" # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
|
| 1394 |
+
" # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
|
| 1395 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1396 |
" try:\n",
|
| 1397 |
+
" model = torch.compile(model, mode='default')\n",
|
| 1398 |
+
" print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
|
| 1399 |
" except Exception as e:\n",
|
| 1400 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1401 |
"\n",
|
|
|
|
| 1709 |
" self.output_dim = n_patches * n_upper\n",
|
| 1710 |
" print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim}\")\n",
|
| 1711 |
"\n",
|
| 1712 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1713 |
" def forward(self, x):\n",
|
| 1714 |
" B, C, H, W = x.shape\n",
|
| 1715 |
" ps = self.patch_size\n",
|
| 1716 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1717 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1718 |
" patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
|
| 1719 |
+
" # Batched QR in FP32 (cusolver requirement)\n",
|
| 1720 |
" Q, R = torch.linalg.qr(patches.transpose(1, 2))\n",
|
|
|
|
| 1721 |
" k = self.k\n",
|
| 1722 |
" triu_idx = torch.triu_indices(k, k)\n",
|
| 1723 |
" upper = R[:, triu_idx[0], triu_idx[1]] # (B*n_p, k*(k+1)/2)\n",
|
|
|
|
| 1764 |
" self.output_dim = n_patches * n_eig * 2 # eigenvalues + phases\n",
|
| 1765 |
" print(f\"[SCHUR] {n_patches} patches, n_eig={n_eig}, dim={self.output_dim}\")\n",
|
| 1766 |
"\n",
|
| 1767 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1768 |
" def forward(self, x):\n",
|
| 1769 |
" B, C, H, W = x.shape\n",
|
| 1770 |
" ps = self.patch_size\n",
|
| 1771 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1772 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1773 |
" patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
|
| 1774 |
+
" # Batched covariance in FP32 (cusolver requirement)\n",
|
| 1775 |
" cov = torch.bmm(patches, patches.transpose(1, 2)) / (ps * ps)\n",
|
|
|
|
| 1776 |
" eigvals = torch.linalg.eigvalsh(cov)\n",
|
| 1777 |
" # Top n_eig eigenvalues (eigvalsh returns ascending, so take last n_eig)\n",
|
| 1778 |
" top = eigvals[:, -self.n_eig:] # (B*n_p, n_eig)\n",
|
spectral/notebooks/experiment_6_losses_and_anchors.ipynb
CHANGED
|
@@ -1392,10 +1392,12 @@
|
|
| 1392 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1393 |
"\n",
|
| 1394 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
|
|
|
|
|
|
| 1395 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1396 |
" try:\n",
|
| 1397 |
-
" model = torch.compile(model, mode='
|
| 1398 |
-
" print(\"[PERF] torch.compile enabled (
|
| 1399 |
" except Exception as e:\n",
|
| 1400 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1401 |
"\n",
|
|
|
|
| 1392 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1393 |
"\n",
|
| 1394 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
| 1395 |
+
" # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
|
| 1396 |
+
" # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
|
| 1397 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1398 |
" try:\n",
|
| 1399 |
+
" model = torch.compile(model, mode='default')\n",
|
| 1400 |
+
" print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
|
| 1401 |
" except Exception as e:\n",
|
| 1402 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1403 |
"\n",
|
spectral/notebooks/experiment_7_composite_pipelines.ipynb
CHANGED
|
@@ -1390,10 +1390,12 @@
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
|
|
|
|
|
|
| 1393 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1394 |
" try:\n",
|
| 1395 |
-
" model = torch.compile(model, mode='
|
| 1396 |
-
" print(\"[PERF] torch.compile enabled (
|
| 1397 |
" except Exception as e:\n",
|
| 1398 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1399 |
"\n",
|
|
|
|
| 1390 |
" print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
|
| 1391 |
"\n",
|
| 1392 |
" # \u2500\u2500 torch.compile \u2500\u2500\n",
|
| 1393 |
+
" # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
|
| 1394 |
+
" # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
|
| 1395 |
" if use_compile and hasattr(torch, 'compile'):\n",
|
| 1396 |
" try:\n",
|
| 1397 |
+
" model = torch.compile(model, mode='default')\n",
|
| 1398 |
+
" print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
|
| 1399 |
" except Exception as e:\n",
|
| 1400 |
" print(f\"[PERF] torch.compile skipped: {e}\")\n",
|
| 1401 |
"\n",
|