AbstractPhil commited on
Commit
a9685fb
·
verified ·
1 Parent(s): 9a74626

Upload 6 files

Browse files
spectral/notebooks/experiment_2_manifold_structures.ipynb CHANGED
@@ -39,7 +39,7 @@
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
- "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
@@ -92,6 +92,142 @@
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
96
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
97
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
@@ -1766,12 +1902,10 @@
1766
  "source": [
1767
  "# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
1768
  "class GrassmannianFrontEnd(nn.Module):\n",
1769
- " \"\"\"Grassmannian subspace features via direct SVD.\n",
1770
- " X = U S Vh. Features: singular values (spectral profile),\n",
1771
- " log singular value ratios (relative spectrum), and right singular\n",
1772
- " vectors V (subspace orientation in channel space \u2014 the actual\n",
1773
- " Grassmannian coordinate). V varies per patch and encodes which\n",
1774
- " linear combinations of RGB correspond to principal directions.\"\"\"\n",
1775
  " def __init__(self, patch_size=8, k=3, input_size=32):\n",
1776
  " super().__init__()\n",
1777
  " self.patch_size = patch_size\n",
@@ -1782,7 +1916,8 @@
1782
  " # k singular values + k log-ratios + k*C right singular vector entries\n",
1783
  " self.features_per_patch = k + k + k * self.C\n",
1784
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1785
- " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim} (direct SVD)\")\n",
 
1786
  "\n",
1787
  " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1788
  " def forward(self, x):\n",
@@ -1792,8 +1927,8 @@
1792
  " n_p = patches.shape[2] * patches.shape[3]\n",
1793
  " # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
1794
  " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1795
- " # Direct thin SVD: X = U S Vh, U:(N,64,3) S:(N,3) Vh:(N,3,3)\n",
1796
- " U, S, Vh = torch.linalg.svd(X, full_matrices=False)\n",
1797
  " S = S[:, :self.k]\n",
1798
  " # Log singular value ratios (scale-invariant spectrum)\n",
1799
  " sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
@@ -1834,7 +1969,8 @@
1834
  "source": [
1835
  "# @title Experiment 2.3 \u2014 Flag Manifold\n",
1836
  "class FlagManifoldFrontEnd(nn.Module):\n",
1837
- " \"\"\"Cascading SVD at multiple truncation levels via direct SVD.\n",
 
1838
  " Nested subspace features: singular values + projection norms at each flag level.\n",
1839
  " The flag structure captures how information distributes across\n",
1840
  " nested subspace hierarchies \u2014 a genuine flag manifold signature.\"\"\"\n",
@@ -1847,7 +1983,8 @@
1847
  " max_sv = min(3, patch_size * patch_size)\n",
1848
  " self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
1849
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1850
- " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim} (direct SVD)\")\n",
 
1851
  "\n",
1852
  " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1853
  " def forward(self, x):\n",
@@ -1857,8 +1994,8 @@
1857
  " n_p = patches.shape[2] * patches.shape[3]\n",
1858
  " # X: (B*n_p, ps*ps, C)\n",
1859
  " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1860
- " # Direct thin SVD\n",
1861
- " U, S, Vh = torch.linalg.svd(X, full_matrices=False)\n",
1862
  " # Features at each flag level\n",
1863
  " feats = []\n",
1864
  " for k in self.levels:\n",
 
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
+ "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
 
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
95
+ "# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
96
+ "# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
97
+ "# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
98
+ "# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
99
+ "_HAS_TRITON_SVD3 = False\n",
100
+ "try:\n",
101
+ " import triton\n",
102
+ " import triton.language as tl\n",
103
+ "\n",
104
+ " @triton.jit\n",
105
+ " def _svd3_kernel(\n",
106
+ " A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
107
+ " M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
108
+ " JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
109
+ " ):\n",
110
+ " bid = tl.program_id(0)\n",
111
+ " # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
112
+ " g00 = tl.zeros([], dtype=tl.float32)\n",
113
+ " g01 = tl.zeros([], dtype=tl.float32)\n",
114
+ " g02 = tl.zeros([], dtype=tl.float32)\n",
115
+ " g11 = tl.zeros([], dtype=tl.float32)\n",
116
+ " g12 = tl.zeros([], dtype=tl.float32)\n",
117
+ " g22 = tl.zeros([], dtype=tl.float32)\n",
118
+ " base = bid * M * 3\n",
119
+ " for block_start in range(0, M, BLOCK_M):\n",
120
+ " offs = tl.arange(0, BLOCK_M)\n",
121
+ " row_idx = block_start + offs\n",
122
+ " mask = row_idx < M\n",
123
+ " ptr0 = base + row_idx * 3 + 0\n",
124
+ " ptr1 = base + row_idx * 3 + 1\n",
125
+ " ptr2 = base + row_idx * 3 + 2\n",
126
+ " a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
127
+ " a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
128
+ " a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
129
+ " g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
130
+ " g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
131
+ " # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
132
+ " v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
133
+ " v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
134
+ " v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
135
+ " for _sweep in range(JACOBI_ITERS):\n",
136
+ " # pair (0,1)\n",
137
+ " off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
138
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
139
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
140
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
141
+ " ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
142
+ " ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
143
+ " g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
144
+ " nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
145
+ " nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
146
+ " v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
147
+ " # pair (0,2)\n",
148
+ " off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
149
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
150
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
151
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
152
+ " ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
153
+ " ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
154
+ " g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
155
+ " nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
156
+ " nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
157
+ " v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
158
+ " # pair (1,2)\n",
159
+ " off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
160
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
161
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
162
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
163
+ " ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
164
+ " ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
165
+ " g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
166
+ " nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
167
+ " nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
168
+ " v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
169
+ " # Sort eigenvalues descending + permute V columns\n",
170
+ " s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
171
+ " do_swap = s0 < s1\n",
172
+ " s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
173
+ " tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
174
+ " tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
175
+ " tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
176
+ " do_swap = s0 < s2\n",
177
+ " s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
178
+ " tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
179
+ " tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
180
+ " tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
181
+ " do_swap = s1 < s2\n",
182
+ " s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
183
+ " tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
184
+ " tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
185
+ " tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
186
+ " # Write S\n",
187
+ " s_base = bid * 3\n",
188
+ " tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
189
+ " # Write Vh = V^T\n",
190
+ " vh_base = bid * 9\n",
191
+ " tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
192
+ " tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
193
+ " tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
194
+ " # Stage 3: U = A @ V @ diag(1/S)\n",
195
+ " inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
196
+ " for block_start in range(0, M, BLOCK_M):\n",
197
+ " offs = tl.arange(0, BLOCK_M)\n",
198
+ " row_idx = block_start + offs\n",
199
+ " mask = row_idx < M\n",
200
+ " a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
201
+ " a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
202
+ " a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
203
+ " u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
204
+ " u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
205
+ " u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
206
+ " u_base = bid * M * 3\n",
207
+ " tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
208
+ " tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
209
+ " tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
210
+ "\n",
211
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
212
+ " \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
213
+ " assert A.ndim == 3 and A.shape[2] == 3\n",
214
+ " B, M, _ = A.shape\n",
215
+ " A_f32 = A.contiguous().float()\n",
216
+ " U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
217
+ " S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
218
+ " Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
219
+ " _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
220
+ " return U, S, Vh\n",
221
+ "\n",
222
+ " _HAS_TRITON_SVD3 = True\n",
223
+ " print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
224
+ "except ImportError:\n",
225
+ " print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
226
+ "\n",
227
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
228
+ " \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
229
+ " return torch.linalg.svd(A.float(), full_matrices=False)\n",
230
+ "\n",
231
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
232
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
233
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
 
1902
  "source": [
1903
  "# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
1904
  "class GrassmannianFrontEnd(nn.Module):\n",
1905
+ " \"\"\"Grassmannian subspace features via SVD.\n",
1906
+ " Uses fused Triton SVD3 kernel (3\u00d73 Jacobi in registers) when available,\n",
1907
+ " falls back to torch.linalg.svd otherwise. Same mathematical decomposition.\n",
1908
+ " Features: singular values, log ratios, right singular vectors V.\"\"\"\n",
 
 
1909
  " def __init__(self, patch_size=8, k=3, input_size=32):\n",
1910
  " super().__init__()\n",
1911
  " self.patch_size = patch_size\n",
 
1916
  " # k singular values + k log-ratios + k*C right singular vector entries\n",
1917
  " self.features_per_patch = k + k + k * self.C\n",
1918
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1919
+ " backend = \"Triton SVD3\" if _HAS_TRITON_SVD3 else \"torch.linalg.svd\"\n",
1920
+ " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim} ({backend})\")\n",
1921
  "\n",
1922
  " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1923
  " def forward(self, x):\n",
 
1927
  " n_p = patches.shape[2] * patches.shape[3]\n",
1928
  " # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
1929
  " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1930
+ " # SVD via fused Triton kernel (or torch fallback)\n",
1931
+ " U, S, Vh = batched_svd3(X)\n",
1932
  " S = S[:, :self.k]\n",
1933
  " # Log singular value ratios (scale-invariant spectrum)\n",
1934
  " sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
 
1969
  "source": [
1970
  "# @title Experiment 2.3 \u2014 Flag Manifold\n",
1971
  "class FlagManifoldFrontEnd(nn.Module):\n",
1972
+ " \"\"\"Cascading SVD at multiple truncation levels.\n",
1973
+ " Uses fused Triton SVD3 kernel when available.\n",
1974
  " Nested subspace features: singular values + projection norms at each flag level.\n",
1975
  " The flag structure captures how information distributes across\n",
1976
  " nested subspace hierarchies \u2014 a genuine flag manifold signature.\"\"\"\n",
 
1983
  " max_sv = min(3, patch_size * patch_size)\n",
1984
  " self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
1985
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1986
+ " backend = \"Triton SVD3\" if _HAS_TRITON_SVD3 else \"torch.linalg.svd\"\n",
1987
+ " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim} ({backend})\")\n",
1988
  "\n",
1989
  " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1990
  " def forward(self, x):\n",
 
1994
  " n_p = patches.shape[2] * patches.shape[3]\n",
1995
  " # X: (B*n_p, ps*ps, C)\n",
1996
  " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1997
+ " # SVD via fused Triton kernel (or torch fallback)\n",
1998
+ " U, S, Vh = batched_svd3(X)\n",
1999
  " # Features at each flag level\n",
2000
  " feats = []\n",
2001
  " for k in self.levels:\n",
spectral/notebooks/experiment_3_compact_representations.ipynb CHANGED
@@ -38,7 +38,7 @@
38
  "metadata": {},
39
  "source": [
40
  "# @title Install Dependencies\n",
41
- "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
42
  "%load_ext tensorboard\n",
43
  "import torch\n",
44
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
@@ -91,6 +91,142 @@
91
  "if device.type == \"cuda\":\n",
92
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
93
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
95
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
96
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
 
38
  "metadata": {},
39
  "source": [
40
  "# @title Install Dependencies\n",
41
+ "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
42
  "%load_ext tensorboard\n",
43
  "import torch\n",
44
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
 
91
  "if device.type == \"cuda\":\n",
92
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
93
  "\n",
94
+ "# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
95
+ "# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
96
+ "# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
97
+ "# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
98
+ "_HAS_TRITON_SVD3 = False\n",
99
+ "try:\n",
100
+ " import triton\n",
101
+ " import triton.language as tl\n",
102
+ "\n",
103
+ " @triton.jit\n",
104
+ " def _svd3_kernel(\n",
105
+ " A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
106
+ " M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
107
+ " JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
108
+ " ):\n",
109
+ " bid = tl.program_id(0)\n",
110
+ " # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
111
+ " g00 = tl.zeros([], dtype=tl.float32)\n",
112
+ " g01 = tl.zeros([], dtype=tl.float32)\n",
113
+ " g02 = tl.zeros([], dtype=tl.float32)\n",
114
+ " g11 = tl.zeros([], dtype=tl.float32)\n",
115
+ " g12 = tl.zeros([], dtype=tl.float32)\n",
116
+ " g22 = tl.zeros([], dtype=tl.float32)\n",
117
+ " base = bid * M * 3\n",
118
+ " for block_start in range(0, M, BLOCK_M):\n",
119
+ " offs = tl.arange(0, BLOCK_M)\n",
120
+ " row_idx = block_start + offs\n",
121
+ " mask = row_idx < M\n",
122
+ " ptr0 = base + row_idx * 3 + 0\n",
123
+ " ptr1 = base + row_idx * 3 + 1\n",
124
+ " ptr2 = base + row_idx * 3 + 2\n",
125
+ " a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
126
+ " a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
127
+ " a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
128
+ " g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
129
+ " g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
130
+ " # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
131
+ " v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
132
+ " v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
133
+ " v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
134
+ " for _sweep in range(JACOBI_ITERS):\n",
135
+ " # pair (0,1)\n",
136
+ " off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
137
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
138
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
139
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
140
+ " ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
141
+ " ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
142
+ " g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
143
+ " nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
144
+ " nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
145
+ " v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
146
+ " # pair (0,2)\n",
147
+ " off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
148
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
149
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
150
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
151
+ " ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
152
+ " ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
153
+ " g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
154
+ " nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
155
+ " nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
156
+ " v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
157
+ " # pair (1,2)\n",
158
+ " off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
159
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
160
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
161
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
162
+ " ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
163
+ " ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
164
+ " g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
165
+ " nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
166
+ " nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
167
+ " v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
168
+ " # Sort eigenvalues descending + permute V columns\n",
169
+ " s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
170
+ " do_swap = s0 < s1\n",
171
+ " s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
172
+ " tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
173
+ " tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
174
+ " tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
175
+ " do_swap = s0 < s2\n",
176
+ " s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
177
+ " tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
178
+ " tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
179
+ " tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
180
+ " do_swap = s1 < s2\n",
181
+ " s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
182
+ " tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
183
+ " tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
184
+ " tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
185
+ " # Write S\n",
186
+ " s_base = bid * 3\n",
187
+ " tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
188
+ " # Write Vh = V^T\n",
189
+ " vh_base = bid * 9\n",
190
+ " tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
191
+ " tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
192
+ " tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
193
+ " # Stage 3: U = A @ V @ diag(1/S)\n",
194
+ " inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
195
+ " for block_start in range(0, M, BLOCK_M):\n",
196
+ " offs = tl.arange(0, BLOCK_M)\n",
197
+ " row_idx = block_start + offs\n",
198
+ " mask = row_idx < M\n",
199
+ " a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
200
+ " a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
201
+ " a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
202
+ " u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
203
+ " u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
204
+ " u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
205
+ " u_base = bid * M * 3\n",
206
+ " tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
207
+ " tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
208
+ " tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
209
+ "\n",
210
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
211
+ " \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
212
+ " assert A.ndim == 3 and A.shape[2] == 3\n",
213
+ " B, M, _ = A.shape\n",
214
+ " A_f32 = A.contiguous().float()\n",
215
+ " U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
216
+ " S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
217
+ " Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
218
+ " _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
219
+ " return U, S, Vh\n",
220
+ "\n",
221
+ " _HAS_TRITON_SVD3 = True\n",
222
+ " print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
223
+ "except ImportError:\n",
224
+ " print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
225
+ "\n",
226
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
227
+ " \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
228
+ " return torch.linalg.svd(A.float(), full_matrices=False)\n",
229
+ "\n",
230
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
231
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
232
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
spectral/notebooks/experiment_4_invertible_transforms.ipynb CHANGED
@@ -39,7 +39,7 @@
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
- "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
@@ -92,6 +92,142 @@
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
96
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
97
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
 
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
+ "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
 
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
95
+ "# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
96
+ "# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
97
+ "# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
98
+ "# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
99
+ "_HAS_TRITON_SVD3 = False\n",
100
+ "try:\n",
101
+ " import triton\n",
102
+ " import triton.language as tl\n",
103
+ "\n",
104
+ " @triton.jit\n",
105
+ " def _svd3_kernel(\n",
106
+ " A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
107
+ " M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
108
+ " JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
109
+ " ):\n",
110
+ " bid = tl.program_id(0)\n",
111
+ " # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
112
+ " g00 = tl.zeros([], dtype=tl.float32)\n",
113
+ " g01 = tl.zeros([], dtype=tl.float32)\n",
114
+ " g02 = tl.zeros([], dtype=tl.float32)\n",
115
+ " g11 = tl.zeros([], dtype=tl.float32)\n",
116
+ " g12 = tl.zeros([], dtype=tl.float32)\n",
117
+ " g22 = tl.zeros([], dtype=tl.float32)\n",
118
+ " base = bid * M * 3\n",
119
+ " for block_start in range(0, M, BLOCK_M):\n",
120
+ " offs = tl.arange(0, BLOCK_M)\n",
121
+ " row_idx = block_start + offs\n",
122
+ " mask = row_idx < M\n",
123
+ " ptr0 = base + row_idx * 3 + 0\n",
124
+ " ptr1 = base + row_idx * 3 + 1\n",
125
+ " ptr2 = base + row_idx * 3 + 2\n",
126
+ " a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
127
+ " a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
128
+ " a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
129
+ " g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
130
+ " g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
131
+ " # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
132
+ " v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
133
+ " v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
134
+ " v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
135
+ " for _sweep in range(JACOBI_ITERS):\n",
136
+ " # pair (0,1)\n",
137
+ " off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
138
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
139
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
140
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
141
+ " ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
142
+ " ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
143
+ " g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
144
+ " nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
145
+ " nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
146
+ " v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
147
+ " # pair (0,2)\n",
148
+ " off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
149
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
150
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
151
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
152
+ " ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
153
+ " ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
154
+ " g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
155
+ " nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
156
+ " nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
157
+ " v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
158
+ " # pair (1,2)\n",
159
+ " off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
160
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
161
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
162
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
163
+ " ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
164
+ " ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
165
+ " g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
166
+ " nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
167
+ " nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
168
+ " v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
169
+ " # Sort eigenvalues descending + permute V columns\n",
170
+ " s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
171
+ " do_swap = s0 < s1\n",
172
+ " s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
173
+ " tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
174
+ " tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
175
+ " tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
176
+ " do_swap = s0 < s2\n",
177
+ " s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
178
+ " tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
179
+ " tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
180
+ " tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
181
+ " do_swap = s1 < s2\n",
182
+ " s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
183
+ " tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
184
+ " tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
185
+ " tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
186
+ " # Write S\n",
187
+ " s_base = bid * 3\n",
188
+ " tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
189
+ " # Write Vh = V^T\n",
190
+ " vh_base = bid * 9\n",
191
+ " tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
192
+ " tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
193
+ " tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
194
+ " # Stage 3: U = A @ V @ diag(1/S)\n",
195
+ " inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
196
+ " for block_start in range(0, M, BLOCK_M):\n",
197
+ " offs = tl.arange(0, BLOCK_M)\n",
198
+ " row_idx = block_start + offs\n",
199
+ " mask = row_idx < M\n",
200
+ " a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
201
+ " a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
202
+ " a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
203
+ " u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
204
+ " u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
205
+ " u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
206
+ " u_base = bid * M * 3\n",
207
+ " tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
208
+ " tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
209
+ " tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
210
+ "\n",
211
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
212
+ " \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
213
+ " assert A.ndim == 3 and A.shape[2] == 3\n",
214
+ " B, M, _ = A.shape\n",
215
+ " A_f32 = A.contiguous().float()\n",
216
+ " U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
217
+ " S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
218
+ " Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
219
+ " _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
220
+ " return U, S, Vh\n",
221
+ "\n",
222
+ " _HAS_TRITON_SVD3 = True\n",
223
+ " print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
224
+ "except ImportError:\n",
225
+ " print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
226
+ "\n",
227
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
228
+ " \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
229
+ " return torch.linalg.svd(A.float(), full_matrices=False)\n",
230
+ "\n",
231
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
232
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
233
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
spectral/notebooks/experiment_5_matrix_decompositions.ipynb CHANGED
@@ -39,7 +39,7 @@
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
- "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
@@ -92,6 +92,142 @@
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
96
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
97
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
 
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
+ "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
 
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
95
+ "# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
96
+ "# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
97
+ "# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
98
+ "# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
99
+ "_HAS_TRITON_SVD3 = False\n",
100
+ "try:\n",
101
+ " import triton\n",
102
+ " import triton.language as tl\n",
103
+ "\n",
104
+ " @triton.jit\n",
105
+ " def _svd3_kernel(\n",
106
+ " A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
107
+ " M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
108
+ " JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
109
+ " ):\n",
110
+ " bid = tl.program_id(0)\n",
111
+ " # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
112
+ " g00 = tl.zeros([], dtype=tl.float32)\n",
113
+ " g01 = tl.zeros([], dtype=tl.float32)\n",
114
+ " g02 = tl.zeros([], dtype=tl.float32)\n",
115
+ " g11 = tl.zeros([], dtype=tl.float32)\n",
116
+ " g12 = tl.zeros([], dtype=tl.float32)\n",
117
+ " g22 = tl.zeros([], dtype=tl.float32)\n",
118
+ " base = bid * M * 3\n",
119
+ " for block_start in range(0, M, BLOCK_M):\n",
120
+ " offs = tl.arange(0, BLOCK_M)\n",
121
+ " row_idx = block_start + offs\n",
122
+ " mask = row_idx < M\n",
123
+ " ptr0 = base + row_idx * 3 + 0\n",
124
+ " ptr1 = base + row_idx * 3 + 1\n",
125
+ " ptr2 = base + row_idx * 3 + 2\n",
126
+ " a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
127
+ " a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
128
+ " a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
129
+ " g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
130
+ " g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
131
+ " # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
132
+ " v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
133
+ " v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
134
+ " v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
135
+ " for _sweep in range(JACOBI_ITERS):\n",
136
+ " # pair (0,1)\n",
137
+ " off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
138
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
139
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
140
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
141
+ " ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
142
+ " ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
143
+ " g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
144
+ " nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
145
+ " nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
146
+ " v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
147
+ " # pair (0,2)\n",
148
+ " off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
149
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
150
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
151
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
152
+ " ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
153
+ " ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
154
+ " g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
155
+ " nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
156
+ " nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
157
+ " v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
158
+ " # pair (1,2)\n",
159
+ " off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
160
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
161
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
162
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
163
+ " ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
164
+ " ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
165
+ " g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
166
+ " nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
167
+ " nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
168
+ " v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
169
+ " # Sort eigenvalues descending + permute V columns\n",
170
+ " s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
171
+ " do_swap = s0 < s1\n",
172
+ " s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
173
+ " tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
174
+ " tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
175
+ " tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
176
+ " do_swap = s0 < s2\n",
177
+ " s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
178
+ " tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
179
+ " tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
180
+ " tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
181
+ " do_swap = s1 < s2\n",
182
+ " s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
183
+ " tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
184
+ " tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
185
+ " tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
186
+ " # Write S\n",
187
+ " s_base = bid * 3\n",
188
+ " tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
189
+ " # Write Vh = V^T\n",
190
+ " vh_base = bid * 9\n",
191
+ " tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
192
+ " tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
193
+ " tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
194
+ " # Stage 3: U = A @ V @ diag(1/S)\n",
195
+ " inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
196
+ " for block_start in range(0, M, BLOCK_M):\n",
197
+ " offs = tl.arange(0, BLOCK_M)\n",
198
+ " row_idx = block_start + offs\n",
199
+ " mask = row_idx < M\n",
200
+ " a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
201
+ " a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
202
+ " a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
203
+ " u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
204
+ " u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
205
+ " u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
206
+ " u_base = bid * M * 3\n",
207
+ " tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
208
+ " tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
209
+ " tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
210
+ "\n",
211
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
212
+ " \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
213
+ " assert A.ndim == 3 and A.shape[2] == 3\n",
214
+ " B, M, _ = A.shape\n",
215
+ " A_f32 = A.contiguous().float()\n",
216
+ " U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
217
+ " S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
218
+ " Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
219
+ " _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
220
+ " return U, S, Vh\n",
221
+ "\n",
222
+ " _HAS_TRITON_SVD3 = True\n",
223
+ " print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
224
+ "except ImportError:\n",
225
+ " print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
226
+ "\n",
227
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
228
+ " \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
229
+ " return torch.linalg.svd(A.float(), full_matrices=False)\n",
230
+ "\n",
231
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
232
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
233
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
spectral/notebooks/experiment_6_losses_and_anchors.ipynb CHANGED
@@ -41,7 +41,7 @@
41
  "metadata": {},
42
  "source": [
43
  "# @title Install Dependencies\n",
44
- "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
45
  "%load_ext tensorboard\n",
46
  "import torch\n",
47
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
@@ -94,6 +94,142 @@
94
  "if device.type == \"cuda\":\n",
95
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
96
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
98
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
99
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
 
41
  "metadata": {},
42
  "source": [
43
  "# @title Install Dependencies\n",
44
+ "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
45
  "%load_ext tensorboard\n",
46
  "import torch\n",
47
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
 
94
  "if device.type == \"cuda\":\n",
95
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
96
  "\n",
97
+ "# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
98
+ "# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
99
+ "# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
100
+ "# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
101
+ "_HAS_TRITON_SVD3 = False\n",
102
+ "try:\n",
103
+ " import triton\n",
104
+ " import triton.language as tl\n",
105
+ "\n",
106
+ " @triton.jit\n",
107
+ " def _svd3_kernel(\n",
108
+ " A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
109
+ " M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
110
+ " JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
111
+ " ):\n",
112
+ " bid = tl.program_id(0)\n",
113
+ " # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
114
+ " g00 = tl.zeros([], dtype=tl.float32)\n",
115
+ " g01 = tl.zeros([], dtype=tl.float32)\n",
116
+ " g02 = tl.zeros([], dtype=tl.float32)\n",
117
+ " g11 = tl.zeros([], dtype=tl.float32)\n",
118
+ " g12 = tl.zeros([], dtype=tl.float32)\n",
119
+ " g22 = tl.zeros([], dtype=tl.float32)\n",
120
+ " base = bid * M * 3\n",
121
+ " for block_start in range(0, M, BLOCK_M):\n",
122
+ " offs = tl.arange(0, BLOCK_M)\n",
123
+ " row_idx = block_start + offs\n",
124
+ " mask = row_idx < M\n",
125
+ " ptr0 = base + row_idx * 3 + 0\n",
126
+ " ptr1 = base + row_idx * 3 + 1\n",
127
+ " ptr2 = base + row_idx * 3 + 2\n",
128
+ " a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
129
+ " a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
130
+ " a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
131
+ " g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
132
+ " g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
133
+ " # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
134
+ " v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
135
+ " v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
136
+ " v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
137
+ " for _sweep in range(JACOBI_ITERS):\n",
138
+ " # pair (0,1)\n",
139
+ " off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
140
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
141
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
142
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
143
+ " ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
144
+ " ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
145
+ " g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
146
+ " nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
147
+ " nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
148
+ " v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
149
+ " # pair (0,2)\n",
150
+ " off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
151
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
152
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
153
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
154
+ " ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
155
+ " ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
156
+ " g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
157
+ " nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
158
+ " nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
159
+ " v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
160
+ " # pair (1,2)\n",
161
+ " off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
162
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
163
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
164
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
165
+ " ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
166
+ " ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
167
+ " g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
168
+ " nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
169
+ " nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
170
+ " v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
171
+ " # Sort eigenvalues descending + permute V columns\n",
172
+ " s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
173
+ " do_swap = s0 < s1\n",
174
+ " s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
175
+ " tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
176
+ " tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
177
+ " tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
178
+ " do_swap = s0 < s2\n",
179
+ " s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
180
+ " tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
181
+ " tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
182
+ " tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
183
+ " do_swap = s1 < s2\n",
184
+ " s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
185
+ " tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
186
+ " tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
187
+ " tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
188
+ " # Write S\n",
189
+ " s_base = bid * 3\n",
190
+ " tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
191
+ " # Write Vh = V^T\n",
192
+ " vh_base = bid * 9\n",
193
+ " tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
194
+ " tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
195
+ " tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
196
+ " # Stage 3: U = A @ V @ diag(1/S)\n",
197
+ " inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
198
+ " for block_start in range(0, M, BLOCK_M):\n",
199
+ " offs = tl.arange(0, BLOCK_M)\n",
200
+ " row_idx = block_start + offs\n",
201
+ " mask = row_idx < M\n",
202
+ " a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
203
+ " a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
204
+ " a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
205
+ " u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
206
+ " u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
207
+ " u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
208
+ " u_base = bid * M * 3\n",
209
+ " tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
210
+ " tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
211
+ " tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
212
+ "\n",
213
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
214
+ " \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
215
+ " assert A.ndim == 3 and A.shape[2] == 3\n",
216
+ " B, M, _ = A.shape\n",
217
+ " A_f32 = A.contiguous().float()\n",
218
+ " U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
219
+ " S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
220
+ " Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
221
+ " _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
222
+ " return U, S, Vh\n",
223
+ "\n",
224
+ " _HAS_TRITON_SVD3 = True\n",
225
+ " print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
226
+ "except ImportError:\n",
227
+ " print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
228
+ "\n",
229
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
230
+ " \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
231
+ " return torch.linalg.svd(A.float(), full_matrices=False)\n",
232
+ "\n",
233
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
234
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
235
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
spectral/notebooks/experiment_7_composite_pipelines.ipynb CHANGED
@@ -39,7 +39,7 @@
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
- "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
@@ -92,6 +92,142 @@
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
96
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
97
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
 
39
  "metadata": {},
40
  "source": [
41
  "# @title Install Dependencies\n",
42
+ "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n",
43
  "%load_ext tensorboard\n",
44
  "import torch\n",
45
  "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n",
 
92
  "if device.type == \"cuda\":\n",
93
  " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n",
94
  "\n",
95
+ "# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n",
96
+ "# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n",
97
+ "# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n",
98
+ "# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n",
99
+ "_HAS_TRITON_SVD3 = False\n",
100
+ "try:\n",
101
+ " import triton\n",
102
+ " import triton.language as tl\n",
103
+ "\n",
104
+ " @triton.jit\n",
105
+ " def _svd3_kernel(\n",
106
+ " A_ptr, U_ptr, S_ptr, Vh_ptr,\n",
107
+ " M: tl.constexpr, BLOCK_M: tl.constexpr,\n",
108
+ " JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n",
109
+ " ):\n",
110
+ " bid = tl.program_id(0)\n",
111
+ " # Stage 1: G = A^T A (6 accumulators, symmetric)\n",
112
+ " g00 = tl.zeros([], dtype=tl.float32)\n",
113
+ " g01 = tl.zeros([], dtype=tl.float32)\n",
114
+ " g02 = tl.zeros([], dtype=tl.float32)\n",
115
+ " g11 = tl.zeros([], dtype=tl.float32)\n",
116
+ " g12 = tl.zeros([], dtype=tl.float32)\n",
117
+ " g22 = tl.zeros([], dtype=tl.float32)\n",
118
+ " base = bid * M * 3\n",
119
+ " for block_start in range(0, M, BLOCK_M):\n",
120
+ " offs = tl.arange(0, BLOCK_M)\n",
121
+ " row_idx = block_start + offs\n",
122
+ " mask = row_idx < M\n",
123
+ " ptr0 = base + row_idx * 3 + 0\n",
124
+ " ptr1 = base + row_idx * 3 + 1\n",
125
+ " ptr2 = base + row_idx * 3 + 2\n",
126
+ " a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n",
127
+ " a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n",
128
+ " a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n",
129
+ " g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n",
130
+ " g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n",
131
+ " # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n",
132
+ " v00 = 1.0; v01 = 0.0; v02 = 0.0\n",
133
+ " v10 = 0.0; v11 = 1.0; v12 = 0.0\n",
134
+ " v20 = 0.0; v21 = 0.0; v22 = 1.0\n",
135
+ " for _sweep in range(JACOBI_ITERS):\n",
136
+ " # pair (0,1)\n",
137
+ " off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n",
138
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
139
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
140
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
141
+ " ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n",
142
+ " ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n",
143
+ " g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n",
144
+ " nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n",
145
+ " nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n",
146
+ " v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n",
147
+ " # pair (0,2)\n",
148
+ " off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n",
149
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
150
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
151
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
152
+ " ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n",
153
+ " ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n",
154
+ " g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n",
155
+ " nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n",
156
+ " nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n",
157
+ " v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n",
158
+ " # pair (1,2)\n",
159
+ " off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n",
160
+ " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n",
161
+ " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n",
162
+ " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n",
163
+ " ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n",
164
+ " ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n",
165
+ " g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n",
166
+ " nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n",
167
+ " nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n",
168
+ " v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n",
169
+ " # Sort eigenvalues descending + permute V columns\n",
170
+ " s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n",
171
+ " do_swap = s0 < s1\n",
172
+ " s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n",
173
+ " tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n",
174
+ " tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n",
175
+ " tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n",
176
+ " do_swap = s0 < s2\n",
177
+ " s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n",
178
+ " tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n",
179
+ " tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n",
180
+ " tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n",
181
+ " do_swap = s1 < s2\n",
182
+ " s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n",
183
+ " tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n",
184
+ " tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n",
185
+ " tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n",
186
+ " # Write S\n",
187
+ " s_base = bid * 3\n",
188
+ " tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n",
189
+ " # Write Vh = V^T\n",
190
+ " vh_base = bid * 9\n",
191
+ " tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n",
192
+ " tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n",
193
+ " tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n",
194
+ " # Stage 3: U = A @ V @ diag(1/S)\n",
195
+ " inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n",
196
+ " for block_start in range(0, M, BLOCK_M):\n",
197
+ " offs = tl.arange(0, BLOCK_M)\n",
198
+ " row_idx = block_start + offs\n",
199
+ " mask = row_idx < M\n",
200
+ " a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n",
201
+ " a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n",
202
+ " a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n",
203
+ " u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n",
204
+ " u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n",
205
+ " u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n",
206
+ " u_base = bid * M * 3\n",
207
+ " tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n",
208
+ " tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n",
209
+ " tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n",
210
+ "\n",
211
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
212
+ " \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n",
213
+ " assert A.ndim == 3 and A.shape[2] == 3\n",
214
+ " B, M, _ = A.shape\n",
215
+ " A_f32 = A.contiguous().float()\n",
216
+ " U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n",
217
+ " S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n",
218
+ " Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n",
219
+ " _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n",
220
+ " return U, S, Vh\n",
221
+ "\n",
222
+ " _HAS_TRITON_SVD3 = True\n",
223
+ " print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n",
224
+ "except ImportError:\n",
225
+ " print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n",
226
+ "\n",
227
+ " def batched_svd3(A, block_m=128, jacobi_iters=6):\n",
228
+ " \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n",
229
+ " return torch.linalg.svd(A.float(), full_matrices=False)\n",
230
+ "\n",
231
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
232
  "# GEOLIP CORE \u2014 Geometric Building Blocks\n",
233
  "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",