Upload 5 files
Browse files
spectral/notebooks/experiment_3_compact_representations.ipynb
CHANGED
|
@@ -2007,8 +2007,6 @@
|
|
| 2007 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 2008 |
" patches = patches.contiguous().reshape(B, C, self.n_patches, ps * ps)\n",
|
| 2009 |
" # Project each patch onto basis: (B, C, n_patches, n_basis)\n",
|
| 2010 |
-
" coeffs = torch.einsum('bcnp,bp->bcn', patches, self.basis.T.contiguous())\n",
|
| 2011 |
-
" # Wait, wrong einsum. Let me fix:\n",
|
| 2012 |
" # patches: (B, C, n_patches, ps*ps), basis: (n_basis, ps*ps)\n",
|
| 2013 |
" coeffs = torch.einsum('bcnp,kp->bcnk', patches, self.basis)\n",
|
| 2014 |
" return coeffs.reshape(B, -1)\n",
|
|
|
|
| 2007 |
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
|
| 2008 |
" patches = patches.contiguous().reshape(B, C, self.n_patches, ps * ps)\n",
|
| 2009 |
" # Project each patch onto basis: (B, C, n_patches, n_basis)\n",
|
|
|
|
|
|
|
| 2010 |
" # patches: (B, C, n_patches, ps*ps), basis: (n_basis, ps*ps)\n",
|
| 2011 |
" coeffs = torch.einsum('bcnp,kp->bcnk', patches, self.basis)\n",
|
| 2012 |
" return coeffs.reshape(B, -1)\n",
|
spectral/notebooks/experiment_5_matrix_decompositions.ipynb
CHANGED
|
@@ -1862,8 +1862,11 @@
|
|
| 1862 |
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1863 |
" # X^T X: (B*n_p, 3, 3)\n",
|
| 1864 |
" XtX = torch.bmm(X.transpose(1, 2), X)\n",
|
| 1865 |
-
" #
|
| 1866 |
-
"
|
|
|
|
|
|
|
|
|
|
| 1867 |
" # Cholesky: R^T R = X^T X \u2192 R is the upper Cholesky factor\n",
|
| 1868 |
" R = torch.linalg.cholesky(XtX).transpose(1, 2) # upper triangular (B*n_p, k, k)\n",
|
| 1869 |
" k = self.k\n",
|
|
|
|
| 1862 |
" X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n",
|
| 1863 |
" # X^T X: (B*n_p, 3, 3)\n",
|
| 1864 |
" XtX = torch.bmm(X.transpose(1, 2), X)\n",
|
| 1865 |
+
" # Regularize relative to matrix scale: eps * tr(XtX)/k * I\n",
|
| 1866 |
+
" # Constant-color patches have near-zero eigenvalues; 1e-6 isn't enough\n",
|
| 1867 |
+
" trace = XtX.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True).unsqueeze(-1) # (N,1,1)\n",
|
| 1868 |
+
" reg = (1e-4 * trace / self.k + 1e-6) # scale-adaptive floor + absolute floor\n",
|
| 1869 |
+
" XtX = XtX + reg * torch.eye(self.k, device=x.device, dtype=XtX.dtype).unsqueeze(0)\n",
|
| 1870 |
" # Cholesky: R^T R = X^T X \u2192 R is the upper Cholesky factor\n",
|
| 1871 |
" R = torch.linalg.cholesky(XtX).transpose(1, 2) # upper triangular (B*n_p, k, k)\n",
|
| 1872 |
" k = self.k\n",
|