Upload 4 files
Browse files
spectral/notebooks/experiment_4_invertible_transforms.ipynb
CHANGED
|
@@ -2098,45 +2098,48 @@
|
|
| 2098 |
"source": [
|
| 2099 |
"# @title Experiment 4.5 \u2014 Procrustes Alignment\n",
|
| 2100 |
"class ProcrustessFrontEnd(nn.Module):\n",
|
| 2101 |
-
" \"\"\"
|
|
|
|
|
|
|
|
|
|
| 2102 |
" def __init__(self, n_templates=8, patch_size=4, input_size=32):\n",
|
| 2103 |
" super().__init__()\n",
|
| 2104 |
" self.patch_size = patch_size\n",
|
| 2105 |
" self.n_templates = n_templates\n",
|
| 2106 |
" n_patches = (input_size // patch_size) ** 2\n",
|
| 2107 |
" self.n_patches = n_patches\n",
|
| 2108 |
-
"
|
| 2109 |
-
"
|
| 2110 |
-
"
|
|
|
|
| 2111 |
" self.register_buffer('templates', templates)\n",
|
| 2112 |
-
" #
|
| 2113 |
-
" self.output_dim = n_templates *
|
| 2114 |
-
"
|
|
|
|
| 2115 |
"\n",
|
| 2116 |
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 2117 |
" def forward(self, x):\n",
|
| 2118 |
" B, C, H, W = x.shape\n",
|
|
|
|
| 2119 |
" ps = self.patch_size\n",
|
| 2120 |
-
"
|
| 2121 |
-
" patches =
|
|
|
|
|
|
|
| 2122 |
" patches_n = F.normalize(patches, dim=-1)\n",
|
| 2123 |
-
"\n",
|
| 2124 |
-
"
|
| 2125 |
-
"
|
| 2126 |
-
"
|
| 2127 |
-
"
|
| 2128 |
-
"
|
| 2129 |
-
"
|
| 2130 |
-
"
|
| 2131 |
-
"
|
| 2132 |
-
"
|
| 2133 |
-
"
|
| 2134 |
-
"
|
| 2135 |
-
" align_quality = S.sum(dim=-1, keepdim=True)\n",
|
| 2136 |
-
" top_s = S[:, :min(3, S.shape[1])]\n",
|
| 2137 |
-
" rot_trace = R_opt.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True) # det proxy\n",
|
| 2138 |
-
" results.append(torch.cat([align_quality, top_s, rot_trace], dim=-1))\n",
|
| 2139 |
-
" return torch.cat(results, dim=-1)\n",
|
| 2140 |
"\n",
|
| 2141 |
"front = ProcrustessFrontEnd(n_templates=8, patch_size=4).to(device)\n",
|
| 2142 |
"model_4_5 = SpectralGeoLIPEncoder(\n",
|
|
|
|
| 2098 |
"source": [
|
| 2099 |
"# @title Experiment 4.5 \u2014 Procrustes Alignment\n",
|
| 2100 |
"class ProcrustessFrontEnd(nn.Module):\n",
|
| 2101 |
+
" \"\"\"Per-patch Procrustes alignment in color space \u2014 SO(3) rotations.\n",
|
| 2102 |
+
" Cross-covariance of patch pixel colors vs template pixel colors gives\n",
|
| 2103 |
+
" a 3\u00d73 matrix. SVD of this 3\u00d73 via fused Triton kernel. R = U Vh is\n",
|
| 2104 |
+
" the optimal color rotation. All templates batched in one kernel call.\"\"\"\n",
|
| 2105 |
" def __init__(self, n_templates=8, patch_size=4, input_size=32):\n",
|
| 2106 |
" super().__init__()\n",
|
| 2107 |
" self.patch_size = patch_size\n",
|
| 2108 |
" self.n_templates = n_templates\n",
|
| 2109 |
" n_patches = (input_size // patch_size) ** 2\n",
|
| 2110 |
" self.n_patches = n_patches\n",
|
| 2111 |
+
" ps2 = patch_size * patch_size\n",
|
| 2112 |
+
" N = n_patches * ps2 # total pixels per image\n",
|
| 2113 |
+
" # Templates: (T, N, 3) \u2014 unit-norm color directions per pixel\n",
|
| 2114 |
+
" templates = F.normalize(torch.randn(n_templates, N, 3), dim=-1)\n",
|
| 2115 |
" self.register_buffer('templates', templates)\n",
|
| 2116 |
+
" # Per template: align_quality(1) + S(3) + rot_trace(1) = 5\n",
|
| 2117 |
+
" self.output_dim = n_templates * 5\n",
|
| 2118 |
+
" backend = \"Triton SVD3\" if _HAS_TRITON_SVD3 else \"torch.linalg.svd\"\n",
|
| 2119 |
+
" print(f\"[PROCRUSTES] {n_templates} templates, color-space SO(3), dim={self.output_dim} ({backend})\")\n",
|
| 2120 |
"\n",
|
| 2121 |
" @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
|
| 2122 |
" def forward(self, x):\n",
|
| 2123 |
" B, C, H, W = x.shape\n",
|
| 2124 |
+
" T = self.n_templates\n",
|
| 2125 |
" ps = self.patch_size\n",
|
| 2126 |
+
" # Reshape to (B, N, 3) \u2014 pixels with 3 color channels\n",
|
| 2127 |
+
" patches = x.unfold(2, ps, ps).unfold(3, ps, ps) # (B, C, nh, nw, ps, ps)\n",
|
| 2128 |
+
" patches = patches.permute(0, 2, 3, 4, 5, 1).contiguous() # (B, nh, nw, ps, ps, C)\n",
|
| 2129 |
+
" patches = patches.reshape(B, -1, C) # (B, N, 3)\n",
|
| 2130 |
" patches_n = F.normalize(patches, dim=-1)\n",
|
| 2131 |
+
" # Expand for all templates \u2014 one batched call, no Python loop\n",
|
| 2132 |
+
" patches_exp = patches_n.unsqueeze(1).expand(B, T, -1, -1).reshape(B * T, -1, 3)\n",
|
| 2133 |
+
" templates_exp = self.templates.unsqueeze(0).expand(B, -1, -1, -1).reshape(B * T, -1, 3)\n",
|
| 2134 |
+
" # Cross-covariance in color space: (B*T, 3, N) @ (B*T, N, 3) = (B*T, 3, 3)\n",
|
| 2135 |
+
" M = torch.bmm(patches_exp.transpose(1, 2), templates_exp)\n",
|
| 2136 |
+
" # SVD of 3\u00d73 via Triton kernel (or fallback)\n",
|
| 2137 |
+
" U, S, Vh = batched_svd3(M) # M is (B*T, 3, 3)\n",
|
| 2138 |
+
" R_opt = torch.bmm(U, Vh) # optimal SO(3) color rotation\n",
|
| 2139 |
+
" align_quality = S.sum(dim=-1, keepdim=True)\n",
|
| 2140 |
+
" rot_trace = R_opt.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True)\n",
|
| 2141 |
+
" feats = torch.cat([align_quality, S, rot_trace], dim=-1) # (B*T, 5)\n",
|
| 2142 |
+
" return feats.reshape(B, T * 5)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2143 |
"\n",
|
| 2144 |
"front = ProcrustessFrontEnd(n_templates=8, patch_size=4).to(device)\n",
|
| 2145 |
"model_4_5 = SpectralGeoLIPEncoder(\n",
|