Upload 6 files
Browse files
spectral/notebooks/experiment_2_manifold_structures.ipynb
CHANGED
|
@@ -1766,47 +1766,41 @@
|
|
| 1766 |
"source": [
|
| 1767 |
"# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
|
| 1768 |
"class GrassmannianFrontEnd(nn.Module):\n",
|
| 1769 |
-
" \"\"\"Grassmannian subspace features via
|
| 1770 |
-
" X = U S
|
| 1771 |
-
"
|
| 1772 |
-
"
|
|
|
|
|
|
|
| 1773 |
" def __init__(self, patch_size=8, k=3, input_size=32):\n",
|
| 1774 |
" super().__init__()\n",
|
| 1775 |
" self.patch_size = patch_size\n",
|
| 1776 |
" self.k = k\n",
|
|
|
|
| 1777 |
" n_patches_h = input_size // patch_size\n",
|
| 1778 |
" self.n_patches = n_patches_h * n_patches_h\n",
|
| 1779 |
-
" # k singular values + k log-ratios + k*
|
| 1780 |
-
" self.features_per_patch = k + k + k *
|
| 1781 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1782 |
-
" print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim} (
|
| 1783 |
"\n",
|
|
|
|
| 1784 |
" def forward(self, x):\n",
|
| 1785 |
" B, C, H, W = x.shape\n",
|
| 1786 |
" ps = self.patch_size\n",
|
| 1787 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1788 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1789 |
" # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
|
| 1790 |
-
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)
|
| 1791 |
-
" #
|
| 1792 |
-
"
|
| 1793 |
-
"
|
| 1794 |
-
"
|
| 1795 |
-
" # Flip to descending order\n",
|
| 1796 |
-
" eigvals = eigvals.flip(-1)\n",
|
| 1797 |
-
" V = V.flip(-1)\n",
|
| 1798 |
-
" # S = sqrt(\u03bb), these ARE the singular values\n",
|
| 1799 |
-
" S = torch.sqrt(eigvals.clamp(min=1e-10))[:, :self.k]\n",
|
| 1800 |
-
" # U = X V S\u207b\u00b9 (recover left singular vectors)\n",
|
| 1801 |
-
" Vk = V[:, :, :self.k] # (B*n_p, C, k)\n",
|
| 1802 |
-
" XV = torch.bmm(X, Vk) # (B*n_p, ps*ps, k)\n",
|
| 1803 |
-
" U = XV / (S.unsqueeze(1) + 1e-10) # (B*n_p, ps*ps, k)\n",
|
| 1804 |
-
" # Features: singular values + log-ratios + gram of U\n",
|
| 1805 |
" sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
|
| 1806 |
-
"
|
| 1807 |
-
"
|
| 1808 |
-
"
|
| 1809 |
-
" feats = torch.cat([S, sv_ratios,
|
| 1810 |
" return feats.reshape(B, -1)\n",
|
| 1811 |
"\n",
|
| 1812 |
"front = GrassmannianFrontEnd(patch_size=8, k=3).to(device)\n",
|
|
@@ -1840,9 +1834,10 @@
|
|
| 1840 |
"source": [
|
| 1841 |
"# @title Experiment 2.3 \u2014 Flag Manifold\n",
|
| 1842 |
"class FlagManifoldFrontEnd(nn.Module):\n",
|
| 1843 |
-
" \"\"\"Cascading SVD at multiple truncation levels via
|
| 1844 |
" Nested subspace features: singular values + projection norms at each flag level.\n",
|
| 1845 |
-
"
|
|
|
|
| 1846 |
" def __init__(self, patch_size=8, levels=(1, 2, 3), input_size=32):\n",
|
| 1847 |
" super().__init__()\n",
|
| 1848 |
" self.patch_size = patch_size\n",
|
|
@@ -1852,23 +1847,18 @@
|
|
| 1852 |
" max_sv = min(3, patch_size * patch_size)\n",
|
| 1853 |
" self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
|
| 1854 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1855 |
-
" print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim} (
|
| 1856 |
"\n",
|
|
|
|
| 1857 |
" def forward(self, x):\n",
|
| 1858 |
" B, C, H, W = x.shape\n",
|
| 1859 |
" ps = self.patch_size\n",
|
| 1860 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1861 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1862 |
" # X: (B*n_p, ps*ps, C)\n",
|
| 1863 |
-
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)
|
| 1864 |
-
" #
|
| 1865 |
-
"
|
| 1866 |
-
" eigvals, V = torch.linalg.eigh(XtX)\n",
|
| 1867 |
-
" eigvals = eigvals.flip(-1)\n",
|
| 1868 |
-
" V = V.flip(-1)\n",
|
| 1869 |
-
" S = torch.sqrt(eigvals.clamp(min=1e-10)) # (B*n_p, C)\n",
|
| 1870 |
-
" # U = X V S\u207b\u00b9\n",
|
| 1871 |
-
" U = torch.bmm(X, V) / (S.unsqueeze(1) + 1e-10) # (B*n_p, ps*ps, C)\n",
|
| 1872 |
" # Features at each flag level\n",
|
| 1873 |
" feats = []\n",
|
| 1874 |
" for k in self.levels:\n",
|
|
|
|
| 1766 |
"source": [
|
| 1767 |
"# @title Experiment 2.2 \u2014 Grassmannian Subspace Features\n",
|
| 1768 |
"class GrassmannianFrontEnd(nn.Module):\n",
|
| 1769 |
+
" \"\"\"Grassmannian subspace features via direct SVD.\n",
|
| 1770 |
+
" X = U S Vh. Features: singular values (spectral profile),\n",
|
| 1771 |
+
" log singular value ratios (relative spectrum), and right singular\n",
|
| 1772 |
+
" vectors V (subspace orientation in channel space \u2014 the actual\n",
|
| 1773 |
+
" Grassmannian coordinate). V varies per patch and encodes which\n",
|
| 1774 |
+
" linear combinations of RGB correspond to principal directions.\"\"\"\n",
|
| 1775 |
" def __init__(self, patch_size=8, k=3, input_size=32):\n",
|
| 1776 |
" super().__init__()\n",
|
| 1777 |
" self.patch_size = patch_size\n",
|
| 1778 |
" self.k = k\n",
|
| 1779 |
+
" self.C = 3 # RGB\n",
|
| 1780 |
" n_patches_h = input_size // patch_size\n",
|
| 1781 |
" self.n_patches = n_patches_h * n_patches_h\n",
|
| 1782 |
+
" # k singular values + k log-ratios + k*C right singular vector entries\n",
|
| 1783 |
+
" self.features_per_patch = k + k + k * self.C\n",
|
| 1784 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1785 |
+
" print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim} (direct SVD)\")\n",
|
| 1786 |
"\n",
|
| 1787 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1788 |
" def forward(self, x):\n",
|
| 1789 |
" B, C, H, W = x.shape\n",
|
| 1790 |
" ps = self.patch_size\n",
|
| 1791 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1792 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1793 |
" # X: (B*n_p, ps*ps, C) \u2014 each patch as a tall-skinny matrix\n",
|
| 1794 |
+
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1795 |
+
" # Direct thin SVD: X = U S Vh, U:(N,64,3) S:(N,3) Vh:(N,3,3)\n",
|
| 1796 |
+
" U, S, Vh = torch.linalg.svd(X, full_matrices=False)\n",
|
| 1797 |
+
" S = S[:, :self.k]\n",
|
| 1798 |
+
" # Log singular value ratios (scale-invariant spectrum)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1799 |
" sv_ratios = torch.log(S / (S[:, -1:] + 1e-8) + 1e-8)\n",
|
| 1800 |
+
" # Right singular vectors Vh[:k]: subspace orientation in channel space\n",
|
| 1801 |
+
" # This IS the Grassmannian coordinate \u2014 varies meaningfully per patch\n",
|
| 1802 |
+
" V_feat = Vh[:, :self.k, :].reshape(-1, self.k * C)\n",
|
| 1803 |
+
" feats = torch.cat([S, sv_ratios, V_feat], dim=-1)\n",
|
| 1804 |
" return feats.reshape(B, -1)\n",
|
| 1805 |
"\n",
|
| 1806 |
"front = GrassmannianFrontEnd(patch_size=8, k=3).to(device)\n",
|
|
|
|
| 1834 |
"source": [
|
| 1835 |
"# @title Experiment 2.3 \u2014 Flag Manifold\n",
|
| 1836 |
"class FlagManifoldFrontEnd(nn.Module):\n",
|
| 1837 |
+
" \"\"\"Cascading SVD at multiple truncation levels via direct SVD.\n",
|
| 1838 |
" Nested subspace features: singular values + projection norms at each flag level.\n",
|
| 1839 |
+
" The flag structure captures how information distributes across\n",
|
| 1840 |
+
" nested subspace hierarchies \u2014 a genuine flag manifold signature.\"\"\"\n",
|
| 1841 |
" def __init__(self, patch_size=8, levels=(1, 2, 3), input_size=32):\n",
|
| 1842 |
" super().__init__()\n",
|
| 1843 |
" self.patch_size = patch_size\n",
|
|
|
|
| 1847 |
" max_sv = min(3, patch_size * patch_size)\n",
|
| 1848 |
" self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
|
| 1849 |
" self.output_dim = self.n_patches * self.features_per_patch\n",
|
| 1850 |
+
" print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim} (direct SVD)\")\n",
|
| 1851 |
"\n",
|
| 1852 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1853 |
" def forward(self, x):\n",
|
| 1854 |
" B, C, H, W = x.shape\n",
|
| 1855 |
" ps = self.patch_size\n",
|
| 1856 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1857 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1858 |
" # X: (B*n_p, ps*ps, C)\n",
|
| 1859 |
+
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1860 |
+
" # Direct thin SVD\n",
|
| 1861 |
+
" U, S, Vh = torch.linalg.svd(X, full_matrices=False)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1862 |
" # Features at each flag level\n",
|
| 1863 |
" feats = []\n",
|
| 1864 |
" for k in self.levels:\n",
|
spectral/notebooks/experiment_4_invertible_transforms.ipynb
CHANGED
|
@@ -1977,12 +1977,13 @@
|
|
| 1977 |
" self.output_dim = n_templates * (2 + min(3, patch_dim))\n",
|
| 1978 |
" print(f\"[PROCRUSTES] {n_templates} templates, dim={self.output_dim}\")\n",
|
| 1979 |
"\n",
|
|
|
|
| 1980 |
" def forward(self, x):\n",
|
| 1981 |
" B, C, H, W = x.shape\n",
|
| 1982 |
" ps = self.patch_size\n",
|
| 1983 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1984 |
" patches = patches.contiguous().reshape(B, self.n_patches, -1)\n",
|
| 1985 |
-
" patches_n = F.normalize(patches
|
| 1986 |
"\n",
|
| 1987 |
" results = []\n",
|
| 1988 |
" for t in range(self.n_templates):\n",
|
|
@@ -1990,16 +1991,10 @@
|
|
| 1990 |
" # Cross-covariance M: (B, D, D) where D = patch_dim\n",
|
| 1991 |
" M = torch.bmm(patches_n.transpose(1, 2),\n",
|
| 1992 |
" template.unsqueeze(0).expand(B, -1, -1))\n",
|
| 1993 |
-
" #
|
| 1994 |
-
"
|
| 1995 |
-
"
|
| 1996 |
-
"
|
| 1997 |
-
" eigvals = eigvals.flip(-1); V = V.flip(-1)\n",
|
| 1998 |
-
" S = torch.sqrt(eigvals.clamp(min=1e-10))\n",
|
| 1999 |
-
" # Recover U = M V S\u207b\u00b9\n",
|
| 2000 |
-
" U = torch.bmm(M, V) / (S.unsqueeze(1) + 1e-10)\n",
|
| 2001 |
-
" # Optimal Procrustes rotation R = U V^T\n",
|
| 2002 |
-
" R_opt = torch.bmm(U, V.transpose(1, 2)) # (B, D, D)\n",
|
| 2003 |
" # Features: alignment quality + top singular values + rotation trace\n",
|
| 2004 |
" align_quality = S.sum(dim=-1, keepdim=True)\n",
|
| 2005 |
" top_s = S[:, :min(3, S.shape[1])]\n",
|
|
|
|
| 1977 |
" self.output_dim = n_templates * (2 + min(3, patch_dim))\n",
|
| 1978 |
" print(f\"[PROCRUSTES] {n_templates} templates, dim={self.output_dim}\")\n",
|
| 1979 |
"\n",
|
| 1980 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1981 |
" def forward(self, x):\n",
|
| 1982 |
" B, C, H, W = x.shape\n",
|
| 1983 |
" ps = self.patch_size\n",
|
| 1984 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1985 |
" patches = patches.contiguous().reshape(B, self.n_patches, -1)\n",
|
| 1986 |
+
" patches_n = F.normalize(patches, dim=-1)\n",
|
| 1987 |
"\n",
|
| 1988 |
" results = []\n",
|
| 1989 |
" for t in range(self.n_templates):\n",
|
|
|
|
| 1991 |
" # Cross-covariance M: (B, D, D) where D = patch_dim\n",
|
| 1992 |
" M = torch.bmm(patches_n.transpose(1, 2),\n",
|
| 1993 |
" template.unsqueeze(0).expand(B, -1, -1))\n",
|
| 1994 |
+
" # Direct SVD of cross-covariance: M = U S Vh\n",
|
| 1995 |
+
" U, S, Vh = torch.linalg.svd(M, full_matrices=False)\n",
|
| 1996 |
+
" # Optimal Procrustes rotation R = U Vh (= U V^T)\n",
|
| 1997 |
+
" R_opt = torch.bmm(U, Vh) # (B, D, D)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1998 |
" # Features: alignment quality + top singular values + rotation trace\n",
|
| 1999 |
" align_quality = S.sum(dim=-1, keepdim=True)\n",
|
| 2000 |
" top_s = S[:, :min(3, S.shape[1])]\n",
|
spectral/notebooks/experiment_5_matrix_decompositions.ipynb
CHANGED
|
@@ -1716,13 +1716,14 @@
|
|
| 1716 |
" self.output_dim = n_patches * n_upper\n",
|
| 1717 |
" print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim} (via 3\u00d73 Cholesky)\")\n",
|
| 1718 |
"\n",
|
|
|
|
| 1719 |
" def forward(self, x):\n",
|
| 1720 |
" B, C, H, W = x.shape\n",
|
| 1721 |
" ps = self.patch_size\n",
|
| 1722 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1723 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1724 |
" # X: (B*n_p, ps*ps, C)\n",
|
| 1725 |
-
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)
|
| 1726 |
" # X^T X: (B*n_p, 3, 3)\n",
|
| 1727 |
" XtX = torch.bmm(X.transpose(1, 2), X)\n",
|
| 1728 |
" # Add small diagonal for numerical stability\n",
|
|
|
|
| 1716 |
" self.output_dim = n_patches * n_upper\n",
|
| 1717 |
" print(f\"[QR] {n_patches} patches, k={k}, dim={self.output_dim} (via 3\u00d73 Cholesky)\")\n",
|
| 1718 |
"\n",
|
| 1719 |
+
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 1720 |
" def forward(self, x):\n",
|
| 1721 |
" B, C, H, W = x.shape\n",
|
| 1722 |
" ps = self.patch_size\n",
|
| 1723 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 1724 |
" n_p = patches.shape[2] * patches.shape[3]\n",
|
| 1725 |
" # X: (B*n_p, ps*ps, C)\n",
|
| 1726 |
+
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1727 |
" # X^T X: (B*n_p, 3, 3)\n",
|
| 1728 |
" XtX = torch.bmm(X.transpose(1, 2), X)\n",
|
| 1729 |
" # Add small diagonal for numerical stability\n",
|