AbstractPhil commited on
Commit
aaec850
·
verified ·
1 Parent(s): 6182b2b

Upload 6 files

Browse files
spectral/notebooks/experiment_2_manifold_structures.ipynb CHANGED
@@ -1390,10 +1390,12 @@
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
 
 
1393
  " if use_compile and hasattr(torch, 'compile'):\n",
1394
  " try:\n",
1395
- " model = torch.compile(model, mode='reduce-overhead')\n",
1396
- " print(\"[PERF] torch.compile enabled (reduce-overhead)\")\n",
1397
  " except Exception as e:\n",
1398
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1399
  "\n",
@@ -1770,6 +1772,7 @@
1770
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1771
  " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim}\")\n",
1772
  "\n",
 
1773
  " def forward(self, x):\n",
1774
  " B, C, H, W = x.shape\n",
1775
  " ps = self.patch_size\n",
@@ -1778,6 +1781,7 @@
1778
  " n_p = patches.shape[1]\n",
1779
  " # Reshape to (B*n_patches, ps*ps, C) for batched SVD\n",
1780
  " patches_flat = patches.reshape(B * n_p, C, ps * ps).permute(0, 2, 1) # (B*n_p, ps*ps, C)\n",
 
1781
  " U, S, Vh = torch.linalg.svd(patches_flat, full_matrices=False)\n",
1782
  " # Top-k singular values\n",
1783
  " sk = S[:, :self.k] # (B*n_p, k)\n",
@@ -1836,17 +1840,18 @@
1836
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1837
  " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim}\")\n",
1838
  "\n",
 
1839
  " def forward(self, x):\n",
1840
  " B, C, H, W = x.shape\n",
1841
  " ps = self.patch_size\n",
1842
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1843
  " n_p = patches.shape[2] * patches.shape[3]\n",
1844
  " patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
1845
- " # SVD \u2014 batched: (B*n_p, C, ps*ps)\n",
1846
  " if C <= ps * ps:\n",
1847
- " mat = patches # (B*n_p, C, ps*ps)\n",
1848
  " else:\n",
1849
- " mat = patches.transpose(1, 2) # (B*n_p, ps*ps, C)\n",
1850
  " U, S, Vh = torch.linalg.svd(mat, full_matrices=False)\n",
1851
  " # Collect features at each flag level\n",
1852
  " max_k = max(self.levels)\n",
 
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
1393
+ " # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
1394
+ " # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
1395
  " if use_compile and hasattr(torch, 'compile'):\n",
1396
  " try:\n",
1397
+ " model = torch.compile(model, mode='default')\n",
1398
+ " print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
1399
  " except Exception as e:\n",
1400
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1401
  "\n",
 
1772
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1773
  " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim}\")\n",
1774
  "\n",
1775
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1776
  " def forward(self, x):\n",
1777
  " B, C, H, W = x.shape\n",
1778
  " ps = self.patch_size\n",
 
1781
  " n_p = patches.shape[1]\n",
1782
  " # Reshape to (B*n_patches, ps*ps, C) for batched SVD\n",
1783
  " patches_flat = patches.reshape(B * n_p, C, ps * ps).permute(0, 2, 1) # (B*n_p, ps*ps, C)\n",
1784
+ " # SVD requires FP32 \u2014 cusolver does not support BF16/FP16\n",
1785
  " U, S, Vh = torch.linalg.svd(patches_flat, full_matrices=False)\n",
1786
  " # Top-k singular values\n",
1787
  " sk = S[:, :self.k] # (B*n_p, k)\n",
 
1840
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1841
  " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim}\")\n",
1842
  "\n",
1843
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1844
  " def forward(self, x):\n",
1845
  " B, C, H, W = x.shape\n",
1846
  " ps = self.patch_size\n",
1847
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1848
  " n_p = patches.shape[2] * patches.shape[3]\n",
1849
  " patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
1850
+ " # SVD \u2014 batched, FP32 (cusolver requirement)\n",
1851
  " if C <= ps * ps:\n",
1852
+ " mat = patches\n",
1853
  " else:\n",
1854
+ " mat = patches.transpose(1, 2)\n",
1855
  " U, S, Vh = torch.linalg.svd(mat, full_matrices=False)\n",
