AbstractPhil commited on
Commit
bc2875f
·
verified ·
1 Parent(s): 8774226

Upload 4 files

Browse files
spectral/notebooks/experiment_4_invertible_transforms.ipynb CHANGED
@@ -2098,45 +2098,48 @@
2098
  "source": [
2099
  "# @title Experiment 4.5 \u2014 Procrustes Alignment\n",
2100
  "class ProcrustessFrontEnd(nn.Module):\n",
2101
- " \"\"\"Align image patches to reference templates, use alignment residuals as features.\"\"\"\n",
 
 
 
2102
  " def __init__(self, n_templates=8, patch_size=4, input_size=32):\n",
2103
  " super().__init__()\n",
2104
  " self.patch_size = patch_size\n",
2105
  " self.n_templates = n_templates\n",
2106
  " n_patches = (input_size // patch_size) ** 2\n",
2107
  " self.n_patches = n_patches\n",
2108
- " patch_dim = 3 * patch_size * patch_size\n",
2109
- " # Fixed reference templates (random orthogonal patches)\n",
2110
- " templates = F.normalize(torch.randn(n_templates, n_patches, patch_dim), dim=-1)\n",
 
2111
  " self.register_buffer('templates', templates)\n",
2112
- " # Features per template: alignment quality (1) + top singular values (min(3,D)) + rotation trace (1)\n",
2113
- " self.output_dim = n_templates * (2 + min(3, patch_dim))\n",
2114
- " print(f\"[PROCRUSTES] {n_templates} templates, dim={self.output_dim}\")\n",
 
2115
  "\n",
2116
  " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
2117
  " def forward(self, x):\n",
2118
  " B, C, H, W = x.shape\n",
 
2119
  " ps = self.patch_size\n",
2120
- " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n",
2121
- " patches = patches.contiguous().reshape(B, self.n_patches, -1)\n",
 
 
2122
  " patches_n = F.normalize(patches, dim=-1)\n",
2123
- "\n",
2124
- " results = []\n",
2125
- " for t in range(self.n_templates):\n",
2126
- " template = self.templates[t] # (n_patches, patch_dim)\n",
2127
- " # Cross-covariance M: (B, D, D) where D = patch_dim\n",
2128
- " M = torch.bmm(patches_n.transpose(1, 2),\n",
2129
- " template.unsqueeze(0).expand(B, -1, -1))\n",
2130
- " # Direct SVD of cross-covariance: M = U S Vh\n",
2131
- " U, S, Vh = torch.linalg.svd(M, full_matrices=False)\n",
2132
- " # Optimal Procrustes rotation R = U Vh (= U V^T)\n",
2133
- " R_opt = torch.bmm(U, Vh) # (B, D, D)\n",
2134
- " # Features: alignment quality + top singular values + rotation trace\n",
2135
- " align_quality = S.sum(dim=-1, keepdim=True)\n",
2136
- " top_s = S[:, :min(3, S.shape[1])]\n",
2137
- " rot_trace = R_opt.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True) # det proxy\n",
2138
- " results.append(torch.cat([align_quality, top_s, rot_trace], dim=-1))\n",
2139
- " return torch.cat(results, dim=-1)\n",
2140
  "\n",
2141
  "front = ProcrustessFrontEnd(n_templates=8, patch_size=4).to(device)\n",
2142
  "model_4_5 = SpectralGeoLIPEncoder(\n",
 
2098
  "source": [
2099
  "# @title Experiment 4.5 \u2014 Procrustes Alignment\n",
2100
  "class ProcrustessFrontEnd(nn.Module):\n",
2101
+ " \"\"\"Per-patch Procrustes alignment in color space \u2014 SO(3) rotations.\n",
2102
+ " Cross-covariance of patch pixel colors vs template pixel colors gives\n",
2103
+ " a 3\u00d73 matrix. SVD of this 3\u00d73 via fused Triton kernel. R = U Vh is\n",
2104
+ " the optimal color rotation. All templates batched in one kernel call.\"\"\"\n",
2105
  " def __init__(self, n_templates=8, patch_size=4, input_size=32):\n",
2106
  " super().__init__()\n",
2107
  " self.patch_size = patch_size\n",
2108
  " self.n_templates = n_templates\n",
2109
  " n_patches = (input_size // patch_size) ** 2\n",
2110
  " self.n_patches = n_patches\n",
2111
+ " ps2 = patch_size * patch_size\n",
2112
+ " N = n_patches * ps2 # total pixels per image\n",
2113
+ " # Templates: (T, N, 3) \u2014 unit-norm color directions per pixel\n",
2114
+ " templates = F.normalize(torch.randn(n_templates, N, 3), dim=-1)\n",
2115
  " self.register_buffer('templates', templates)\n",
2116
+ " # Per template: align_quality(1) + S(3) + rot_trace(1) = 5\n",
2117
+ " self.output_dim = n_templates * 5\n",
2118
+ " backend = \"Triton SVD3\" if _HAS_TRITON_SVD3 else \"torch.linalg.svd\"\n",
2119
+ " print(f\"[PROCRUSTES] {n_templates} templates, color-space SO(3), dim={self.output_dim} ({backend})\")\n",
2120
  "\n",
2121
  " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n",
2122
  " def forward(self, x):\n",
2123
  " B, C, H, W = x.shape\n",
2124
+ " T = self.n_templates\n",
2125
  " ps = self.patch_size\n",
2126
+ " # Reshape to (B, N, 3) \u2014 pixels with 3 color channels\n",
2127
+ " patches = x.unfold(2, ps, ps).unfold(3, ps, ps) # (B, C, nh, nw, ps, ps)\n",
2128
+ " patches = patches.permute(0, 2, 3, 4, 5, 1).contiguous() # (B, nh, nw, ps, ps, C)\n",
2129
+ " patches = patches.reshape(B, -1, C) # (B, N, 3)\n",
2130
  " patches_n = F.normalize(patches, dim=-1)\n",
2131
+ " # Expand for all templates \u2014 one batched call, no Python loop\n",
2132
+ " patches_exp = patches_n.unsqueeze(1).expand(B, T, -1, -1).reshape(B * T, -1, 3)\n",
2133
+ " templates_exp = self.templates.unsqueeze(0).expand(B, -1, -1, -1).reshape(B * T, -1, 3)\n",
2134
+ " # Cross-covariance in color space: (B*T, 3, N) @ (B*T, N, 3) = (B*T, 3, 3)\n",
2135
+ " M = torch.bmm(patches_exp.transpose(1, 2), templates_exp)\n",
2136
+ " # SVD of 3\u00d73 via Triton kernel (or fallback)\n",
2137
+ " U, S, Vh = batched_svd3(M) # M is (B*T, 3, 3)\n",
2138
+ " R_opt = torch.bmm(U, Vh) # optimal SO(3) color rotation\n",
2139
+ " align_quality = S.sum(dim=-1, keepdim=True)\n",
2140
+ " rot_trace = R_opt.diagonal(dim1=-2, dim2=-1).sum(-1, keepdim=True)\n",
2141
+ " feats = torch.cat([align_quality, S, rot_trace], dim=-1) # (B*T, 5)\n",
2142
+ " return feats.reshape(B, T * 5)\n",
 
 
 
 
 
2143
  "\n",
2144
  "front = ProcrustessFrontEnd(n_templates=8, patch_size=4).to(device)\n",
2145
  "model_4_5 = SpectralGeoLIPEncoder(\n",