AbstractPhil commited on
Commit
12d62d9
·
verified ·
1 Parent(s): bfa18f0

Create vit_zana_prediffusion.py

Browse files
Files changed (1) hide show
  1. vit_zana_prediffusion.py +316 -0
vit_zana_prediffusion.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline Vision Transformer with Frozen Pentachora Embeddings
3
+ Clean architecture with geometric semantic anchors
4
+ Assumes PentachoronStabilizer is loaded externally
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from einops import rearrange
12
+ import math
13
+ from typing import Optional, Tuple, Dict, Any
14
+
15
+
16
+ class PentachoraEmbedding(nn.Module):
17
+ """
18
+ A single frozen pentachora embedding (5 vertices in geometric space).
19
+ Accepts pre-computed vertices only. No random initialization.
20
+ """
21
+
22
+ def __init__(self, vertices: torch.Tensor):
23
+ super().__init__()
24
+ assert vertices.shape == (5, 128), f"Expected shape (5, 128), got {vertices.shape}"
25
+
26
+ self.embed_dim = vertices.shape[-1]
27
+
28
+ # Store provided vertices as frozen buffer
29
+ self.register_buffer('vertices', vertices)
30
+ self.vertices.requires_grad = False
31
+
32
+ # Precompute normalized versions and centroid
33
+ with torch.no_grad():
34
+ self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
35
+ self.register_buffer('centroid', self.vertices.mean(dim=0))
36
+ self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
37
+
38
+ def get_vertices(self) -> torch.Tensor:
39
+ """Get all 5 vertices."""
40
+ return self.vertices
41
+
42
+ def get_centroid(self) -> torch.Tensor:
43
+ """Get the centroid of the pentachora."""
44
+ return self.centroid
45
+
46
+ def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Compute Rose similarity score with this pentachora.
49
+ Uses external PentachoronStabilizer.rose_score_magnitude
50
+ """
51
+ # Prepare vertices for rose scoring
52
+ verts = self.vertices.unsqueeze(0) # [1, 5, D]
53
+ if features.dim() == 1:
54
+ features = features.unsqueeze(0)
55
+
56
+ # Expand vertices to batch size if needed
57
+ B = features.shape[0]
58
+ if B > 1:
59
+ verts = verts.expand(B, -1, -1)
60
+
61
+ return PentachoronStabilizer.rose_score_magnitude(features, verts)
62
+
63
+ def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
64
+ """
65
+ Compute similarity between features and this pentachora.
66
+
67
+ Args:
68
+ features: [batch, dim] or [batch, seq, dim]
69
+ mode: 'centroid', 'max' (max over vertices), or 'rose' (Rose score)
70
+
71
+ Returns:
72
+ similarities: [batch] or [batch, seq]
73
+ """
74
+ if mode == 'rose':
75
+ return self.compute_rose_score(features)
76
+
77
+ features_norm = F.normalize(features, dim=-1)
78
+
79
+ if mode == 'centroid':
80
+ # Dot product with centroid
81
+ return torch.matmul(features_norm, self.centroid_norm)
82
+ else: # mode == 'max'
83
+ # Max similarity across vertices
84
+ sims = torch.matmul(features_norm, self.vertices_norm.T)
85
+ return sims.max(dim=-1)[0]
86
+
87
+
88
+ class TransformerBlock(nn.Module):
89
+ """Standard transformer block with multi-head attention and MLP."""
90
+
91
+ def __init__(
92
+ self,
93
+ dim: int,
94
+ num_heads: int = 8,
95
+ mlp_ratio: float = 4.0,
96
+ dropout: float = 0.0,
97
+ attn_dropout: float = 0.0
98
+ ):
99
+ super().__init__()
100
+
101
+ self.norm1 = nn.LayerNorm(dim)
102
+ self.attn = nn.MultiheadAttention(
103
+ dim,
104
+ num_heads,
105
+ dropout=attn_dropout,
106
+ batch_first=True
107
+ )
108
+
109
+ self.norm2 = nn.LayerNorm(dim)
110
+ mlp_hidden_dim = int(dim * mlp_ratio)
111
+ self.mlp = nn.Sequential(
112
+ nn.Linear(dim, mlp_hidden_dim),
113
+ nn.GELU(),
114
+ nn.Dropout(dropout),
115
+ nn.Linear(mlp_hidden_dim, dim),
116
+ nn.Dropout(dropout)
117
+ )
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ # Self-attention
121
+ x_norm = self.norm1(x)
122
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm)
123
+ x = x + attn_out
124
+
125
+ # MLP
126
+ x = x + self.mlp(self.norm2(x))
127
+
128
+ return x
129
+
130
+
131
+ class BaselineViT(nn.Module):
132
+ """
133
+ Clean baseline Vision Transformer with frozen pentachora embeddings.
134
+ """
135
+
136
+ def __init__(
137
+ self,
138
+ pentachora_list: list, # List of torch.Tensor, each [5, 128]
139
+ img_size: int = 32,
140
+ patch_size: int = 4,
141
+ embed_dim: int = 512,
142
+ depth: int = 12,
143
+ num_heads: int = 8,
144
+ mlp_ratio: float = 4.0,
145
+ dropout: float = 0.0,
146
+ attn_dropout: float = 0.0,
147
+ similarity_mode: str = 'rose' # 'centroid', 'max', or 'rose'
148
+ ):
149
+ super().__init__()
150
+
151
+ # Validate pentachora list
152
+ assert isinstance(pentachora_list, list), f"Expected list, got {type(pentachora_list)}"
153
+ assert len(pentachora_list) > 0, "Empty pentachora list"
154
+
155
+ # Validate each pentachora
156
+ for i, penta in enumerate(pentachora_list):
157
+ assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
158
+ assert penta.shape == (5, 128), f"Item {i} has shape {penta.shape}, expected (5, 128)"
159
+
160
+ self.num_classes = len(pentachora_list)
161
+ self.embed_dim = embed_dim
162
+ self.num_patches = (img_size // patch_size) ** 2
163
+ self.similarity_mode = similarity_mode
164
+ self.pentachora_dim = 128 # Always 128 from vocab
165
+
166
+ # Create individual pentachora embeddings from list
167
+ self.class_pentachora = nn.ModuleList([
168
+ PentachoraEmbedding(vertices=penta)
169
+ for penta in pentachora_list
170
+ ])
171
+
172
+ # Patch embedding
173
+ self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
174
+
175
+ # CLS token - learnable
176
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
177
+
178
+ # Position embeddings
179
+ self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
180
+ self.pos_drop = nn.Dropout(dropout)
181
+
182
+ # Transformer blocks
183
+ self.blocks = nn.ModuleList([
184
+ TransformerBlock(
185
+ dim=embed_dim,
186
+ num_heads=num_heads,
187
+ mlp_ratio=mlp_ratio,
188
+ dropout=dropout,
189
+ attn_dropout=attn_dropout
190
+ )
191
+ for i in range(depth)
192
+ ])
193
+
194
+ # Final norm
195
+ self.norm = nn.LayerNorm(embed_dim)
196
+
197
+ # Project to pentachora dimension if needed
198
+ if self.pentachora_dim != embed_dim:
199
+ self.to_pentachora_dim = nn.Linear(embed_dim, self.pentachora_dim)
200
+ else:
201
+ self.to_pentachora_dim = nn.Identity()
202
+
203
+ # Temperature for similarity-based classification
204
+ self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
205
+
206
+ # Initialize weights
207
+ self.init_weights()
208
+
209
+ def init_weights(self):
210
+ """Initialize model weights."""
211
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
212
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
213
+
214
+ for m in self.modules():
215
+ if isinstance(m, nn.Linear):
216
+ nn.init.trunc_normal_(m.weight, std=0.02)
217
+ if m.bias is not None:
218
+ nn.init.zeros_(m.bias)
219
+ elif isinstance(m, nn.LayerNorm):
220
+ nn.init.ones_(m.weight)
221
+ nn.init.zeros_(m.bias)
222
+
223
+ def get_class_centroids(self) -> torch.Tensor:
224
+ """Get all class centroids for similarity computation."""
225
+ centroids = torch.stack([
226
+ penta.get_centroid() for penta in self.class_pentachora
227
+ ])
228
+ return F.normalize(centroids, dim=-1)
229
+
230
+ def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
231
+ """
232
+ Compute similarities between features and all class pentachora.
233
+
234
+ Args:
235
+ features: [batch, dim] features to compare
236
+
237
+ Returns:
238
+ similarities: [batch, num_classes]
239
+ """
240
+ similarities = []
241
+ for penta in self.class_pentachora:
242
+ sim = penta.compute_similarity(features, mode=self.similarity_mode)
243
+ similarities.append(sim)
244
+
245
+ return torch.stack(similarities, dim=-1)
246
+
247
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
248
+ """Extract features from images."""
249
+ B = x.shape[0]
250
+
251
+ # Patch embedding
252
+ x = self.patch_embed(x) # [B, embed_dim, H', W']
253
+ x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
254
+
255
+ # Add CLS token
256
+ cls_tokens = self.cls_token.expand(B, -1, -1)
257
+ x = torch.cat([cls_tokens, x], dim=1)
258
+
259
+ # Add position embeddings
260
+ x = x + self.pos_embed
261
+ x = self.pos_drop(x)
262
+
263
+ # Apply transformer blocks
264
+ for block in self.blocks:
265
+ x = block(x)
266
+
267
+ # Final norm
268
+ x = self.norm(x)
269
+
270
+ # Return CLS token
271
+ return x[:, 0]
272
+
273
+ def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
274
+ """
275
+ Forward pass.
276
+
277
+ Returns dict with:
278
+ - logits: classification logits
279
+ - features: CLS features (if return_features=True)
280
+ - similarities: raw similarities to pentachora
281
+ """
282
+ features = self.forward_features(x)
283
+
284
+ output = {}
285
+
286
+ # Project to pentachora dimension
287
+ features_proj = self.to_pentachora_dim(features)
288
+
289
+ # Compute similarities based on mode
290
+ if self.similarity_mode == 'rose':
291
+ # Use Rose scoring
292
+ similarities = self.compute_pentachora_similarities(features_proj)
293
+ else:
294
+ # Use centroid or max similarity
295
+ features_norm = F.normalize(features_proj, dim=-1)
296
+ centroids = self.get_class_centroids()
297
+ similarities = torch.matmul(features_norm, centroids.T)
298
+
299
+ # Scale by temperature
300
+ logits = similarities * self.temperature.exp()
301
+
302
+ output['logits'] = logits
303
+ output['similarities'] = similarities
304
+
305
+ if return_features:
306
+ output['features'] = features
307
+
308
+ return output
309
+
310
+
311
+ # Test - requires external setup
312
+ if __name__ == "__main__":
313
+ print("BaselineViT requires:")
314
+ print(" 1. PentachoronStabilizer loaded externally")
315
+ print(" 2. pentachora_batch tensor [num_classes, 5, 128]")
316
+ print("\nNo random initialization. No fallbacks.")