AbstractPhil commited on
Commit
9a74626
·
verified ·
1 Parent(s): 9414786

Upload 6 files

Browse files
spectral/notebooks/experiment_2_manifold_structures.ipynb CHANGED
@@ -1766,47 +1766,41 @@
1766
  "source": [
1767
  "# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
1768
  "class GrassmannianFrontEnd(nn.Module):\n",
1769
- " \"\"\"Grassmannian subspace features via SVD computed through 3\u00d73 eigh path.\n",
1770
- " X = U S V^T. Since X^T X = V S\u00b2 V^T, we get S,V from eigh(X^T X),\n",
1771
- " then U = X V S\u207b\u00b9. Full SVD result, but via 3\u00d73 eigendecomposition\n",
1772
- " instead of iterative cusolver on (64,3). Instant on GPU.\"\"\"\n",
 
 
1773
  " def __init__(self, patch_size=8, k=3, input_size=32):\n",
1774
  " super().__init__()\n",
1775
  " self.patch_size = patch_size\n",
1776
  " self.k = k\n",
 
1777
  " n_patches_h = input_size // patch_size\n",
1778
  " self.n_patches = n_patches_h * n_patches_h\n",
1779
- " # k singular values + k log-ratios + k*(k+1)/2 gram upper-tri of U\n",
1780
- " self.features_per_patch = k + k + k * (k + 1) // 2\n",
1781
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1782
- " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim} (SVD via 3\u00d73 eigh)\")\n",
1783
  "\n",
 
1784
  " def forward(self, x):\n",
1785
  " B, C, H, W = x.shape\n",
1786
  " ps = self.patch_size\n",
1787
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1788
  " n_p = patches.shape[2] * patches.shape[3]\n",
1789
  " # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
1790
- " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1).float()\n",
1791
- " # X^T X: (B*n_p, C, C) = (B*n_p, 3, 3)\n",
1792
- " XtX = torch.bmm(X.transpose(1, 2), X)\n",
1793
- " # eigh on 3\u00d73 symmetric PSD \u2192 eigenvalues \u03bb (ascending), eigenvectors V\n",
1794
- " eigvals, V = torch.linalg.eigh(XtX) # (B*n_p, C), (B*n_p, C, C)\n",
1795
- " # Flip to descending order\n",
1796
- " eigvals = eigvals.flip(-1)\n",
1797
- " V = V.flip(-1)\n",
1798
- " # S = sqrt(\u03bb), these ARE the singular values\n",
1799
- " S = torch.sqrt(eigvals.clamp(min=1e-10))[:, :self.k]\n",
1800
- " # U = X V S\u207b\u00b9 (recover left singular vectors)\n",
1801
- " Vk = V[:, :, :self.k] # (B*n_p, C, k)\n",
1802
- " XV = torch.bmm(X, Vk) # (B*n_p, ps*ps, k)\n",
1803
- " U = XV / (S.unsqueeze(1) + 1e-10) # (B*n_p, ps*ps, k)\n",
1804
- " # Features: singular values + log-ratios + gram of U\n",
1805
  " sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
1806
- " gram = torch.bmm(U.transpose(1, 2), U) # (B*n_p, k, k)\n",
1807
- " triu_idx = torch.triu_indices(self.k, self.k)\n",
1808
- " gram_feat = gram[:, triu_idx[0], triu_idx[1]]\n",
1809
- " feats = torch.cat([S, sv_ratios, gram_feat], dim=-1)\n",
1810
  " return feats.reshape(B, -1)\n",
1811
  "\n",
1812
  "front = GrassmannianFrontEnd(patch_size=8, k=3).to(device)\n",
@@ -1840,9 +1834,10 @@
1840
  "source": [
1841
  "# @title Experiment 2.3 \u2014 Flag Manifold\n",
1842
  "class FlagManifoldFrontEnd(nn.Module):\n",
1843
- " \"\"\"Cascading SVD at multiple truncation levels via 3\u00d73 eigh path.\n",
1844
  " Nested subspace features: singular values + projection norms at each flag level.\n",
1845
- " SVD computed through eigh(X^T X) \u2014 no iterative cusolver.\"\"\"\n",
 
1846
  " def __init__(self, patch_size=8, levels=(1, 2, 3), input_size=32):\n",
1847
  " super().__init__()\n",
1848
  " self.patch_size = patch_size\n",
@@ -1852,23 +1847,18 @@
1852
  " max_sv = min(3, patch_size * patch_size)\n",
1853
  " self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
1854
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1855
- " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim} (SVD via 3\u00d73 eigh)\")\n",
1856
  "\n",
 
1857
  " def forward(self, x):\n",
1858
  " B, C, H, W = x.shape\n",
1859
  " ps = self.patch_size\n",
1860
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1861
  " n_p = patches.shape[2] * patches.shape[3]\n",
1862
  " # X: (B*n_p, ps*ps, C)\n",
1863
- " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1).float()\n",
1864
- " # X^T X: (B*n_p, 3, 3)\n",
1865
- " XtX = torch.bmm(X.transpose(1, 2), X)\n",
1866
- " eigvals, V = torch.linalg.eigh(XtX)\n",
1867
- " eigvals = eigvals.flip(-1)\n",
1868
- " V = V.flip(-1)\n",
1869
- " S = torch.sqrt(eigvals.clamp(min=1e-10)) # (B*n_p, C)\n",
1870
- " # U = X V S\u207b\u00b9\n",
1871
- " U = torch.bmm(X, V) / (S.unsqueeze(1) + 1e-10) # (B*n_p, ps*ps, C)\n",
1872
  " # Features at each flag level\n",
1873
  " feats = []\n",
1874
  " for k in self.levels:\n",
 
1766
  "source": [
1767
  "# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
1768
  "class GrassmannianFrontEnd(nn.Module):\n",
1769
+ " \"\"\"Grassmannian subspace features via direct SVD.\n",
1770
+ " X = U S Vh. Features: singular values (spectral profile),\n",
1771
+ " log singular value ratios (relative spectrum), and right singular\n",
1772
+ " vectors V (subspace orientation in channel space \u2014 the actual\n",
1773
+ " Grassmannian coordinate). V varies per patch and encodes which\n",
1774
+ " linear combinations of RGB correspond to principal directions.\"\"\"\n",
1775
  " def __init__(self, patch_size=8, k=3, input_size=32):\n",
1776
  " super().__init__()\n",
1777
  " self.patch_size = patch_size\n",
1778
  " self.k = k\n",
1779
+ " self.C = 3 # RGB\n",
1780
  " n_patches_h = input_size // patch_size\n",
1781
  " self.n_patches = n_patches_h * n_patches_h\n",
1782
+ " # k singular values + k log-ratios + k*C right singular vector entries\n",
1783
+ " self.features_per_patch = k + k + k * self.C\n",
1784
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1785
+ " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim} (direct SVD)\")\n",
1786
  "\n",
1787
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1788
  " def forward(self, x):\n",
1789
  " B, C, H, W = x.shape\n",
1790
  " ps = self.patch_size\n",
1791
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1792
  " n_p = patches.shape[2] * patches.shape[3]\n",
1793
  " # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
1794
+ " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1795
+ " # Direct thin SVD: X = U S Vh, U:(N,64,3) S:(N,3) Vh:(N,3,3)\n",
1796
+ " U, S, Vh = torch.linalg.svd(X, full_matrices=False)\n",
1797
+ " S = S[:, :self.k]\n",
1798
+ " # Log singular value ratios (scale-invariant spectrum)\n",
 
 
 
 
 
 
 
 
 
 
1799
  " sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
1800
+ " # Right singular vectors Vh[:k]: subspace orientation in channel space\n",
1801
+ " # This IS the Grassmannian coordinate \u2014 varies meaningfully per patch\n",
1802
+ " V_feat = Vh[:, :self.k, :].reshape(-1, self.k * C)\n",
1803
+ " feats = torch.cat([S, sv_ratios, V_feat], dim=-1)\n",
1804
  " return feats.reshape(B, -1)\n",
1805
  "\n",
1806
  "front = GrassmannianFrontEnd(patch_size=8, k=3).to(device)\n",
 
1834
  "source": [
1835
  "# @title Experiment 2.3 \u2014 Flag Manifold\n",
1836
  "class FlagManifoldFrontEnd(nn.Module):\n",
1837
+ " \"\"\"Cascading SVD at multiple truncation levels via direct SVD.\n",
1838
  " Nested subspace features: singular values + projection norms at each flag level.\n",
1839
+ " The flag structure captures how information distributes across\n",
1840
+ " nested subspace hierarchies \u2014 a genuine flag manifold signature.\"\"\"\n",
1841
  " def __init__(self, patch_size=8, levels=(1, 2, 3), input_size=32):\n",
1842
  " super().__init__()\n",
1843
  " self.patch_size = patch_size\n",
 
1847
  " max_sv = min(3, patch_size * patch_size)\n",
1848
  " self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
1849
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1850
+ " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim} (direct SVD)\")\n",
1851
  "\n",
1852
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1853
  " def forward(self, x):\n",
1854
  " B, C, H, W = x.shape\n",
1855
  " ps = self.patch_size\n",
1856
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1857
  " n_p = patches.shape[2] * patches.shape[3]\n",
1858
  " # X: (B*n_p, ps*ps, C)\n",
1859
+ " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1860
+ " # Direct thin SVD\n",
1861
+ " U, S, Vh = torch.linalg.svd(X, full_matrices=False)\n",
 
 
 
 
 
 
1862
  " # Features at each flag level\n",
1863
  " feats = []\n",
1864
  " for k in self.levels:\n",
spectral/notebooks/experiment_4_invertible_transforms.ipynb CHANGED
@@ -1977,12 +1977,13 @@
1977
  " self.output_dim = n_templates * (2 + min(3, patch_dim))\n",
1978
  " print(f\"[PROCRUSTES] {n_templates} templates, dim={self.output_dim}\")\n",
1979
  "\n",
 
1980
  " def forward(self, x):\n",
1981
  " B, C, H, W = x.shape\n",
1982
  " ps = self.patch_size\n",
1983
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1984
  " patches = patches.contiguous().reshape(B, self.n_patches, -1)\n",
1985
- " patches_n = F.normalize(patches.float(), dim=-1)\n",
1986
  "\n",
1987
  " results = []\n",
1988
  " for t in range(self.n_templates):\n",
@@ -1990,16 +1991,10 @@
1990
  " # Cross-covariance M: (B, D, D) where D = patch_dim\n",
1991
  " M = torch.bmm(patches_n.transpose(1, 2),\n",
1992
  " template.unsqueeze(0).expand(B, -1, -1))\n",
1993
- " # Full SVD via eigh path: M = U S V^T\n",
1994
- " # M^T M = V S\u00b2 V^T \u2192 eigh gives S, V\n",
1995
- " MtM = torch.bmm(M.transpose(1, 2), M) # (B, D, D) symmetric PSD\n",
1996
- " eigvals, V = torch.linalg.eigh(MtM)\n",
1997
- " eigvals = eigvals.flip(-1); V = V.flip(-1)\n",
1998
- " S = torch.sqrt(eigvals.clamp(min=1e-10))\n",
1999
- " # Recover U = M V S\u207b\u00b9\n",
2000
- " U = torch.bmm(M, V) / (S.unsqueeze(1) + 1e-10)\n",
2001
- " # Optimal Procrustes rotation R = U V^T\n",
2002
- " R_opt = torch.bmm(U, V.transpose(1, 2)) # (B, D, D)\n",
2003
  " # Features: alignment quality + top singular values + rotation trace\n",
2004
  " align_quality = S.sum(dim=-1, keepdim=True)\n",
2005
  " top_s = S[:, :min(3, S.shape[1])]\n",
 
1977
  " self.output_dim = n_templates * (2 + min(3, patch_dim))\n",
1978
  " print(f\"[PROCRUSTES] {n_templates} templates, dim={self.output_dim}\")\n",
1979
  "\n",
1980
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1981
  " def forward(self, x):\n",
1982
  " B, C, H, W = x.shape\n",
1983
  " ps = self.patch_size\n",
1984
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1985
  " patches = patches.contiguous().reshape(B, self.n_patches, -1)\n",
1986
+ " patches_n = F.normalize(patches, dim=-1)\n",
1987
  "\n",
1988
  " results = []\n",
1989
  " for t in range(self.n_templates):\n",
 
1991
  " # Cross-covariance M: (B, D, D) where D = patch_dim\n",
1992
  " M = torch.bmm(patches_n.transpose(1, 2),\n",
1993
  " template.unsqueeze(0).expand(B, -1, -1))\n",
1994
+ " # Direct SVD of cross-covariance: M = U S Vh\n",
1995
+ " U, S, Vh = torch.linalg.svd(M, full_matrices=False)\n",
1996
+ " # Optimal Procrustes rotation R = U Vh (= U V^T)\n",
1997
+ " R_opt = torch.bmm(U, Vh) # (B, D, D)\n",
 
 
 
 
 
 
1998
  " # Features: alignment quality + top singular values + rotation trace\n",
1999
  " align_quality = S.sum(dim=-1, keepdim=True)\n",
2000
  " top_s = S[:, :min(3, S.shape[1])]\n",
spectral/notebooks/experiment_5_matrix_decompositions.ipynb CHANGED
@@ -1716,13 +1716,14 @@
1716
  " self.output_dim = n_patches * n_upper\n",
1717
  " print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim} (via 3\u00d73 Cholesky)\")\n",
1718
  "\n",
 
1719
  " def forward(self, x):\n",
1720
  " B, C, H, W = x.shape\n",
1721
  " ps = self.patch_size\n",
1722
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1723
  " n_p = patches.shape[2] * patches.shape[3]\n",
1724
  " # X: (B*n_p, ps*ps, C)\n",
1725
- " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1).float()\n",
1726
  " # X^T X: (B*n_p, 3, 3)\n",
1727
  " XtX = torch.bmm(X.transpose(1, 2), X)\n",
1728
  " # Add small diagonal for numerical stability\n",
 
1716
  " self.output_dim = n_patches * n_upper\n",
1717
  " print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim} (via 3\u00d73 Cholesky)\")\n",
1718
  "\n",
1719
+ " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
1720
  " def forward(self, x):\n",
1721
  " B, C, H, W = x.shape\n",
1722
  " ps = self.patch_size\n",
1723
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
1724
  " n_p = patches.shape[2] * patches.shape[3]\n",
1725
  " # X: (B*n_p, ps*ps, C)\n",
1726
+ " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1727
  " # X^T X: (B*n_p, 3, 3)\n",
1728
  " XtX = torch.bmm(X.transpose(1, 2), X)\n",
1729
  " # Add small diagonal for numerical stability\n",