AbstractPhil commited on
Commit
d9646ca
·
verified ·
1 Parent(s): 12d0dc7

Create vit_beatrix.py

Browse files
Files changed (1) hide show
  1. vit_beatrix.py +488 -0
vit_beatrix.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline Vision Transformer with Frozen Pentachora Embeddings
3
+ Adapted for L1-normalized pentachora vertices
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from einops import rearrange
11
+ import math
12
+ from typing import Optional, Tuple, Dict, Any
13
+
14
+
15
+ class PentachoraEmbedding(nn.Module):
16
+ """
17
+ A single frozen pentachora embedding (5 vertices in geometric space).
18
+ Supports both L1 and L2 normalized vertices.
19
+ """
20
+
21
+ def __init__(self, vertices: torch.Tensor, norm_type: str = 'l1'):
22
+ super().__init__()
23
+
24
+ self.embed_dim = vertices.shape[-1]
25
+ self.norm_type = norm_type
26
+
27
+ # Store provided vertices as frozen buffer
28
+ self.register_buffer('vertices', vertices)
29
+ self.vertices.requires_grad = False
30
+
31
+ # Precompute normalized versions and centroid
32
+ with torch.no_grad():
33
+ # For L1-normalized data, use L1 norm for consistency
34
+ if norm_type == 'l1':
35
+ # L1 normalize (sum of abs values = 1)
36
+ self.register_buffer('vertices_norm',
37
+ vertices / (vertices.abs().sum(dim=-1, keepdim=True) + 1e-8))
38
+ else:
39
+ # L2 normalize (euclidean norm = 1)
40
+ self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
41
+
42
+ self.register_buffer('centroid', self.vertices.mean(dim=0))
43
+
44
+ # Centroid normalization matches vertex normalization
45
+ if norm_type == 'l1':
46
+ self.register_buffer('centroid_norm',
47
+ self.centroid / (self.centroid.abs().sum() + 1e-8))
48
+ else:
49
+ self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
50
+
51
+ def get_vertices(self) -> torch.Tensor:
52
+ """Get all 5 vertices."""
53
+ return self.vertices
54
+
55
+ def get_centroid(self) -> torch.Tensor:
56
+ """Get the centroid of the pentachora."""
57
+ return self.centroid
58
+
59
+ def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ Compute Rose similarity score with this pentachora.
62
+ Scaled appropriately for L1 norm.
63
+ """
64
+ verts = self.vertices.unsqueeze(0) # [1, 5, D]
65
+ if features.dim() == 1:
66
+ features = features.unsqueeze(0)
67
+
68
+ B = features.shape[0]
69
+ if B > 1:
70
+ verts = verts.expand(B, -1, -1)
71
+
72
+ # For L1 norm, scale the rose score appropriately
73
+ score = PentachoronStabilizer.rose_score_magnitude(features, verts)
74
+ if self.norm_type == 'l1':
75
+ # L1 norm produces smaller values, so amplify the signal
76
+ score = score * 10.0
77
+ return score
78
+
79
+ def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
80
+ """
81
+ Compute similarity between features and this pentachora.
82
+ """
83
+ if mode == 'rose':
84
+ return self.compute_rose_score(features)
85
+
86
+ # Normalize features according to norm type
87
+ if self.norm_type == 'l1':
88
+ features_norm = features / (features.abs().sum(dim=-1, keepdim=True) + 1e-8)
89
+ else:
90
+ features_norm = F.normalize(features, dim=-1)
91
+
92
+ if mode == 'centroid':
93
+ # Dot product with centroid
94
+ sim = torch.sum(features_norm * self.centroid_norm, dim=-1)
95
+ # Scale up L1 similarities to be comparable to L2
96
+ if self.norm_type == 'l1':
97
+ sim = sim * 10.0
98
+ return sim
99
+ else: # mode == 'max'
100
+ # Max similarity across vertices
101
+ sims = torch.matmul(features_norm, self.vertices_norm.T)
102
+ if self.norm_type == 'l1':
103
+ sims = sims * 10.0
104
+ return sims.max(dim=-1)[0]
105
+
106
+
107
+ class TransformerBlock(nn.Module):
108
+ """Standard transformer block with multi-head attention and MLP."""
109
+
110
+ def __init__(
111
+ self,
112
+ dim: int,
113
+ num_heads: int = 8,
114
+ mlp_ratio: float = 4.0,
115
+ dropout: float = 0.0,
116
+ attn_dropout: float = 0.0
117
+ ):
118
+ super().__init__()
119
+
120
+ self.norm1 = nn.LayerNorm(dim)
121
+ self.attn = nn.MultiheadAttention(
122
+ dim,
123
+ num_heads,
124
+ dropout=attn_dropout,
125
+ batch_first=True
126
+ )
127
+
128
+ self.norm2 = nn.LayerNorm(dim)
129
+ mlp_hidden_dim = int(dim * mlp_ratio)
130
+ self.mlp = nn.Sequential(
131
+ nn.Linear(dim, mlp_hidden_dim),
132
+ nn.GELU(),
133
+ nn.Dropout(dropout),
134
+ nn.Linear(mlp_hidden_dim, dim),
135
+ nn.Dropout(dropout)
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ # Self-attention
140
+ x_norm = self.norm1(x)
141
+ attn_out, _ = self.attn(x_norm, x_norm, x_norm)
142
+ x = x + attn_out
143
+
144
+ # MLP
145
+ x = x + self.mlp(self.norm2(x))
146
+
147
+ return x
148
+
149
+
150
+ class BaselineViT(nn.Module):
151
+ """
152
+ Vision Transformer with frozen pentachora embeddings.
153
+ - Preserves L1 law for pentachora geometry.
154
+ - Uses L2 angles for RoseFace (ArcFace/CosFace/SphereFace) classification.
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ pentachora_list: list, # List of torch.Tensor, each [5, vocab_dim]
160
+ vocab_dim: int = 256,
161
+ img_size: int = 32,
162
+ patch_size: int = 4,
163
+ embed_dim: int = 512,
164
+ depth: int = 12,
165
+ num_heads: int = 8,
166
+ mlp_ratio: float = 4.0,
167
+ dropout: float = 0.0,
168
+ attn_dropout: float = 0.0,
169
+ similarity_mode: str = 'rose', # legacy similarity (kept for compatibility)
170
+ norm_type: str = 'l1', # 'l1' or 'l2' normalization for pentachora law
171
+ # --- New RoseFace config ---
172
+ head_type: str = 'roseface', # 'roseface' | 'legacy'
173
+ prototype_mode: str = 'centroid',# 'centroid' | 'rose5' | 'max_vertex'
174
+ margin_type: str = 'cosface', # 'arcface' | 'cosface' | 'sphereface'
175
+ margin_m: float = 0.30,
176
+ scale_s: float = 30.0,
177
+ apply_margin_train_only: bool = False,
178
+ ):
179
+ super().__init__()
180
+
181
+ # Validate pentachora list
182
+ assert isinstance(pentachora_list, list), f"Expected list, got {type(pentachora_list)}"
183
+ assert len(pentachora_list) > 0, "Empty pentachora list"
184
+ for i, penta in enumerate(pentachora_list):
185
+ assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
186
+
187
+ self.num_classes = len(pentachora_list)
188
+ self.embed_dim = embed_dim
189
+ self.num_patches = (img_size // patch_size) ** 2
190
+ self.similarity_mode = similarity_mode
191
+ self.pentachora_dim = vocab_dim
192
+ self.norm_type = norm_type
193
+
194
+ # --- RoseFace config ---
195
+ self.head_type = head_type
196
+ self.prototype_mode = prototype_mode
197
+ self.margin_type = margin_type
198
+ self.margin_m = float(margin_m)
199
+ self.scale_s = float(scale_s)
200
+ self.apply_margin_train_only = apply_margin_train_only
201
+
202
+ # Create individual pentachora embeddings from list
203
+ self.class_pentachora = nn.ModuleList([
204
+ PentachoraEmbedding(vertices=penta, norm_type=norm_type)
205
+ for penta in pentachora_list
206
+ ])
207
+
208
+ # Patch embedding
209
+ self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
210
+
211
+ # CLS token - learnable
212
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
213
+
214
+ # Position embeddings
215
+ self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
216
+ self.pos_drop = nn.Dropout(dropout)
217
+
218
+ # Transformer blocks
219
+ self.blocks = nn.ModuleList([
220
+ TransformerBlock(
221
+ dim=embed_dim,
222
+ num_heads=num_heads,
223
+ mlp_ratio=mlp_ratio,
224
+ dropout=dropout,
225
+ attn_dropout=attn_dropout
226
+ )
227
+ for _ in range(depth)
228
+ ])
229
+
230
+ # Final norm
231
+ self.norm = nn.LayerNorm(embed_dim)
232
+
233
+ # Project to pentachora dimension if needed
234
+ if self.pentachora_dim != embed_dim:
235
+ self.to_pentachora_dim = nn.Linear(embed_dim, self.pentachora_dim)
236
+ else:
237
+ self.to_pentachora_dim = nn.Identity()
238
+
239
+ # Legacy temperature (used only if head_type == 'legacy')
240
+ if norm_type == 'l1':
241
+ self.temperature = nn.Parameter(torch.zeros(1)) # exp(0)=1
242
+ else:
243
+ self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
244
+
245
+ # Precompute all centroids (buffers) for legacy path
246
+ self.register_buffer(
247
+ 'all_centroids',
248
+ torch.stack([penta.centroid for penta in self.class_pentachora])
249
+ )
250
+ if norm_type == 'l1':
251
+ centroids_normalized = self.all_centroids / (
252
+ self.all_centroids.abs().sum(dim=-1, keepdim=True) + 1e-8)
253
+ else:
254
+ centroids_normalized = F.normalize(self.all_centroids, dim=-1)
255
+ self.register_buffer('all_centroids_norm', centroids_normalized)
256
+
257
+ # Face weights for rose5 prototypes (10 triads)
258
+ face_triplets = torch.tensor([
259
+ [0,1,2],[0,1,3],[0,1,4],
260
+ [0,2,3],[0,2,4],[0,3,4],
261
+ [1,2,3],[1,2,4],[1,3,4],
262
+ [2,3,4]
263
+ ], dtype=torch.long)
264
+ face_weights = torch.zeros(10, 5, dtype=torch.float32)
265
+ for r, (i,j,k) in enumerate(face_triplets):
266
+ face_weights[r, i] = face_weights[r, j] = face_weights[r, k] = 1.0/3.0
267
+ self.register_buffer('rose_face_weights', face_weights, persistent=False)
268
+
269
+ # Initialize weights
270
+ self.init_weights()
271
+
272
+ # Record config for checkpoint saving
273
+ self.config = getattr(self, 'config', {})
274
+ self.config.update({
275
+ 'head_type': self.head_type,
276
+ 'prototype_mode': self.prototype_mode,
277
+ 'margin_type': self.margin_type,
278
+ 'margin_m': self.margin_m,
279
+ 'scale_s': self.scale_s,
280
+ 'apply_margin_train_only': self.apply_margin_train_only,
281
+ 'norm_type': self.norm_type,
282
+ 'similarity_mode': self.similarity_mode,
283
+ 'pentachora_dim': self.pentachora_dim,
284
+ })
285
+
286
+ def init_weights(self):
287
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
288
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
289
+ for m in self.modules():
290
+ if isinstance(m, nn.Linear):
291
+ nn.init.trunc_normal_(m.weight, std=0.02)
292
+ if m.bias is not None:
293
+ nn.init.zeros_(m.bias)
294
+ elif isinstance(m, nn.LayerNorm):
295
+ nn.init.ones_(m.weight)
296
+ nn.init.zeros_(m.bias)
297
+
298
+ # ---- Legacy helper (kept) ----
299
+ def get_class_centroids(self) -> torch.Tensor:
300
+ return self.all_centroids_norm
301
+
302
+ # ---- Legacy similarity (kept for compatibility & debugging) ----
303
+ def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
304
+ if self.similarity_mode == 'rose':
305
+ all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora])
306
+ features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1)
307
+ scores = PentachoronStabilizer.rose_score_magnitude(
308
+ features_exp.reshape(-1, self.pentachora_dim),
309
+ all_vertices.repeat(features.shape[0], 1, 1)
310
+ ).reshape(features.shape[0], -1)
311
+ if self.norm_type == 'l1':
312
+ scores = scores * 10.0
313
+ return scores
314
+ else:
315
+ if self.norm_type == 'l1':
316
+ features_norm = features / (features.abs().sum(dim=-1, keepdim=True) + 1e-8)
317
+ else:
318
+ features_norm = F.normalize(features, dim=-1)
319
+ centroids = self.get_class_centroids()
320
+ sims = torch.matmul(features_norm, centroids.T)
321
+ if self.norm_type == 'l1':
322
+ sims = sims * 10.0
323
+ return sims
324
+
325
+ # ---- RoseFace utilities ----
326
+ @staticmethod
327
+ def _l2_norm(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
328
+ return x / (x.norm(p=2, dim=-1, keepdim=True) + eps)
329
+
330
+ def _get_class_vertices_l2(self) -> torch.Tensor:
331
+ """[C,5,D] L2-normalized vertices for all classes."""
332
+ V = torch.stack([p.vertices for p in self.class_pentachora], dim=0)
333
+ V = V.to(self.pos_embed.device, dtype=self.pos_embed.dtype)
334
+ return self._l2_norm(V)
335
+
336
+ def _get_prototypes(self, mode: Optional[str] = None) -> Optional[torch.Tensor]:
337
+ """
338
+ Prototypes [C,D] for 'centroid'/'rose5'; None for 'max_vertex'.
339
+ """
340
+ mode = mode or self.prototype_mode
341
+ device = self.pos_embed.device
342
+ dtype = self.pos_embed.dtype
343
+
344
+ if mode == 'centroid':
345
+ C = torch.stack([p.centroid for p in self.class_pentachora], dim=0).to(device, dtype)
346
+ return self._l2_norm(C)
347
+
348
+ elif mode == 'rose5':
349
+ V_l2 = self._get_class_vertices_l2() # [C,5,D]
350
+ W = self.rose_face_weights.to(device=device, dtype=dtype) # [10,5]
351
+ faces = torch.einsum('tf,cfd->ctd', W, V_l2) # [C,10,D]
352
+ verts_mean = V_l2.mean(dim=1) # [C,D]
353
+ faces_mean = faces.mean(dim=1) # [C,D]
354
+ alpha, beta = 1.0, 0.5
355
+ proto = alpha * verts_mean + beta * faces_mean
356
+ return self._l2_norm(proto)
357
+
358
+ elif mode == 'max_vertex':
359
+ return None
360
+
361
+ else:
362
+ raise ValueError(f"Unknown prototype_mode: {mode}")
363
+
364
+ def _cosine_matrix(self, z_l2: torch.Tensor) -> torch.Tensor:
365
+ """
366
+ Pre-margin cosine [B,C] based on prototype_mode.
367
+ """
368
+ if self.prototype_mode in ('centroid', 'rose5'):
369
+ P = self._get_prototypes(self.prototype_mode) # [C,D]
370
+ return torch.matmul(z_l2, P.t()) # [B,C]
371
+ elif self.prototype_mode == 'max_vertex':
372
+ V_l2 = self._get_class_vertices_l2() # [C,5,D]
373
+ cos_cv = torch.einsum('bd,cvd->bcv', z_l2, V_l2) # [B,C,5]
374
+ cos_max, _ = cos_cv.max(dim=2) # [B,C]
375
+ return cos_max
376
+ else:
377
+ raise ValueError(f"Unknown prototype_mode: {self.prototype_mode}")
378
+
379
+ @staticmethod
380
+ def _apply_margin(cosine: torch.Tensor, targets: torch.Tensor, m: float, kind: str = 'cosface') -> torch.Tensor:
381
+ """
382
+ Apply margin to target class cosines. Returns adjusted cosines [B,C].
383
+ """
384
+ eps = 1e-7
385
+ B, C = cosine.shape
386
+ y = targets.view(-1, 1) # [B,1]
387
+
388
+ if kind == 'cosface':
389
+ cos_m = cosine.clone()
390
+ cos_m.scatter_(1, y, (cosine.gather(1, y) - m))
391
+ return cos_m
392
+
393
+ theta = torch.acos(torch.clamp(cosine.gather(1, y), -1.0 + eps, 1.0 - eps)) # [B,1]
394
+ if kind == 'arcface':
395
+ cos_margin = torch.cos(theta + m)
396
+ elif kind == 'sphereface':
397
+ cos_margin = torch.cos(m * theta)
398
+ else:
399
+ raise ValueError(f"Unknown margin type: {kind}")
400
+
401
+ cos_m = cosine.clone()
402
+ cos_m.scatter_(1, y, cos_margin)
403
+ return cos_m
404
+
405
+ def schedule_roseface(
406
+ self, epoch: int, warmup_epochs: int = 15, s_start: float = 10.0, s_final: float = 30.0,
407
+ m_start: Optional[float] = None, m_final: Optional[float] = None
408
+ ):
409
+ """
410
+ Deterministic cosine ramp for scale s (and optional margin m).
411
+ """
412
+ t = max(0.0, min(1.0, epoch / max(1, warmup_epochs)))
413
+ # cosine ramp from s_start -> s_final
414
+ self.scale_s = float(s_final - 0.5 * (1.0 + np.cos(np.pi * t)) * (s_final - s_start))
415
+ if (m_start is not None) and (m_final is not None):
416
+ self.margin_m = float(m_final - 0.5 * (1.0 + np.cos(np.pi * t)) * (m_final - m_start))
417
+
418
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
419
+ B = x.shape[0]
420
+ x = self.patch_embed(x) # [B, embed_dim, H', W']
421
+ x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
422
+ cls_tokens = self.cls_token.expand(B, -1, -1)
423
+ x = torch.cat([cls_tokens, x], dim=1)
424
+ x = x + self.pos_embed
425
+ x = self.pos_drop(x)
426
+ for block in self.blocks:
427
+ x = block(x)
428
+ x = self.norm(x)
429
+ return x[:, 0]
430
+
431
+ def forward(
432
+ self,
433
+ x: torch.Tensor,
434
+ return_features: bool = False,
435
+ targets: Optional[torch.Tensor] = None # NEW: required for margin at train time
436
+ ) -> Dict[str, torch.Tensor]:
437
+
438
+ features = self.forward_features(x)
439
+ output: Dict[str, torch.Tensor] = {}
440
+
441
+ # Project to pentachora dimension (L1 law applies here)
442
+ features_proj = self.to_pentachora_dim(features)
443
+ if self.norm_type == 'l1':
444
+ features_proj = features_proj / (features_proj.abs().sum(dim=-1, keepdim=True) + 1e-8)
445
+
446
+ if self.head_type == 'roseface':
447
+ # L2 angles for classification head (dual-norm bridge)
448
+ z_l2 = features_proj / (features_proj.norm(p=2, dim=-1, keepdim=True) + 1e-12)
449
+
450
+ # Pre-margin cosines [B,C]
451
+ cos_pre = self._cosine_matrix(z_l2)
452
+
453
+ # Apply margin (train-time if configured)
454
+ if (self.apply_margin_train_only and not self.training) or (targets is None):
455
+ cos_post = cos_pre
456
+ else:
457
+ cos_post = self._apply_margin(cos_pre, targets, self.margin_m, self.margin_type)
458
+
459
+ # Scaled logits
460
+ logits = self.scale_s * cos_post
461
+
462
+ # Emit outputs
463
+ output['logits'] = logits # for CE
464
+ output['similarities'] = cos_pre # pre-margin (for alignment / diagnostics)
465
+ if return_features:
466
+ output['features'] = features
467
+ output['features_proj'] = features_proj
468
+
469
+ else:
470
+ # Legacy path (kept for compatibility)
471
+ similarities = self.compute_pentachora_similarities(features_proj)
472
+ logits = similarities * self.temperature.exp()
473
+ output['logits'] = logits
474
+ output['similarities'] = similarities
475
+ if return_features:
476
+ output['features'] = features
477
+ output['features_proj'] = features_proj
478
+
479
+ return output
480
+
481
+
482
+
483
+ # Test - requires external setup
484
+ if __name__ == "__main__":
485
+ print("BaselineViT requires:")
486
+ print(" 1. PentachoronStabilizer loaded externally")
487
+ print(" 2. pentachora_batch tensor [num_classes, 5, vocab_dim]")
488
+ print("\nNo random initialization. No fallbacks.")