AbstractPhil commited on
Commit
3678161
·
verified ·
1 Parent(s): c3d8b53

Create vit_zana_v4_l1.py

Browse files
Files changed (1) hide show
  1. vit_zana_v4_l1.py +368 -0
vit_zana_v4_l1.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline Vision Transformer with Frozen Pentachora Embeddings
3
+ Adapted for L1-normalized pentachora vertices
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from einops import rearrange
11
+ import math
12
+ from typing import Optional, Tuple, Dict, Any
13
+
14
+
15
+ class PentachoraEmbedding(nn.Module):
16
+ """
17
+ A single frozen pentachora embedding (5 vertices in geometric space).
18
+ Supports both L1 and L2 normalized vertices.
19
+ """
20
+
21
+ def __init__(self, vertices: torch.Tensor, norm_type: str = 'l1'):
22
+ super().__init__()
23
+
24
+ self.embed_dim = vertices.shape[-1]
25
+ self.norm_type = norm_type
26
+
27
+ # Store provided vertices as frozen buffer
28
+ self.register_buffer('vertices', vertices)
29
+ self.vertices.requires_grad = False
30
+
31
+ # Precompute normalized versions and centroid
32
+ with torch.no_grad():
33
+ # For L1-normalized data, use L1 norm for consistency
34
+ if norm_type == 'l1':
35
+ # L1 normalize (sum of abs values = 1)
36
+ self.register_buffer('vertices_norm',
37
+ vertices / (vertices.abs().sum(dim=-1, keepdim=True) + 1e-8))
38
+ else:
39
+ # L2 normalize (euclidean norm = 1)
40
+ self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
41
+
42
+ self.register_buffer('centroid', self.vertices.mean(dim=0))
43
+
44
+ # Centroid normalization matches vertex normalization
45
+ if norm_type == 'l1':
46
+ self.register_buffer('centroid_norm',
47
+ self.centroid / (self.centroid.abs().sum() + 1e-8))
48
+ else:
49
+ self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
50
+
51
+ def get_vertices(self) -> torch.Tensor:
52
+ """Get all 5 vertices."""
53
+ return self.vertices
54
+
55
+ def get_centroid(self) -> torch.Tensor:
56
+ """Get the centroid of the pentachora."""
57
+ return self.centroid
58
+
59
+ def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ Compute Rose similarity score with this pentachora.
62
+ Scaled appropriately for L1 norm.
63
+ """
64
+ verts = self.vertices.unsqueeze(0) # [1, 5, D]
65
+ if features.dim() == 1:
66
+ features = features.unsqueeze(0)
67
+
68
+ B = features.shape[0]
69
+ if B > 1:
70
+ verts = verts.expand(B, -1, -1)
71
+
72
+ # For L1 norm, scale the rose score appropriately
73
+ score = PentachoronStabilizer.rose_score_magnitude(features, verts)
74
+ if self.norm_type == 'l1':
75
+ # L1 norm produces smaller values, so amplify the signal
76
+ score = score * 10.0
77
+ return score
78
+
79
+ def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
80
+ """
81
+ Compute similarity between features and this pentachora.
82
+ """
83
+ if mode == 'rose':
84
+ return self.compute_rose_score(features)
85
+
86
+ # Normalize features according to norm type
87
+ if self.norm_type == 'l1':
88
+ features_norm = features / (features.abs().sum(dim=-1, keepdim=True) + 1e-8)
89
+ else:
90
+ features_norm = F.normalize(features, dim=-1)
91
+
92
+ if mode == 'centroid':
93
+ # Dot product with centroid
94
+ sim = torch.sum(features_norm * self.centroid_norm, dim=-1)
95
+ # Scale up L1 similarities to be comparable to L2
96
+ if self.norm_type == 'l1':
97
+ sim = sim * 10.0
98
+ return sim
99
+ else: # mode == 'max'
100
+ # Max similarity across vertices
101
+ sims = torch.matmul(features_norm, self.vertices_norm.T)
102
+ if self.norm_type == 'l1':
103
+ sims = sims * 10.0
104
+ return sims.max(dim=-1)[0]
105
+
106
+
107
+ class TransformerBlock(nn.Module):
108
+ """Standard transformer block with multi-head attention and MLP."""
109
+
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ num_heads: int = 8,
114
+ mlp_ratio: float = 4.0,
115
+ dropout: float = 0.0,
116
+ attn_dropout: float = 0.0
117
+ ):
118
+ super().__init__()
119
+
120
+ self.norm1 = nn.LayerNorm(dim)
121
+ self.attn = nn.MultiheadAttention(
122
+ dim,
123
+ num_heads,
124
+ dropout=attn_dropout,
125
+ batch_first=True
126
+ )
127
+
128
+ self.norm2 = nn.LayerNorm(dim)
129
+ mlp_hidden_dim = int(dim * mlp_ratio)
130
+ self.mlp = nn.Sequential(
131
+ nn.Linear(dim, mlp_hidden_dim),
132
+ nn.GELU(),
133
+ nn.Dropout(dropout),
134
+ nn.Linear(mlp_hidden_dim, dim),
135
+ nn.Dropout(dropout)
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ # Self-attention
140
+ x_norm = self.norm1(x)
141
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm)
142
+ x = x + attn_out
143
+
144
+ # MLP
145
+ x = x + self.mlp(self.norm2(x))
146
+
147
+ return x
148
+
149
+
150
+ class BaselineViT(nn.Module):
151
+ """
152
+ Vision Transformer with frozen pentachora embeddings.
153
+ Supports L1-normalized pentachora.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ pentachora_list: list, # List of torch.Tensor, each [5, vocab_dim]
159
+ vocab_dim: int = 256,
160
+ img_size: int = 32,
161
+ patch_size: int = 4,
162
+ embed_dim: int = 512,
163
+ depth: int = 12,
164
+ num_heads: int = 8,
165
+ mlp_ratio: float = 4.0,
166
+ dropout: float = 0.0,
167
+ attn_dropout: float = 0.0,
168
+ similarity_mode: str = 'rose', # 'centroid', 'max', or 'rose'
169
+ norm_type: str = 'l1' # 'l1' or 'l2' normalization
170
+ ):
171
+ super().__init__()
172
+
173
+ # Validate pentachora list
174
+ assert isinstance(pentachora_list, list), f"Expected list, got {type(pentachora_list)}"
175
+ assert len(pentachora_list) > 0, "Empty pentachora list"
176
+
177
+ for i, penta in enumerate(pentachora_list):
178
+ assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
179
+
180
+ self.num_classes = len(pentachora_list)
181
+ self.embed_dim = embed_dim
182
+ self.num_patches = (img_size // patch_size) ** 2
183
+ self.similarity_mode = similarity_mode
184
+ self.pentachora_dim = vocab_dim
185
+ self.norm_type = norm_type
186
+
187
+ # Create individual pentachora embeddings from list
188
+ self.class_pentachora = nn.ModuleList([
189
+ PentachoraEmbedding(vertices=penta, norm_type=norm_type)
190
+ for penta in pentachora_list
191
+ ])
192
+
193
+ # Patch embedding
194
+ self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
195
+
196
+ # CLS token - learnable
197
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
198
+
199
+ # Position embeddings
200
+ self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
201
+ self.pos_drop = nn.Dropout(dropout)
202
+
203
+ # Transformer blocks
204
+ self.blocks = nn.ModuleList([
205
+ TransformerBlock(
206
+ dim=embed_dim,
207
+ num_heads=num_heads,
208
+ mlp_ratio=mlp_ratio,
209
+ dropout=dropout,
210
+ attn_dropout=attn_dropout
211
+ )
212
+ for i in range(depth)
213
+ ])
214
+
215
+ # Final norm
216
+ self.norm = nn.LayerNorm(embed_dim)
217
+
218
+ # Project to pentachora dimension if needed
219
+ if self.pentachora_dim != embed_dim:
220
+ self.to_pentachora_dim = nn.Linear(embed_dim, self.pentachora_dim)
221
+ else:
222
+ self.to_pentachora_dim = nn.Identity()
223
+
224
+ # Temperature for similarity-based classification
225
+ # For L1 norm, start with lower temperature since similarities are scaled
226
+ if norm_type == 'l1':
227
+ self.temperature = nn.Parameter(torch.zeros(1)) # exp(0) = 1
228
+ else:
229
+ self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
230
+
231
+ # Precompute all centroids for efficiency
232
+ self.register_buffer(
233
+ 'all_centroids',
234
+ torch.stack([penta.centroid for penta in self.class_pentachora])
235
+ )
236
+
237
+ # Normalize centroids according to norm type
238
+ if norm_type == 'l1':
239
+ centroids_normalized = self.all_centroids / (
240
+ self.all_centroids.abs().sum(dim=-1, keepdim=True) + 1e-8)
241
+ else:
242
+ centroids_normalized = F.normalize(self.all_centroids, dim=-1)
243
+
244
+ self.register_buffer('all_centroids_norm', centroids_normalized)
245
+
246
+ # Initialize weights
247
+ self.init_weights()
248
+
249
+ def init_weights(self):
250
+ """Initialize model weights."""
251
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
252
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
253
+
254
+ for m in self.modules():
255
+ if isinstance(m, nn.Linear):
256
+ nn.init.trunc_normal_(m.weight, std=0.02)
257
+ if m.bias is not None:
258
+ nn.init.zeros_(m.bias)
259
+ elif isinstance(m, nn.LayerNorm):
260
+ nn.init.ones_(m.weight)
261
+ nn.init.zeros_(m.bias)
262
+
263
+ def get_class_centroids(self) -> torch.Tensor:
264
+ return self.all_centroids_norm
265
+
266
+ def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
267
+ """
268
+ Compute similarities between features and all class pentachora.
269
+ Properly scaled for L1 or L2 norm.
270
+ """
271
+ if self.similarity_mode == 'rose':
272
+ # Stack all vertices into single tensor for batch Rose scoring
273
+ all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora])
274
+ features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1)
275
+ scores = PentachoronStabilizer.rose_score_magnitude(
276
+ features_exp.reshape(-1, self.pentachora_dim),
277
+ all_vertices.repeat(features.shape[0], 1, 1)
278
+ ).reshape(features.shape[0], -1)
279
+
280
+ # Scale for L1 norm
281
+ if self.norm_type == 'l1':
282
+ scores = scores * 10.0
283
+ return scores
284
+ else:
285
+ # Normalize features according to norm type
286
+ if self.norm_type == 'l1':
287
+ features_norm = features / (features.abs().sum(dim=-1, keepdim=True) + 1e-8)
288
+ else:
289
+ features_norm = F.normalize(features, dim=-1)
290
+
291
+ centroids = self.get_class_centroids()
292
+ sims = torch.matmul(features_norm, centroids.T)
293
+
294
+ # Scale for L1 norm
295
+ if self.norm_type == 'l1':
296
+ sims = sims * 10.0
297
+ return sims
298
+
299
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
300
+ """Extract features from images."""
301
+ B = x.shape[0]
302
+
303
+ # Patch embedding
304
+ x = self.patch_embed(x) # [B, embed_dim, H', W']
305
+ x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
306
+
307
+ # Add CLS token
308
+ cls_tokens = self.cls_token.expand(B, -1, -1)
309
+ x = torch.cat([cls_tokens, x], dim=1)
310
+
311
+ # Add position embeddings
312
+ x = x + self.pos_embed
313
+ x = self.pos_drop(x)
314
+
315
+ # Apply transformer blocks
316
+ for block in self.blocks:
317
+ x = block(x)
318
+
319
+ # Final norm
320
+ x = self.norm(x)
321
+
322
+ # Return CLS token
323
+ return x[:, 0]
324
+
325
+ def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
326
+ """
327
+ Forward pass.
328
+
329
+ Returns dict with:
330
+ - logits: classification logits
331
+ - features: CLS features (if return_features=True)
332
+ - features_proj: projected features in pentachora space
333
+ - similarities: raw similarities to pentachora
334
+ """
335
+ features = self.forward_features(x)
336
+
337
+ output = {}
338
+
339
+ # Project to pentachora dimension
340
+ features_proj = self.to_pentachora_dim(features)
341
+
342
+ # Apply appropriate normalization for projected features
343
+ if self.norm_type == 'l1':
344
+ # L1 normalize the projected features
345
+ features_proj = features_proj / (features_proj.abs().sum(dim=-1, keepdim=True) + 1e-8)
346
+
347
+ # Compute similarities
348
+ similarities = self.compute_pentachora_similarities(features_proj)
349
+
350
+ # Scale by temperature
351
+ logits = similarities * self.temperature.exp()
352
+
353
+ output['logits'] = logits
354
+ output['similarities'] = similarities
355
+
356
+ if return_features:
357
+ output['features'] = features # Original transformer features
358
+ output['features_proj'] = features_proj # Projected features
359
+
360
+ return output
361
+
362
+
363
+ # Test - requires external setup
364
+ if __name__ == "__main__":
365
+ print("BaselineViT requires:")
366
+ print(" 1. PentachoronStabilizer loaded externally")
367
+ print(" 2. pentachora_batch tensor [num_classes, 5, vocab_dim]")
368
+ print("\nNo random initialization. No fallbacks.")