Upload 6 files
Browse files- spectral/notebooks/experiment_2_manifold_structures.ipynb +151 -14
- spectral/notebooks/experiment_3_compact_representations.ipynb +137 -1
- spectral/notebooks/experiment_4_invertible_transforms.ipynb +137 -1
- spectral/notebooks/experiment_5_matrix_decompositions.ipynb +137 -1
- spectral/notebooks/experiment_6_losses_and_anchors.ipynb +137 -1
- spectral/notebooks/experiment_7_composite_pipelines.ipynb +137 -1
spectral/notebooks/experiment_2_manifold_structures.ipynb
CHANGED
|
@@ -39,7 +39,7 @@
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
-
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
@@ -92,6 +92,142 @@
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 96 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 97 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
|
@@ -1766,12 +1902,10 @@
|
|
| 1766 |
"source": [
|
| 1767 |
"# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
|
| 1768 |
"class GrassmannianFrontEnd(nn.Module):\n",
|
| 1769 |
-
" \"\"\"Grassmannian subspace features via
|
| 1770 |
-
"
|
| 1771 |
-
"
|
| 1772 |
-
"
|
| 1773 |
-
" Grassmannian coordinate). V varies per patch and encodes which\n",
|
| 1774 |
-
" linear combinations of RGB correspond to principal directions.\"\"\"\n",
|
| 1775 |
" def __init__(self, patch_size=8, k=3, input_size=32):\n",
|
| 1776 |
" super().__init__()\n",
|
| 1777 |
" self.patch_size = patch_size\n",
|
|
@@ -1782,7 +1916,8 @@
|
|
| 1782 |
" # k singular values + k log-ratios + k*C right singular vector entries\n",
|
| 1783 |
" self.features_per_patch = k + k + k * self.C\n",
|
| 1784 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1785 |
-
"
|
|
|
|
| 1786 |
"\n",
|
| 1787 |
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1788 |
" def forward(self, x):\n",
|
|
@@ -1792,8 +1927,8 @@
|
|
| 1792 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1793 |
" # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
|
| 1794 |
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1795 |
-
" #
|
| 1796 |
-
" U, S, Vh =
|
| 1797 |
" S = S[:, :self.k]\n",
|
| 1798 |
" # Log singular value ratios (scale-invariant spectrum)\n",
|
| 1799 |
" sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
|
|
@@ -1834,7 +1969,8 @@
|
|
| 1834 |
"source": [
|
| 1835 |
"# @title Experiment 2.3 \u2014 Flag Manifold\n",
|
| 1836 |
"class FlagManifoldFrontEnd(nn.Module):\n",
|
| 1837 |
-
" \"\"\"Cascading SVD at multiple truncation levels
|
|
|
|
| 1838 |
" Nested subspace features: singular values + projection norms at each flag level.\n",
|
| 1839 |
" The flag structure captures how information distributes across\n",
|
| 1840 |
" nested subspace hierarchies \u2014 a genuine flag manifold signature.\"\"\"\n",
|
|
@@ -1847,7 +1983,8 @@
|
|
| 1847 |
" max_sv = min(3, patch_size * patch_size)\n",
|
| 1848 |
" self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
|
| 1849 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1850 |
-
"
|
|
|
|
| 1851 |
"\n",
|
| 1852 |
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1853 |
" def forward(self, x):\n",
|
|
@@ -1857,8 +1994,8 @@
|
|
| 1857 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1858 |
" # X: (B*n_p, ps*ps, C)\n",
|
| 1859 |
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1860 |
-
" #
|
| 1861 |
-
" U, S, Vh =
|
| 1862 |
" # Features at each flag level\n",
|
| 1863 |
" feats = []\n",
|
| 1864 |
" for k in self.levels:\n",
|
|
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
+
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
| 95 |
+
"# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 96 |
+
"# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
|
| 97 |
+
"# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
|
| 98 |
+
"# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
|
| 99 |
+
"_HAS_TRITON_SVD3 = False\n",
|
| 100 |
+
"try:\n",
|
| 101 |
+
" import triton\n",
|
| 102 |
+
" import triton.language as tl\n",
|
| 103 |
+
"\n",
|
| 104 |
+
" @triton.jit\n",
|
| 105 |
+
" def _svd3_kernel(\n",
|
| 106 |
+
" A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
|
| 107 |
+
" M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
|
| 108 |
+
" JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
|
| 109 |
+
" ):\n",
|
| 110 |
+
" bid = tl.program_id(0)\n",
|
| 111 |
+
" # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
|
| 112 |
+
" g00 = tl.zeros([], dtype=tl.float32)\n",
|
| 113 |
+
" g01 = tl.zeros([], dtype=tl.float32)\n",
|
| 114 |
+
" g02 = tl.zeros([], dtype=tl.float32)\n",
|
| 115 |
+
" g11 = tl.zeros([], dtype=tl.float32)\n",
|
| 116 |
+
" g12 = tl.zeros([], dtype=tl.float32)\n",
|
| 117 |
+
" g22 = tl.zeros([], dtype=tl.float32)\n",
|
| 118 |
+
" base = bid * M * 3\n",
|
| 119 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 120 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 121 |
+
" row_idx = block_start + offs\n",
|
| 122 |
+
" mask = row_idx < M\n",
|
| 123 |
+
" ptr0 = base + row_idx * 3 + 0\n",
|
| 124 |
+
" ptr1 = base + row_idx * 3 + 1\n",
|
| 125 |
+
" ptr2 = base + row_idx * 3 + 2\n",
|
| 126 |
+
" a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 127 |
+
" a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 128 |
+
" a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 129 |
+
" g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
|
| 130 |
+
" g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
|
| 131 |
+
" # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
|
| 132 |
+
" v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
|
| 133 |
+
" v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
|
| 134 |
+
" v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
|
| 135 |
+
" for _sweep in range(JACOBI_ITERS):\n",
|
| 136 |
+
" # pair (0,1)\n",
|
| 137 |
+
" off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
|
| 138 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 139 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 140 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 141 |
+
" ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
|
| 142 |
+
" ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
|
| 143 |
+
" g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
|
| 144 |
+
" nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
|
| 145 |
+
" nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
|
| 146 |
+
" v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
|
| 147 |
+
" # pair (0,2)\n",
|
| 148 |
+
" off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
|
| 149 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 150 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 151 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 152 |
+
" ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
|
| 153 |
+
" ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
|
| 154 |
+
" g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
|
| 155 |
+
" nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
|
| 156 |
+
" nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
|
| 157 |
+
" v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
|
| 158 |
+
" # pair (1,2)\n",
|
| 159 |
+
" off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
|
| 160 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 161 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 162 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 163 |
+
" ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
|
| 164 |
+
" ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
|
| 165 |
+
" g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
|
| 166 |
+
" nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
|
| 167 |
+
" nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
|
| 168 |
+
" v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
|
| 169 |
+
" # Sort eigenvalues descending + permute V columns\n",
|
| 170 |
+
" s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
|
| 171 |
+
" do_swap = s0 < s1\n",
|
| 172 |
+
" s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
|
| 173 |
+
" tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
|
| 174 |
+
" tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
|
| 175 |
+
" tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
|
| 176 |
+
" do_swap = s0 < s2\n",
|
| 177 |
+
" s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
|
| 178 |
+
" tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
|
| 179 |
+
" tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
|
| 180 |
+
" tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
|
| 181 |
+
" do_swap = s1 < s2\n",
|
| 182 |
+
" s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
|
| 183 |
+
" tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
|
| 184 |
+
" tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
|
| 185 |
+
" tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
|
| 186 |
+
" # Write S\n",
|
| 187 |
+
" s_base = bid * 3\n",
|
| 188 |
+
" tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
|
| 189 |
+
" # Write Vh = V^T\n",
|
| 190 |
+
" vh_base = bid * 9\n",
|
| 191 |
+
" tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
|
| 192 |
+
" tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
|
| 193 |
+
" tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
|
| 194 |
+
" # Stage 3: U = A @ V @ diag(1/S)\n",
|
| 195 |
+
" inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
|
| 196 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 197 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 198 |
+
" row_idx = block_start + offs\n",
|
| 199 |
+
" mask = row_idx < M\n",
|
| 200 |
+
" a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 201 |
+
" a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 202 |
+
" a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 203 |
+
" u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
|
| 204 |
+
" u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
|
| 205 |
+
" u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
|
| 206 |
+
" u_base = bid * M * 3\n",
|
| 207 |
+
" tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
|
| 208 |
+
" tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
|
| 209 |
+
" tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 212 |
+
" \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
|
| 213 |
+
" assert A.ndim == 3 and A.shape[2] == 3\n",
|
| 214 |
+
" B, M, _ = A.shape\n",
|
| 215 |
+
" A_f32 = A.contiguous().float()\n",
|
| 216 |
+
" U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
|
| 217 |
+
" S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
|
| 218 |
+
" Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
|
| 219 |
+
" _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
|
| 220 |
+
" return U, S, Vh\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" _HAS_TRITON_SVD3 = True\n",
|
| 223 |
+
" print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
|
| 224 |
+
"except ImportError:\n",
|
| 225 |
+
" print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 228 |
+
" \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
|
| 229 |
+
" return torch.linalg.svd(A.float(), full_matrices=False)\n",
|
| 230 |
+
"\n",
|
| 231 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 232 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 233 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
|
|
|
| 1902 |
"source": [
|
| 1903 |
"# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
|
| 1904 |
"class GrassmannianFrontEnd(nn.Module):\n",
|
| 1905 |
+
" \"\"\"Grassmannian subspace features via SVD.\n",
|
| 1906 |
+
" Uses fused Triton SVD3 kernel (3\u00d73 Jacobi in registers) when available,\n",
|
| 1907 |
+
" falls back to torch.linalg.svd otherwise. Same mathematical decomposition.\n",
|
| 1908 |
+
" Features: singular values, log ratios, right singular vectors V.\"\"\"\n",
|
|
|
|
|
|
|
| 1909 |
" def __init__(self, patch_size=8, k=3, input_size=32):\n",
|
| 1910 |
" super().__init__()\n",
|
| 1911 |
" self.patch_size = patch_size\n",
|
|
|
|
| 1916 |
" # k singular values + k log-ratios + k*C right singular vector entries\n",
|
| 1917 |
" self.features_per_patch = k + k + k * self.C\n",
|
| 1918 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1919 |
+
" backend = \"Triton SVD3\" if _HAS_TRITON_SVD3 else \"torch.linalg.svd\"\n",
|
| 1920 |
+
" print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim} ({backend})\")\n",
|
| 1921 |
"\n",
|
| 1922 |
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1923 |
" def forward(self, x):\n",
|
|
|
|
| 1927 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1928 |
" # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
|
| 1929 |
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1930 |
+
" # SVD via fused Triton kernel (or torch fallback)\n",
|
| 1931 |
+
" U, S, Vh = batched_svd3(X)\n",
|
| 1932 |
" S = S[:, :self.k]\n",
|
| 1933 |
" # Log singular value ratios (scale-invariant spectrum)\n",
|
| 1934 |
" sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
|
|
|
|
| 1969 |
"source": [
|
| 1970 |
"# @title Experiment 2.3 \u2014 Flag Manifold\n",
|
| 1971 |
"class FlagManifoldFrontEnd(nn.Module):\n",
|
| 1972 |
+
" \"\"\"Cascading SVD at multiple truncation levels.\n",
|
| 1973 |
+
" Uses fused Triton SVD3 kernel when available.\n",
|
| 1974 |
" Nested subspace features: singular values + projection norms at each flag level.\n",
|
| 1975 |
" The flag structure captures how information distributes across\n",
|
| 1976 |
" nested subspace hierarchies \u2014 a genuine flag manifold signature.\"\"\"\n",
|
|
|
|
| 1983 |
" max_sv = min(3, patch_size * patch_size)\n",
|
| 1984 |
" self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
|
| 1985 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1986 |
+
" backend = \"Triton SVD3\" if _HAS_TRITON_SVD3 else \"torch.linalg.svd\"\n",
|
| 1987 |
+
" print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim} ({backend})\")\n",
|
| 1988 |
"\n",
|
| 1989 |
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1990 |
" def forward(self, x):\n",
|
|
|
|
| 1994 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1995 |
" # X: (B*n_p, ps*ps, C)\n",
|
| 1996 |
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1997 |
+
" # SVD via fused Triton kernel (or torch fallback)\n",
|
| 1998 |
+
" U, S, Vh = batched_svd3(X)\n",
|
| 1999 |
" # Features at each flag level\n",
|
| 2000 |
" feats = []\n",
|
| 2001 |
" for k in self.levels:\n",
|
spectral/notebooks/experiment_3_compact_representations.ipynb
CHANGED
|
@@ -38,7 +38,7 @@
|
|
| 38 |
"metadata": {},
|
| 39 |
"source": [
|
| 40 |
"# @title Install Dependencies\n",
|
| 41 |
-
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
|
| 42 |
"%load_ext tensorboard\n",
|
| 43 |
"import torch\n",
|
| 44 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
@@ -91,6 +91,142 @@
|
|
| 91 |
"if device.type == \"cuda\":\n",
|
| 92 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 93 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 95 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 96 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
|
|
|
| 38 |
"metadata": {},
|
| 39 |
"source": [
|
| 40 |
"# @title Install Dependencies\n",
|
| 41 |
+
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
|
| 42 |
"%load_ext tensorboard\n",
|
| 43 |
"import torch\n",
|
| 44 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
|
|
| 91 |
"if device.type == \"cuda\":\n",
|
| 92 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 93 |
"\n",
|
| 94 |
+
"# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 95 |
+
"# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
|
| 96 |
+
"# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
|
| 97 |
+
"# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
|
| 98 |
+
"_HAS_TRITON_SVD3 = False\n",
|
| 99 |
+
"try:\n",
|
| 100 |
+
" import triton\n",
|
| 101 |
+
" import triton.language as tl\n",
|
| 102 |
+
"\n",
|
| 103 |
+
" @triton.jit\n",
|
| 104 |
+
" def _svd3_kernel(\n",
|
| 105 |
+
" A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
|
| 106 |
+
" M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
|
| 107 |
+
" JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
|
| 108 |
+
" ):\n",
|
| 109 |
+
" bid = tl.program_id(0)\n",
|
| 110 |
+
" # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
|
| 111 |
+
" g00 = tl.zeros([], dtype=tl.float32)\n",
|
| 112 |
+
" g01 = tl.zeros([], dtype=tl.float32)\n",
|
| 113 |
+
" g02 = tl.zeros([], dtype=tl.float32)\n",
|
| 114 |
+
" g11 = tl.zeros([], dtype=tl.float32)\n",
|
| 115 |
+
" g12 = tl.zeros([], dtype=tl.float32)\n",
|
| 116 |
+
" g22 = tl.zeros([], dtype=tl.float32)\n",
|
| 117 |
+
" base = bid * M * 3\n",
|
| 118 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 119 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 120 |
+
" row_idx = block_start + offs\n",
|
| 121 |
+
" mask = row_idx < M\n",
|
| 122 |
+
" ptr0 = base + row_idx * 3 + 0\n",
|
| 123 |
+
" ptr1 = base + row_idx * 3 + 1\n",
|
| 124 |
+
" ptr2 = base + row_idx * 3 + 2\n",
|
| 125 |
+
" a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 126 |
+
" a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 127 |
+
" a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 128 |
+
" g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
|
| 129 |
+
" g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
|
| 130 |
+
" # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
|
| 131 |
+
" v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
|
| 132 |
+
" v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
|
| 133 |
+
" v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
|
| 134 |
+
" for _sweep in range(JACOBI_ITERS):\n",
|
| 135 |
+
" # pair (0,1)\n",
|
| 136 |
+
" off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
|
| 137 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 138 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 139 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 140 |
+
" ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
|
| 141 |
+
" ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
|
| 142 |
+
" g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
|
| 143 |
+
" nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
|
| 144 |
+
" nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
|
| 145 |
+
" v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
|
| 146 |
+
" # pair (0,2)\n",
|
| 147 |
+
" off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
|
| 148 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 149 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 150 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 151 |
+
" ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
|
| 152 |
+
" ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
|
| 153 |
+
" g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
|
| 154 |
+
" nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
|
| 155 |
+
" nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
|
| 156 |
+
" v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
|
| 157 |
+
" # pair (1,2)\n",
|
| 158 |
+
" off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
|
| 159 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 160 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 161 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 162 |
+
" ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
|
| 163 |
+
" ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
|
| 164 |
+
" g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
|
| 165 |
+
" nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
|
| 166 |
+
" nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
|
| 167 |
+
" v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
|
| 168 |
+
" # Sort eigenvalues descending + permute V columns\n",
|
| 169 |
+
" s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
|
| 170 |
+
" do_swap = s0 < s1\n",
|
| 171 |
+
" s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
|
| 172 |
+
" tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
|
| 173 |
+
" tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
|
| 174 |
+
" tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
|
| 175 |
+
" do_swap = s0 < s2\n",
|
| 176 |
+
" s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
|
| 177 |
+
" tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
|
| 178 |
+
" tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
|
| 179 |
+
" tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
|
| 180 |
+
" do_swap = s1 < s2\n",
|
| 181 |
+
" s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
|
| 182 |
+
" tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
|
| 183 |
+
" tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
|
| 184 |
+
" tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
|
| 185 |
+
" # Write S\n",
|
| 186 |
+
" s_base = bid * 3\n",
|
| 187 |
+
" tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
|
| 188 |
+
" # Write Vh = V^T\n",
|
| 189 |
+
" vh_base = bid * 9\n",
|
| 190 |
+
" tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
|
| 191 |
+
" tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
|
| 192 |
+
" tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
|
| 193 |
+
" # Stage 3: U = A @ V @ diag(1/S)\n",
|
| 194 |
+
" inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
|
| 195 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 196 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 197 |
+
" row_idx = block_start + offs\n",
|
| 198 |
+
" mask = row_idx < M\n",
|
| 199 |
+
" a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 200 |
+
" a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 201 |
+
" a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 202 |
+
" u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
|
| 203 |
+
" u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
|
| 204 |
+
" u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
|
| 205 |
+
" u_base = bid * M * 3\n",
|
| 206 |
+
" tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
|
| 207 |
+
" tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
|
| 208 |
+
" tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
|
| 209 |
+
"\n",
|
| 210 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 211 |
+
" \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
|
| 212 |
+
" assert A.ndim == 3 and A.shape[2] == 3\n",
|
| 213 |
+
" B, M, _ = A.shape\n",
|
| 214 |
+
" A_f32 = A.contiguous().float()\n",
|
| 215 |
+
" U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
|
| 216 |
+
" S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
|
| 217 |
+
" Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
|
| 218 |
+
" _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
|
| 219 |
+
" return U, S, Vh\n",
|
| 220 |
+
"\n",
|
| 221 |
+
" _HAS_TRITON_SVD3 = True\n",
|
| 222 |
+
" print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
|
| 223 |
+
"except ImportError:\n",
|
| 224 |
+
" print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
|
| 225 |
+
"\n",
|
| 226 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 227 |
+
" \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
|
| 228 |
+
" return torch.linalg.svd(A.float(), full_matrices=False)\n",
|
| 229 |
+
"\n",
|
| 230 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 231 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 232 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
spectral/notebooks/experiment_4_invertible_transforms.ipynb
CHANGED
|
@@ -39,7 +39,7 @@
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
-
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
@@ -92,6 +92,142 @@
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 96 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 97 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
|
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
+
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
| 95 |
+
"# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 96 |
+
"# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
|
| 97 |
+
"# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
|
| 98 |
+
"# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
|
| 99 |
+
"_HAS_TRITON_SVD3 = False\n",
|
| 100 |
+
"try:\n",
|
| 101 |
+
" import triton\n",
|
| 102 |
+
" import triton.language as tl\n",
|
| 103 |
+
"\n",
|
| 104 |
+
" @triton.jit\n",
|
| 105 |
+
" def _svd3_kernel(\n",
|
| 106 |
+
" A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
|
| 107 |
+
" M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
|
| 108 |
+
" JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
|
| 109 |
+
" ):\n",
|
| 110 |
+
" bid = tl.program_id(0)\n",
|
| 111 |
+
" # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
|
| 112 |
+
" g00 = tl.zeros([], dtype=tl.float32)\n",
|
| 113 |
+
" g01 = tl.zeros([], dtype=tl.float32)\n",
|
| 114 |
+
" g02 = tl.zeros([], dtype=tl.float32)\n",
|
| 115 |
+
" g11 = tl.zeros([], dtype=tl.float32)\n",
|
| 116 |
+
" g12 = tl.zeros([], dtype=tl.float32)\n",
|
| 117 |
+
" g22 = tl.zeros([], dtype=tl.float32)\n",
|
| 118 |
+
" base = bid * M * 3\n",
|
| 119 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 120 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 121 |
+
" row_idx = block_start + offs\n",
|
| 122 |
+
" mask = row_idx < M\n",
|
| 123 |
+
" ptr0 = base + row_idx * 3 + 0\n",
|
| 124 |
+
" ptr1 = base + row_idx * 3 + 1\n",
|
| 125 |
+
" ptr2 = base + row_idx * 3 + 2\n",
|
| 126 |
+
" a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 127 |
+
" a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 128 |
+
" a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 129 |
+
" g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
|
| 130 |
+
" g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
|
| 131 |
+
" # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
|
| 132 |
+
" v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
|
| 133 |
+
" v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
|
| 134 |
+
" v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
|
| 135 |
+
" for _sweep in range(JACOBI_ITERS):\n",
|
| 136 |
+
" # pair (0,1)\n",
|
| 137 |
+
" off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
|
| 138 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 139 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 140 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 141 |
+
" ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
|
| 142 |
+
" ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
|
| 143 |
+
" g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
|
| 144 |
+
" nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
|
| 145 |
+
" nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
|
| 146 |
+
" v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
|
| 147 |
+
" # pair (0,2)\n",
|
| 148 |
+
" off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
|
| 149 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 150 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 151 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 152 |
+
" ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
|
| 153 |
+
" ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
|
| 154 |
+
" g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
|
| 155 |
+
" nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
|
| 156 |
+
" nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
|
| 157 |
+
" v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
|
| 158 |
+
" # pair (1,2)\n",
|
| 159 |
+
" off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
|
| 160 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 161 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 162 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 163 |
+
" ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
|
| 164 |
+
" ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
|
| 165 |
+
" g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
|
| 166 |
+
" nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
|
| 167 |
+
" nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
|
| 168 |
+
" v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
|
| 169 |
+
" # Sort eigenvalues descending + permute V columns\n",
|
| 170 |
+
" s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
|
| 171 |
+
" do_swap = s0 < s1\n",
|
| 172 |
+
" s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
|
| 173 |
+
" tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
|
| 174 |
+
" tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
|
| 175 |
+
" tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
|
| 176 |
+
" do_swap = s0 < s2\n",
|
| 177 |
+
" s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
|
| 178 |
+
" tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
|
| 179 |
+
" tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
|
| 180 |
+
" tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
|
| 181 |
+
" do_swap = s1 < s2\n",
|
| 182 |
+
" s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
|
| 183 |
+
" tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
|
| 184 |
+
" tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
|
| 185 |
+
" tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
|
| 186 |
+
" # Write S\n",
|
| 187 |
+
" s_base = bid * 3\n",
|
| 188 |
+
" tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
|
| 189 |
+
" # Write Vh = V^T\n",
|
| 190 |
+
" vh_base = bid * 9\n",
|
| 191 |
+
" tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
|
| 192 |
+
" tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
|
| 193 |
+
" tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
|
| 194 |
+
" # Stage 3: U = A @ V @ diag(1/S)\n",
|
| 195 |
+
" inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
|
| 196 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 197 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 198 |
+
" row_idx = block_start + offs\n",
|
| 199 |
+
" mask = row_idx < M\n",
|
| 200 |
+
" a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 201 |
+
" a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 202 |
+
" a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 203 |
+
" u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
|
| 204 |
+
" u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
|
| 205 |
+
" u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
|
| 206 |
+
" u_base = bid * M * 3\n",
|
| 207 |
+
" tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
|
| 208 |
+
" tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
|
| 209 |
+
" tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 212 |
+
" \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
|
| 213 |
+
" assert A.ndim == 3 and A.shape[2] == 3\n",
|
| 214 |
+
" B, M, _ = A.shape\n",
|
| 215 |
+
" A_f32 = A.contiguous().float()\n",
|
| 216 |
+
" U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
|
| 217 |
+
" S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
|
| 218 |
+
" Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
|
| 219 |
+
" _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
|
| 220 |
+
" return U, S, Vh\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" _HAS_TRITON_SVD3 = True\n",
|
| 223 |
+
" print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
|
| 224 |
+
"except ImportError:\n",
|
| 225 |
+
" print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 228 |
+
" \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
|
| 229 |
+
" return torch.linalg.svd(A.float(), full_matrices=False)\n",
|
| 230 |
+
"\n",
|
| 231 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 232 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 233 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
spectral/notebooks/experiment_5_matrix_decompositions.ipynb
CHANGED
|
@@ -39,7 +39,7 @@
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
-
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
@@ -92,6 +92,142 @@
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 96 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 97 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
|
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
+
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
| 95 |
+
"# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 96 |
+
"# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
|
| 97 |
+
"# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
|
| 98 |
+
"# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
|
| 99 |
+
"_HAS_TRITON_SVD3 = False\n",
|
| 100 |
+
"try:\n",
|
| 101 |
+
" import triton\n",
|
| 102 |
+
" import triton.language as tl\n",
|
| 103 |
+
"\n",
|
| 104 |
+
" @triton.jit\n",
|
| 105 |
+
" def _svd3_kernel(\n",
|
| 106 |
+
" A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
|
| 107 |
+
" M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
|
| 108 |
+
" JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
|
| 109 |
+
" ):\n",
|
| 110 |
+
" bid = tl.program_id(0)\n",
|
| 111 |
+
" # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
|
| 112 |
+
" g00 = tl.zeros([], dtype=tl.float32)\n",
|
| 113 |
+
" g01 = tl.zeros([], dtype=tl.float32)\n",
|
| 114 |
+
" g02 = tl.zeros([], dtype=tl.float32)\n",
|
| 115 |
+
" g11 = tl.zeros([], dtype=tl.float32)\n",
|
| 116 |
+
" g12 = tl.zeros([], dtype=tl.float32)\n",
|
| 117 |
+
" g22 = tl.zeros([], dtype=tl.float32)\n",
|
| 118 |
+
" base = bid * M * 3\n",
|
| 119 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 120 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 121 |
+
" row_idx = block_start + offs\n",
|
| 122 |
+
" mask = row_idx < M\n",
|
| 123 |
+
" ptr0 = base + row_idx * 3 + 0\n",
|
| 124 |
+
" ptr1 = base + row_idx * 3 + 1\n",
|
| 125 |
+
" ptr2 = base + row_idx * 3 + 2\n",
|
| 126 |
+
" a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 127 |
+
" a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 128 |
+
" a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 129 |
+
" g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
|
| 130 |
+
" g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
|
| 131 |
+
" # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
|
| 132 |
+
" v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
|
| 133 |
+
" v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
|
| 134 |
+
" v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
|
| 135 |
+
" for _sweep in range(JACOBI_ITERS):\n",
|
| 136 |
+
" # pair (0,1)\n",
|
| 137 |
+
" off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
|
| 138 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 139 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 140 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 141 |
+
" ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
|
| 142 |
+
" ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
|
| 143 |
+
" g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
|
| 144 |
+
" nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
|
| 145 |
+
" nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
|
| 146 |
+
" v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
|
| 147 |
+
" # pair (0,2)\n",
|
| 148 |
+
" off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
|
| 149 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 150 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 151 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 152 |
+
" ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
|
| 153 |
+
" ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
|
| 154 |
+
" g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
|
| 155 |
+
" nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
|
| 156 |
+
" nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
|
| 157 |
+
" v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
|
| 158 |
+
" # pair (1,2)\n",
|
| 159 |
+
" off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
|
| 160 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 161 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 162 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 163 |
+
" ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
|
| 164 |
+
" ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
|
| 165 |
+
" g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
|
| 166 |
+
" nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
|
| 167 |
+
" nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
|
| 168 |
+
" v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
|
| 169 |
+
" # Sort eigenvalues descending + permute V columns\n",
|
| 170 |
+
" s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
|
| 171 |
+
" do_swap = s0 < s1\n",
|
| 172 |
+
" s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
|
| 173 |
+
" tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
|
| 174 |
+
" tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
|
| 175 |
+
" tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
|
| 176 |
+
" do_swap = s0 < s2\n",
|
| 177 |
+
" s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
|
| 178 |
+
" tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
|
| 179 |
+
" tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
|
| 180 |
+
" tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
|
| 181 |
+
" do_swap = s1 < s2\n",
|
| 182 |
+
" s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
|
| 183 |
+
" tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
|
| 184 |
+
" tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
|
| 185 |
+
" tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
|
| 186 |
+
" # Write S\n",
|
| 187 |
+
" s_base = bid * 3\n",
|
| 188 |
+
" tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
|
| 189 |
+
" # Write Vh = V^T\n",
|
| 190 |
+
" vh_base = bid * 9\n",
|
| 191 |
+
" tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
|
| 192 |
+
" tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
|
| 193 |
+
" tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
|
| 194 |
+
" # Stage 3: U = A @ V @ diag(1/S)\n",
|
| 195 |
+
" inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
|
| 196 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 197 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 198 |
+
" row_idx = block_start + offs\n",
|
| 199 |
+
" mask = row_idx < M\n",
|
| 200 |
+
" a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 201 |
+
" a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 202 |
+
" a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 203 |
+
" u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
|
| 204 |
+
" u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
|
| 205 |
+
" u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
|
| 206 |
+
" u_base = bid * M * 3\n",
|
| 207 |
+
" tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
|
| 208 |
+
" tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
|
| 209 |
+
" tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 212 |
+
" \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
|
| 213 |
+
" assert A.ndim == 3 and A.shape[2] == 3\n",
|
| 214 |
+
" B, M, _ = A.shape\n",
|
| 215 |
+
" A_f32 = A.contiguous().float()\n",
|
| 216 |
+
" U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
|
| 217 |
+
" S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
|
| 218 |
+
" Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
|
| 219 |
+
" _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
|
| 220 |
+
" return U, S, Vh\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" _HAS_TRITON_SVD3 = True\n",
|
| 223 |
+
" print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
|
| 224 |
+
"except ImportError:\n",
|
| 225 |
+
" print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 228 |
+
" \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
|
| 229 |
+
" return torch.linalg.svd(A.float(), full_matrices=False)\n",
|
| 230 |
+
"\n",
|
| 231 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 232 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 233 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
spectral/notebooks/experiment_6_losses_and_anchors.ipynb
CHANGED
|
@@ -41,7 +41,7 @@
|
|
| 41 |
"metadata": {},
|
| 42 |
"source": [
|
| 43 |
"# @title Install Dependencies\n",
|
| 44 |
-
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
|
| 45 |
"%load_ext tensorboard\n",
|
| 46 |
"import torch\n",
|
| 47 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
@@ -94,6 +94,142 @@
|
|
| 94 |
"if device.type == \"cuda\":\n",
|
| 95 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 96 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 98 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 99 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
|
|
|
| 41 |
"metadata": {},
|
| 42 |
"source": [
|
| 43 |
"# @title Install Dependencies\n",
|
| 44 |
+
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
|
| 45 |
"%load_ext tensorboard\n",
|
| 46 |
"import torch\n",
|
| 47 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
|
|
| 94 |
"if device.type == \"cuda\":\n",
|
| 95 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 96 |
"\n",
|
| 97 |
+
"# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 98 |
+
"# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
|
| 99 |
+
"# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
|
| 100 |
+
"# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
|
| 101 |
+
"_HAS_TRITON_SVD3 = False\n",
|
| 102 |
+
"try:\n",
|
| 103 |
+
" import triton\n",
|
| 104 |
+
" import triton.language as tl\n",
|
| 105 |
+
"\n",
|
| 106 |
+
" @triton.jit\n",
|
| 107 |
+
" def _svd3_kernel(\n",
|
| 108 |
+
" A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
|
| 109 |
+
" M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
|
| 110 |
+
" JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
|
| 111 |
+
" ):\n",
|
| 112 |
+
" bid = tl.program_id(0)\n",
|
| 113 |
+
" # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
|
| 114 |
+
" g00 = tl.zeros([], dtype=tl.float32)\n",
|
| 115 |
+
" g01 = tl.zeros([], dtype=tl.float32)\n",
|
| 116 |
+
" g02 = tl.zeros([], dtype=tl.float32)\n",
|
| 117 |
+
" g11 = tl.zeros([], dtype=tl.float32)\n",
|
| 118 |
+
" g12 = tl.zeros([], dtype=tl.float32)\n",
|
| 119 |
+
" g22 = tl.zeros([], dtype=tl.float32)\n",
|
| 120 |
+
" base = bid * M * 3\n",
|
| 121 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 122 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 123 |
+
" row_idx = block_start + offs\n",
|
| 124 |
+
" mask = row_idx < M\n",
|
| 125 |
+
" ptr0 = base + row_idx * 3 + 0\n",
|
| 126 |
+
" ptr1 = base + row_idx * 3 + 1\n",
|
| 127 |
+
" ptr2 = base + row_idx * 3 + 2\n",
|
| 128 |
+
" a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 129 |
+
" a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 130 |
+
" a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 131 |
+
" g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
|
| 132 |
+
" g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
|
| 133 |
+
" # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
|
| 134 |
+
" v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
|
| 135 |
+
" v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
|
| 136 |
+
" v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
|
| 137 |
+
" for _sweep in range(JACOBI_ITERS):\n",
|
| 138 |
+
" # pair (0,1)\n",
|
| 139 |
+
" off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
|
| 140 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 141 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 142 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 143 |
+
" ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
|
| 144 |
+
" ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
|
| 145 |
+
" g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
|
| 146 |
+
" nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
|
| 147 |
+
" nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
|
| 148 |
+
" v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
|
| 149 |
+
" # pair (0,2)\n",
|
| 150 |
+
" off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
|
| 151 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 152 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 153 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 154 |
+
" ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
|
| 155 |
+
" ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
|
| 156 |
+
" g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
|
| 157 |
+
" nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
|
| 158 |
+
" nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
|
| 159 |
+
" v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
|
| 160 |
+
" # pair (1,2)\n",
|
| 161 |
+
" off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
|
| 162 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 163 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 164 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 165 |
+
" ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
|
| 166 |
+
" ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
|
| 167 |
+
" g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
|
| 168 |
+
" nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
|
| 169 |
+
" nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
|
| 170 |
+
" v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
|
| 171 |
+
" # Sort eigenvalues descending + permute V columns\n",
|
| 172 |
+
" s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
|
| 173 |
+
" do_swap = s0 < s1\n",
|
| 174 |
+
" s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
|
| 175 |
+
" tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
|
| 176 |
+
" tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
|
| 177 |
+
" tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
|
| 178 |
+
" do_swap = s0 < s2\n",
|
| 179 |
+
" s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
|
| 180 |
+
" tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
|
| 181 |
+
" tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
|
| 182 |
+
" tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
|
| 183 |
+
" do_swap = s1 < s2\n",
|
| 184 |
+
" s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
|
| 185 |
+
" tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
|
| 186 |
+
" tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
|
| 187 |
+
" tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
|
| 188 |
+
" # Write S\n",
|
| 189 |
+
" s_base = bid * 3\n",
|
| 190 |
+
" tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
|
| 191 |
+
" # Write Vh = V^T\n",
|
| 192 |
+
" vh_base = bid * 9\n",
|
| 193 |
+
" tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
|
| 194 |
+
" tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
|
| 195 |
+
" tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
|
| 196 |
+
" # Stage 3: U = A @ V @ diag(1/S)\n",
|
| 197 |
+
" inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
|
| 198 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 199 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 200 |
+
" row_idx = block_start + offs\n",
|
| 201 |
+
" mask = row_idx < M\n",
|
| 202 |
+
" a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 203 |
+
" a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 204 |
+
" a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 205 |
+
" u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
|
| 206 |
+
" u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
|
| 207 |
+
" u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
|
| 208 |
+
" u_base = bid * M * 3\n",
|
| 209 |
+
" tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
|
| 210 |
+
" tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
|
| 211 |
+
" tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
|
| 212 |
+
"\n",
|
| 213 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 214 |
+
" \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
|
| 215 |
+
" assert A.ndim == 3 and A.shape[2] == 3\n",
|
| 216 |
+
" B, M, _ = A.shape\n",
|
| 217 |
+
" A_f32 = A.contiguous().float()\n",
|
| 218 |
+
" U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
|
| 219 |
+
" S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
|
| 220 |
+
" Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
|
| 221 |
+
" _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
|
| 222 |
+
" return U, S, Vh\n",
|
| 223 |
+
"\n",
|
| 224 |
+
" _HAS_TRITON_SVD3 = True\n",
|
| 225 |
+
" print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
|
| 226 |
+
"except ImportError:\n",
|
| 227 |
+
" print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
|
| 228 |
+
"\n",
|
| 229 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 230 |
+
" \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
|
| 231 |
+
" return torch.linalg.svd(A.float(), full_matrices=False)\n",
|
| 232 |
+
"\n",
|
| 233 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 234 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 235 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
spectral/notebooks/experiment_7_composite_pipelines.ipynb
CHANGED
|
@@ -39,7 +39,7 @@
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
-
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
@@ -92,6 +92,142 @@
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 96 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 97 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
|
|
|
| 39 |
"metadata": {},
|
| 40 |
"source": [
|
| 41 |
"# @title Install Dependencies\n",
|
| 42 |
+
"!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
|
| 43 |
"%load_ext tensorboard\n",
|
| 44 |
"import torch\n",
|
| 45 |
"print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
|
|
|
|
| 92 |
"if device.type == \"cuda\":\n",
|
| 93 |
" print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
|
| 94 |
"\n",
|
| 95 |
+
"# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
|
| 96 |
+
"# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
|
| 97 |
+
"# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
|
| 98 |
+
"# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
|
| 99 |
+
"_HAS_TRITON_SVD3 = False\n",
|
| 100 |
+
"try:\n",
|
| 101 |
+
" import triton\n",
|
| 102 |
+
" import triton.language as tl\n",
|
| 103 |
+
"\n",
|
| 104 |
+
" @triton.jit\n",
|
| 105 |
+
" def _svd3_kernel(\n",
|
| 106 |
+
" A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
|
| 107 |
+
" M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
|
| 108 |
+
" JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
|
| 109 |
+
" ):\n",
|
| 110 |
+
" bid = tl.program_id(0)\n",
|
| 111 |
+
" # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
|
| 112 |
+
" g00 = tl.zeros([], dtype=tl.float32)\n",
|
| 113 |
+
" g01 = tl.zeros([], dtype=tl.float32)\n",
|
| 114 |
+
" g02 = tl.zeros([], dtype=tl.float32)\n",
|
| 115 |
+
" g11 = tl.zeros([], dtype=tl.float32)\n",
|
| 116 |
+
" g12 = tl.zeros([], dtype=tl.float32)\n",
|
| 117 |
+
" g22 = tl.zeros([], dtype=tl.float32)\n",
|
| 118 |
+
" base = bid * M * 3\n",
|
| 119 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 120 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 121 |
+
" row_idx = block_start + offs\n",
|
| 122 |
+
" mask = row_idx < M\n",
|
| 123 |
+
" ptr0 = base + row_idx * 3 + 0\n",
|
| 124 |
+
" ptr1 = base + row_idx * 3 + 1\n",
|
| 125 |
+
" ptr2 = base + row_idx * 3 + 2\n",
|
| 126 |
+
" a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 127 |
+
" a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 128 |
+
" a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 129 |
+
" g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
|
| 130 |
+
" g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
|
| 131 |
+
" # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
|
| 132 |
+
" v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
|
| 133 |
+
" v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
|
| 134 |
+
" v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
|
| 135 |
+
" for _sweep in range(JACOBI_ITERS):\n",
|
| 136 |
+
" # pair (0,1)\n",
|
| 137 |
+
" off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
|
| 138 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 139 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 140 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 141 |
+
" ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
|
| 142 |
+
" ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
|
| 143 |
+
" g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
|
| 144 |
+
" nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
|
| 145 |
+
" nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
|
| 146 |
+
" v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
|
| 147 |
+
" # pair (0,2)\n",
|
| 148 |
+
" off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
|
| 149 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 150 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 151 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 152 |
+
" ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
|
| 153 |
+
" ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
|
| 154 |
+
" g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
|
| 155 |
+
" nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
|
| 156 |
+
" nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
|
| 157 |
+
" v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
|
| 158 |
+
" # pair (1,2)\n",
|
| 159 |
+
" off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
|
| 160 |
+
" tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
|
| 161 |
+
" t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
|
| 162 |
+
" c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
|
| 163 |
+
" ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
|
| 164 |
+
" ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
|
| 165 |
+
" g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
|
| 166 |
+
" nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
|
| 167 |
+
" nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
|
| 168 |
+
" v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
|
| 169 |
+
" # Sort eigenvalues descending + permute V columns\n",
|
| 170 |
+
" s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
|
| 171 |
+
" do_swap = s0 < s1\n",
|
| 172 |
+
" s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
|
| 173 |
+
" tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
|
| 174 |
+
" tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
|
| 175 |
+
" tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
|
| 176 |
+
" do_swap = s0 < s2\n",
|
| 177 |
+
" s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
|
| 178 |
+
" tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
|
| 179 |
+
" tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
|
| 180 |
+
" tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
|
| 181 |
+
" do_swap = s1 < s2\n",
|
| 182 |
+
" s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
|
| 183 |
+
" tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
|
| 184 |
+
" tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
|
| 185 |
+
" tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
|
| 186 |
+
" # Write S\n",
|
| 187 |
+
" s_base = bid * 3\n",
|
| 188 |
+
" tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
|
| 189 |
+
" # Write Vh = V^T\n",
|
| 190 |
+
" vh_base = bid * 9\n",
|
| 191 |
+
" tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
|
| 192 |
+
" tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
|
| 193 |
+
" tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
|
| 194 |
+
" # Stage 3: U = A @ V @ diag(1/S)\n",
|
| 195 |
+
" inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
|
| 196 |
+
" for block_start in range(0, M, BLOCK_M):\n",
|
| 197 |
+
" offs = tl.arange(0, BLOCK_M)\n",
|
| 198 |
+
" row_idx = block_start + offs\n",
|
| 199 |
+
" mask = row_idx < M\n",
|
| 200 |
+
" a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
|
| 201 |
+
" a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
|
| 202 |
+
" a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
|
| 203 |
+
" u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
|
| 204 |
+
" u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
|
| 205 |
+
" u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
|
| 206 |
+
" u_base = bid * M * 3\n",
|
| 207 |
+
" tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
|
| 208 |
+
" tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
|
| 209 |
+
" tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 212 |
+
" \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
|
| 213 |
+
" assert A.ndim == 3 and A.shape[2] == 3\n",
|
| 214 |
+
" B, M, _ = A.shape\n",
|
| 215 |
+
" A_f32 = A.contiguous().float()\n",
|
| 216 |
+
" U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
|
| 217 |
+
" S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
|
| 218 |
+
" Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
|
| 219 |
+
" _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
|
| 220 |
+
" return U, S, Vh\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" _HAS_TRITON_SVD3 = True\n",
|
| 223 |
+
" print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
|
| 224 |
+
"except ImportError:\n",
|
| 225 |
+
" print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
|
| 226 |
+
"\n",
|
| 227 |
+
" def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
|
| 228 |
+
" \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
|
| 229 |
+
" return torch.linalg.svd(A.float(), full_matrices=False)\n",
|
| 230 |
+
"\n",
|
| 231 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|
| 232 |
"# GEOLIP CORE \u2014 Geometric Building Blocks\n",
|
| 233 |
"# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
|