AbstractPhil commited on
Commit
8774226
·
verified ·
1 Parent(s): 7d543e2

Upload 5 files

Browse files
spectral/notebooks/experiment_3_compact_representations.ipynb CHANGED
@@ -2007,8 +2007,6 @@
2007
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
2008
  " patches = patches.contiguous().reshape(B, C, self.n_patches, ps * ps)\n",
2009
  " # Project each patch onto basis: (B, C, n_patches, n_basis)\n",
2010
- " coeffs = torch.einsum('bcnp,bp->bcn', patches, self.basis.T.contiguous())\n",
2011
- " # Wait, wrong einsum. Let me fix:\n",
2012
  " # patches: (B, C, n_patches, ps*ps), basis: (n_basis, ps*ps)\n",
2013
  " coeffs = torch.einsum('bcnp,kp->bcnk', patches, self.basis)\n",
2014
  " return coeffs.reshape(B, -1)\n",
 
2007
  " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
2008
  " patches = patches.contiguous().reshape(B, C, self.n_patches, ps * ps)\n",
2009
  " # Project each patch onto basis: (B, C, n_patches, n_basis)\n",
 
 
2010
  " # patches: (B, C, n_patches, ps*ps), basis: (n_basis, ps*ps)\n",
2011
  " coeffs = torch.einsum('bcnp,kp->bcnk', patches, self.basis)\n",
2012
  " return coeffs.reshape(B, -1)\n",
spectral/notebooks/experiment_5_matrix_decompositions.ipynb CHANGED
@@ -1862,8 +1862,11 @@
1862
  " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1863
  " # X^T X: (B*n_p, 3, 3)\n",
1864
  " XtX = torch.bmm(X.transpose(1, 2), X)\n",
1865
- " # Add small diagonal for numerical stability\n",
1866
- " XtX = XtX + 1e-6 * torch.eye(self.k, device=x.device, dtype=XtX.dtype).unsqueeze(0)\n",
 
 
 
1867
  " # Cholesky: R^T R = X^T X \u2192 R is the upper Cholesky factor\n",
1868
  " R = torch.linalg.cholesky(XtX).transpose(1, 2) # upper triangular (B*n_p, k, k)\n",
1869
  " k = self.k\n",
 
1862
  " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
1863
  " # X^T X: (B*n_p, 3, 3)\n",
1864
  " XtX = torch.bmm(X.transpose(1, 2), X)\n",
1865
+ " # Regularize relative to matrix scale: eps * tr(XtX)/k * I\n",
1866
+ " # Constant-color patches have near-zero eigenvalues; 1e-6 isn't enough\n",
1867
+ " trace = XtX.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True).unsqueeze(-1) # (N,1,1)\n",
1868
+ " reg = (1e-4 * trace / self.k + 1e-6) # scale-adaptive floor + absolute floor\n",
1869
+ " XtX = XtX + reg * torch.eye(self.k, device=x.device, dtype=XtX.dtype).unsqueeze(0)\n",
1870
  " # Cholesky: R^T R = X^T X \u2192 R is the upper Cholesky factor\n",
1871
  " R = torch.linalg.cholesky(XtX).transpose(1, 2) # upper triangular (B*n_p, k, k)\n",
1872
  " k = self.k\n",