1856
  " # Collect features at each flag level\n",
1857
  " max_k = max(self.levels)\n",
spectral/notebooks/experiment_3_compact_representations.ipynb CHANGED
@@ -1389,10 +1389,12 @@
1389
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1390
  "\n",
1391
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
 
 
1392
  " if use_compile and hasattr(torch, 'compile'):\n",
1393
  " try:\n",
1394
- " model = torch.compile(model, mode='reduce-overhead')\n",
1395
- " print(\"[PERF] torch.compile enabled (reduce-overhead)\")\n",
1396
  " except Exception as e:\n",
1397
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1398
  "\n",
 
1389
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1390
  "\n",
1391
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
1392
+ " # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
1393
+ " # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
1394
  " if use_compile and hasattr(torch, 'compile'):\n",
1395
  " try:\n",
1396
+ " model = torch.compile(model, mode='default')\n",
1397
+ " print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
1398
  " except Exception as e:\n",
1399
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1400
  "\n",
spectral/notebooks/experiment_4_invertible_transforms.ipynb CHANGED
@@ -1390,10 +1390,12 @@
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
 
 
1393
  " if use_compile and hasattr(torch, 'compile'):\n",
1394
  " try:\n",
1395
- " model = torch.compile(model, mode='reduce-overhead')\n",
1396
- " print(\"[PERF] torch.compile enabled (reduce-overhead)\")\n",
1397
  " except Exception as e:\n",
1398
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1399
  "\n",
 
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
1393
+ " # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
1394
+ " # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
1395
  " if use_compile and hasattr(torch, 'compile'):\n",
1396
  " try:\n",
1397
+ " model = torch.compile(model, mode='default')\n",
1398
+ " print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
1399
  " except Exception as e:\n",
1400
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1401
  "\n",
spectral/notebooks/experiment_5_matrix_decompositions.ipynb CHANGED
@@ -1390,10 +1390,12 @@
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
 
 
1393
  " if use_compile and hasattr(torch, 'compile'):\n",
1394
  " try:\n",
1395
- " model = torch.compile(model, mode='reduce-overhead')\n",
1396
- " print(\"[PERF] torch.compile enabled (reduce-overhead)\")\n",
1397
  " except Exception as e:\n",
1398
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1399
  "\n",
@@ -1707,15 +1709,15 @@
1707
  " self.output_dim = n_patches * n_upper\n",
1708
  " print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim}\")\n",
1709
  "\n",
 
1710
  " def forward(self, x):\n",
1711
  " B, C, H, W = x.shape\n",
1712
  " ps = self.patch_size\n",
1713
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1714
  " n_p = patches.shape[2] * patches.shape[3]\n",
1715
  " patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
1716
- " # Batched QR: (B*n_p, ps*ps, C)\n",
1717
  " Q, R = torch.linalg.qr(patches.transpose(1, 2))\n",
1718
- " # Upper triangle of R: R is (B*n_p, C, C), take upper tri\n",
1719
  " k = self.k\n",
1720
  " triu_idx = torch.triu_indices(k, k)\n",
1721
  " upper = R[:, triu_idx[0], triu_idx[1]] # (B*n_p, k*(k+1)/2)\n",
@@ -1762,15 +1764,15 @@
1762
  " self.output_dim = n_patches * n_eig * 2 # eigenvalues + phases\n",
1763
  " print(f\"[SCHUR] {n_patches} patches, n_eig={n_eig}, dim={self.output_dim}\")\n",
1764
  "\n",
 
1765
  " def forward(self, x):\n",
1766
  " B, C, H, W = x.shape\n",
1767
  " ps = self.patch_size\n",
1768
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1769
  " n_p = patches.shape[2] * patches.shape[3]\n",
1770
  " patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
1771
- " # Batched covariance: (B*n_p, C, C)\n",
1772
  " cov = torch.bmm(patches, patches.transpose(1, 2)) / (ps * ps)\n",
1773
- " # Batched eigenvalues: (B*n_p, C)\n",
1774
  " eigvals = torch.linalg.eigvalsh(cov)\n",
1775
  " # Top n_eig eigenvalues (eigvalsh returns ascending, so take last n_eig)\n",
