AbstractPhil commited on
Commit
6182b2b
·
verified ·
1 Parent(s): e1713d6

Upload 6 files

Browse files
spectral/notebooks/experiment_2_manifold_structures.ipynb CHANGED
@@ -1765,8 +1765,8 @@
1765
  " self.k = k\n",
1766
  " n_patches_h = input_size // patch_size\n",
1767
  " self.n_patches = n_patches_h * n_patches_h\n",
1768
- " # k singular values + k*(k-1)/2 principal angles = features per patch\n",
1769
- " self.features_per_patch = k + k * (k - 1) // 2 + k * 3\n",
1770
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1771
  " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim}\")\n",
1772
  "\n",
@@ -1781,13 +1781,14 @@
1781
  " U, S, Vh = torch.linalg.svd(patches_flat, full_matrices=False)\n",
1782
  " # Top-k singular values\n",
1783
  " sk = S[:, :self.k] # (B*n_p, k)\n",
 
 
1784
  " # Principal angles via gram matrix of top-k vectors\n",
1785
  " Uk = U[:, :, :self.k] # (B*n_p, ps*ps, k)\n",
1786
  " gram = torch.bmm(Uk.transpose(1, 2), Uk) # (B*n_p, k, k)\n",
1787
- " # Features: singular values + flattened gram upper triangle\n",
1788
  " triu_idx = torch.triu_indices(self.k, self.k)\n",
1789
  " gram_feat = gram[:, triu_idx[0], triu_idx[1]] # (B*n_p, k*(k+1)/2)\n",
1790
- " feats = torch.cat([sk, gram_feat], dim=-1) # (B*n_p, feat_per_patch)\n",
1791
  " return feats.reshape(B, -1) # (B, n_patches * feat_per_patch)\n",
1792
  "\n",
1793
  "front = GrassmannianFrontEnd(patch_size=8, k=3).to(device)\n",
@@ -1828,8 +1829,10 @@
1828
  " self.levels = levels\n",
1829
  " n_patches_h = input_size // patch_size\n",
1830
  " self.n_patches = n_patches_h * n_patches_h\n",
1831
- " # For each level k: k singular values + projection norms\n",
1832
- " self.features_per_patch = sum(k + k for k in levels)\n",
 
 
1833
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1834
  " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim}\")\n",
1835
  "\n",
 
1765
  " self.k = k\n",
1766
  " n_patches_h = input_size // patch_size\n",
1767
  " self.n_patches = n_patches_h * n_patches_h\n",
1768
+ " # k singular values + k log-ratios + k*(k+1)/2 gram upper-tri\n",
1769
+ " self.features_per_patch = k + k + k * (k + 1) // 2\n",
1770
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1771
  " print(f\"[GRASS] {self.n_patches} patches, k={k}, dim={self.output_dim}\")\n",
1772
  "\n",
 
1781
  " U, S, Vh = torch.linalg.svd(patches_flat, full_matrices=False)\n",
1782
  " # Top-k singular values\n",
1783
  " sk = S[:, :self.k] # (B*n_p, k)\n",
1784
+ " # Log singular value ratios (numerically informative)\n",
1785
+ " sv_ratios = torch.log(sk / (sk[:, -1:] + 1e-8) + 1e-8) # (B*n_p, k)\n",
1786
  " # Principal angles via gram matrix of top-k vectors\n",
1787
  " Uk = U[:, :, :self.k] # (B*n_p, ps*ps, k)\n",
1788
  " gram = torch.bmm(Uk.transpose(1, 2), Uk) # (B*n_p, k, k)\n",
 
1789
  " triu_idx = torch.triu_indices(self.k, self.k)\n",
1790
  " gram_feat = gram[:, triu_idx[0], triu_idx[1]] # (B*n_p, k*(k+1)/2)\n",
1791
+ " feats = torch.cat([sk, sv_ratios, gram_feat], dim=-1) # (B*n_p, feat_per_patch)\n",
1792
  " return feats.reshape(B, -1) # (B, n_patches * feat_per_patch)\n",
1793
  "\n",
1794
  "front = GrassmannianFrontEnd(patch_size=8, k=3).to(device)\n",
 
1829
  " self.levels = levels\n",
1830
  " n_patches_h = input_size // patch_size\n",
1831
  " self.n_patches = n_patches_h * n_patches_h\n",
1832
+ " # For each level k: min(k, C) singular values + min(k, C) projection norms\n",
1833
+ " # C=3 for RGB, so levels > 3 are clamped\n",
1834
+ " max_sv = min(3, patch_size * patch_size) # C vs ps*ps\n",
1835
+ " self.features_per_patch = sum(min(k, max_sv) * 2 for k in levels)\n",
1836
  " self.output_dim = self.n_patches * self.features_per_patch\n",
1837
  " print(f\"[FLAG] {self.n_patches} patches, levels={levels}, dim={self.output_dim}\")\n",
1838
  "\n",