1776
  " top = eigvals[:, -self.n_eig:] # (B*n_p, n_eig)\n",
 
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
1393
+ " # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
1394
+ " # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
1395
  " if use_compile and hasattr(torch, 'compile'):\n",
1396
  " try:\n",
1397
+ " model = torch.compile(model, mode='default')\n",
1398
+ " print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
1399
  " except Exception as e:\n",
1400
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1401
  "\n",
 
1709
  " self.output_dim = n_patches * n_upper\n",
1710
  " print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim}\")\n",
1711
  "\n",
1712
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1713
  " def forward(self, x):\n",
1714
  " B, C, H, W = x.shape\n",
1715
  " ps = self.patch_size\n",
1716
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1717
  " n_p = patches.shape[2] * patches.shape[3]\n",
1718
  " patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
1719
+ " # Batched QR in FP32 (cusolver requirement)\n",
1720
  " Q, R = torch.linalg.qr(patches.transpose(1, 2))\n",
 
1721
  " k = self.k\n",
1722
  " triu_idx = torch.triu_indices(k, k)\n",
1723
  " upper = R[:, triu_idx[0], triu_idx[1]] # (B*n_p, k*(k+1)/2)\n",
 
1764
  " self.output_dim = n_patches * n_eig * 2 # eigenvalues + phases\n",
1765
  " print(f\"[SCHUR] {n_patches} patches, n_eig={n_eig}, dim={self.output_dim}\")\n",
1766
  "\n",
1767
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1768
  " def forward(self, x):\n",
1769
  " B, C, H, W = x.shape\n",
1770
  " ps = self.patch_size\n",
1771
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1772
  " n_p = patches.shape[2] * patches.shape[3]\n",
1773
  " patches = patches.contiguous().reshape(B * n_p, C, ps * ps)\n",
1774
+ " # Batched covariance in FP32 (cusolver requirement)\n",
1775
  " cov = torch.bmm(patches, patches.transpose(1, 2)) / (ps * ps)\n",
 
1776
  " eigvals = torch.linalg.eigvalsh(cov)\n",
1777
  " # Top n_eig eigenvalues (eigvalsh returns ascending, so take last n_eig)\n",
1778
  " top = eigvals[:, -self.n_eig:] # (B*n_p, n_eig)\n",
spectral/notebooks/experiment_6_losses_and_anchors.ipynb CHANGED
@@ -1392,10 +1392,12 @@
1392
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1393
  "\n",
1394
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
 
 
1395
  " if use_compile and hasattr(torch, 'compile'):\n",
1396
  " try:\n",
1397
- " model = torch.compile(model, mode='reduce-overhead')\n",
1398
- " print(\"[PERF] torch.compile enabled (reduce-overhead)\")\n",
1399
  " except Exception as e:\n",
1400
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1401
  "\n",
 
1392
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1393
  "\n",
1394
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
1395
+ " # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
1396
+ " # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
1397
  " if use_compile and hasattr(torch, 'compile'):\n",
1398
  " try:\n",
1399
+ " model = torch.compile(model, mode='default')\n",
1400
+ " print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
1401
  " except Exception as e:\n",
1402
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1403
  "\n",
spectral/notebooks/experiment_7_composite_pipelines.ipynb CHANGED
@@ -1390,10 +1390,12 @@
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
 
 
1393
  " if use_compile and hasattr(torch, 'compile'):\n",
1394
  " try:\n",
1395
- " model = torch.compile(model, mode='reduce-overhead')\n",
1396
- " print(\"[PERF] torch.compile enabled (reduce-overhead)\")\n",
1397
  " except Exception as e:\n",
1398
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1399
  "\n",
 
1390
  " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n",
1391
  "\n",
1392
  " # \u2500\u2500 torch.compile \u2500\u2500\n",
1393
+ " # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n",
1394
+ " # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n",
1395
  " if use_compile and hasattr(torch, 'compile'):\n",
1396
  " try:\n",
1397
+ " model = torch.compile(model, mode='default')\n",
1398
+ " print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n",
1399
  " except Exception as e:\n",
1400
  " print(f\"[PERF] torch.compile skipped: {e}\")\n",
1401
  "\n",