AbstractPhil commited on
Commit
f3307b8
·
verified ·
1 Parent(s): 17c850b

Reverted the experimental theta head.

Browse files
Files changed (1) hide show
  1. vit_zana_v3.py +101 -214
vit_zana_v3.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Baseline Vision Transformer with Frozen Pentachora Embeddings
3
- Now with optional theta rotation head for better classification
 
4
  """
5
 
6
  import torch
@@ -15,16 +16,17 @@ from typing import Optional, Tuple, Dict, Any
15
  class PentachoraEmbedding(nn.Module):
16
  """
17
  A single frozen pentachora embedding (5 vertices in geometric space).
18
- Now with theta rotation capabilities.
19
  """
20
 
21
  def __init__(self, vertices: torch.Tensor):
22
  super().__init__()
 
23
 
24
  self.embed_dim = vertices.shape[-1]
25
 
26
  # Store provided vertices as frozen buffer
27
- self.register_buffer('vertices', vertices.cpu().contiguous().detach().clone().to(get_default_device()))
28
  self.vertices.requires_grad = False
29
 
30
  # Precompute normalized versions and centroid
@@ -32,27 +34,26 @@ class PentachoraEmbedding(nn.Module):
32
  self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
33
  self.register_buffer('centroid', self.vertices.mean(dim=0))
34
  self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
35
-
36
- # Compute theta bases for rotation
37
- self.register_buffer('theta_bases', self._compute_theta_bases().cpu().contiguous().detach().clone().to(get_default_device()))
38
-
39
- def _compute_theta_bases(self) -> torch.Tensor:
40
- """Compute orthogonal bases from vertices for theta rotation."""
41
- U, S, V = torch.svd(self.vertices)
42
- n_components = min(5, self.embed_dim)
43
- return V[:, :n_components] # [embed_dim, n_components]
44
 
45
  def get_vertices(self) -> torch.Tensor:
 
46
  return self.vertices
47
 
48
  def get_centroid(self) -> torch.Tensor:
 
49
  return self.centroid
50
 
51
  def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
52
- verts = self.vertices.unsqueeze(0)
 
 
 
 
 
53
  if features.dim() == 1:
54
  features = features.unsqueeze(0)
55
 
 
56
  B = features.shape[0]
57
  if B > 1:
58
  verts = verts.expand(B, -1, -1)
@@ -60,110 +61,28 @@ class PentachoraEmbedding(nn.Module):
60
  return PentachoronStabilizer.rose_score_magnitude(features, verts)
61
 
62
  def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
63
  if mode == 'rose':
64
  return self.compute_rose_score(features)
65
 
66
  features_norm = F.normalize(features, dim=-1)
67
 
68
  if mode == 'centroid':
 
69
  return torch.matmul(features_norm, self.centroid_norm)
70
  else: # mode == 'max'
 
71
  sims = torch.matmul(features_norm, self.vertices_norm.T)
72
  return sims.max(dim=-1)[0]
73
-
74
- def compute_theta_features(self, features: torch.Tensor) -> torch.Tensor:
75
- """
76
- Project features to theta space defined by this pentachora.
77
- Returns angular features for feedforward classification.
78
- """
79
- # Project onto pentachora bases
80
- projections = torch.matmul(features, self.theta_bases) # [batch, 5]
81
-
82
- # Compute angles relative to centroid
83
- centroid_proj = torch.matmul(self.centroid.unsqueeze(0), self.theta_bases)
84
- angles = torch.atan2(projections, centroid_proj + 1e-8)
85
-
86
- # Return sin/cos encoding
87
- return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).to(get_default_device()) # [batch, 10]
88
-
89
-
90
- class ThetaHead(nn.Module):
91
- """
92
- Theta-based classification head using angular representations.
93
- Replaces similarity matching with learned feedforward.
94
- """
95
-
96
- def __init__(
97
- self,
98
- embed_dim: int,
99
- num_classes: int,
100
- n_pentachora: int = 10, # Use subset of pentachora for theta
101
- hidden_dim: int = 256,
102
- dropout: float = 0.1
103
- ):
104
- super().__init__()
105
-
106
- self.n_pentachora = n_pentachora
107
- self.embed_dim = embed_dim
108
-
109
- # Each pentachora gives 10 theta features (5 sin + 5 cos)
110
- theta_dim = n_pentachora * 10
111
-
112
- # Project to theta space
113
- self.to_theta = nn.Sequential(
114
- nn.Linear(embed_dim, hidden_dim),
115
- nn.LayerNorm(hidden_dim),
116
- nn.GELU(),
117
- nn.Dropout(dropout),
118
- nn.Linear(hidden_dim, theta_dim)
119
- )
120
-
121
- # Classify from theta
122
- self.classifier = nn.Sequential(
123
- nn.LayerNorm(theta_dim),
124
- nn.Dropout(dropout),
125
- nn.Linear(theta_dim, num_classes)
126
- )
127
-
128
- # Learnable temperature
129
- self.temperature = nn.Parameter(torch.ones(1) * 0.1)
130
-
131
- def forward(self, features: torch.Tensor, pentachora_list: nn.ModuleList) -> Dict[str, torch.Tensor]:
132
- """
133
- Classify using theta rotation.
134
-
135
- Args:
136
- features: [batch, embed_dim] CLS features
137
- pentachora_list: List of PentachoraEmbedding modules
138
- """
139
- # Get theta features from first n pentachora
140
- theta_features = []
141
- for i in range(min(self.n_pentachora, len(pentachora_list))):
142
- theta = pentachora_list[i].compute_theta_features(features)
143
- theta_features.append(theta)
144
-
145
- # Concatenate all theta features
146
- theta_concat = torch.cat(theta_features, dim=-1) # [batch, n_pentachora * 10]
147
-
148
- # If we have fewer pentachora than expected, pad with zeros
149
- if len(theta_features) < self.n_pentachora:
150
- pad_size = (self.n_pentachora - len(theta_features)) * 10
151
- padding = torch.zeros(features.shape[0], pad_size, device=features.device)
152
- theta_concat = torch.cat([theta_concat, padding], dim=-1)
153
-
154
- # Project through MLP
155
- theta_proj = self.to_theta(features)
156
-
157
- # Combine with geometric theta (residual connection)
158
- theta_combined = theta_concat + 0.1 * theta_proj
159
-
160
- # Classify
161
- logits = self.classifier(theta_combined) / self.temperature.exp()
162
-
163
- return {
164
- 'logits': logits,
165
- 'theta_features': theta_combined
166
- }
167
 
168
 
169
  class TransformerBlock(nn.Module):
@@ -198,22 +117,25 @@ class TransformerBlock(nn.Module):
198
  )
199
 
200
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
201
  x_norm = self.norm1(x)
202
  attn_out, _ = self.attn(x_norm, x_norm, x_norm)
203
  x = x + attn_out
 
 
204
  x = x + self.mlp(self.norm2(x))
 
205
  return x
206
 
207
 
208
  class BaselineViT(nn.Module):
209
  """
210
- Vision Transformer with optional theta-based classification.
211
- Can switch between similarity-based and theta-based heads.
212
  """
213
 
214
  def __init__(
215
  self,
216
- pentachora_list: list,
217
  vocab_dim: int = 256,
218
  img_size: int = 32,
219
  patch_size: int = 4,
@@ -223,23 +145,25 @@ class BaselineViT(nn.Module):
223
  mlp_ratio: float = 4.0,
224
  dropout: float = 0.0,
225
  attn_dropout: float = 0.0,
226
- similarity_mode: str = 'rose',
227
- use_theta_head: bool = True, # NEW: Toggle theta head
228
- theta_n_pentachora: int = 2, # NEW: How many pentachora for theta
229
- theta_hidden_dim: int = 256 # NEW: Hidden dim for theta MLP
230
  ):
231
  super().__init__()
232
 
233
- assert isinstance(pentachora_list, list) and len(pentachora_list) > 0
 
 
 
 
 
 
234
 
235
  self.num_classes = len(pentachora_list)
236
  self.embed_dim = embed_dim
237
  self.num_patches = (img_size // patch_size) ** 2
238
  self.similarity_mode = similarity_mode
239
  self.pentachora_dim = vocab_dim
240
- self.use_theta_head = use_theta_head
241
 
242
- # Create pentachora embeddings
243
  self.class_pentachora = nn.ModuleList([
244
  PentachoraEmbedding(vertices=penta)
245
  for penta in pentachora_list
@@ -248,7 +172,7 @@ class BaselineViT(nn.Module):
248
  # Patch embedding
249
  self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
250
 
251
- # CLS token
252
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
253
 
254
  # Position embeddings
@@ -276,33 +200,23 @@ class BaselineViT(nn.Module):
276
  else:
277
  self.to_pentachora_dim = nn.Identity()
278
 
279
- # Classification heads
280
- if use_theta_head:
281
- # NEW: Theta-based classification
282
- self.theta_head = ThetaHead(
283
- embed_dim=self.pentachora_dim,
284
- num_classes=self.num_classes,
285
- n_pentachora=theta_n_pentachora,
286
- hidden_dim=theta_hidden_dim,
287
- dropout=dropout
288
- )
289
- else:
290
- # Original: Similarity-based classification
291
- self.theta_head = None
292
- self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
293
-
294
- self.register_buffer(
295
- 'all_centroids',
296
- torch.stack([penta.centroid for penta in self.class_pentachora])
297
- )
298
- self.register_buffer(
299
- 'all_centroids_norm',
300
- F.normalize(self.all_centroids, dim=-1)
301
- )
302
 
 
 
 
 
 
 
 
 
 
 
303
  self.init_weights()
304
 
305
  def init_weights(self):
 
306
  nn.init.trunc_normal_(self.cls_token, std=0.02)
307
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
308
 
@@ -315,33 +229,35 @@ class BaselineViT(nn.Module):
315
  nn.init.ones_(m.weight)
316
  nn.init.zeros_(m.bias)
317
 
 
318
  def get_class_centroids(self) -> torch.Tensor:
319
- if self.use_theta_head:
320
- # Return centroids from pentachora for compatibility
321
- centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora])
322
- return centroids
323
- else:
324
- return self.all_centroids_norm
325
-
326
  def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
 
 
 
327
  if self.similarity_mode == 'rose':
328
- all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora])
329
- features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1)
330
- return PentachoronStabilizer.rose_score_magnitude(
331
- features_exp.reshape(-1, self.pentachora_dim),
332
- all_vertices.repeat(features.shape[0], 1, 1)
333
- ).reshape(features.shape[0], -1)
334
  else:
335
- centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora])
336
- features_norm = F.normalize(features, dim=-1)
337
- return torch.matmul(features_norm, centroids.T)
 
 
338
 
339
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
 
340
  B = x.shape[0]
341
 
342
  # Patch embedding
343
- x = self.patch_embed(x)
344
- x = x.flatten(2).transpose(1, 2)
345
 
346
  # Add CLS token
347
  cls_tokens = self.cls_token.expand(B, -1, -1)
@@ -363,31 +279,35 @@ class BaselineViT(nn.Module):
363
 
364
  def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
365
  """
366
- Forward pass with optional theta head.
 
 
 
 
 
367
  """
368
  features = self.forward_features(x)
 
369
  output = {}
370
 
371
  # Project to pentachora dimension
372
  features_proj = self.to_pentachora_dim(features)
373
 
374
- if self.use_theta_head:
375
- # NEW: Use theta-based classification
376
- theta_output = self.theta_head(features_proj, self.class_pentachora)
377
- output['logits'] = theta_output['logits']
378
- output['theta_features'] = theta_output['theta_features']
379
-
380
- # Still compute similarities for analysis
381
- with torch.no_grad():
382
- similarities = self.compute_pentachora_similarities(features_proj)
383
- output['similarities'] = similarities
384
- else:
385
- # Original: Use similarity-based classification
386
  similarities = self.compute_pentachora_similarities(features_proj)
387
- logits = similarities * self.temperature.exp()
388
-
389
- output['logits'] = logits
390
- output['similarities'] = similarities
 
 
 
 
 
 
 
391
 
392
  if return_features:
393
  output['features'] = features
@@ -395,42 +315,9 @@ class BaselineViT(nn.Module):
395
  return output
396
 
397
 
398
- # Helper function to convert existing model to theta
399
- def enable_theta_head(model: BaselineViT, n_pentachora: int = 10, hidden_dim: int = 256):
400
- """
401
- Convert an existing similarity-based model to use theta head.
402
- This modifies the model in-place.
403
- """
404
- if model.use_theta_head:
405
- print("Model already using theta head")
406
- return model
407
-
408
- print(f"Converting to theta head with {n_pentachora} pentachora...")
409
-
410
- # Create theta head
411
- model.theta_head = ThetaHead(
412
- embed_dim=model.pentachora_dim,
413
- num_classes=model.num_classes,
414
- n_pentachora=n_pentachora,
415
- hidden_dim=hidden_dim,
416
- dropout=0.1
417
- ).to(next(model.parameters()).device)
418
-
419
- # Set flag
420
- model.use_theta_head = True
421
-
422
- # Initialize new parameters
423
- for m in model.theta_head.modules():
424
- if isinstance(m, nn.Linear):
425
- nn.init.trunc_normal_(m.weight, std=0.02)
426
- if m.bias is not None:
427
- nn.init.zeros_(m.bias)
428
-
429
- print("✓ Theta head enabled")
430
- return model
431
-
432
-
433
  if __name__ == "__main__":
434
- print("BaselineViT with optional theta head")
435
- print("Use 'use_theta_head=True' to enable theta classification")
436
- print("Or call enable_theta_head() on existing model")
 
 
1
  """
2
  Baseline Vision Transformer with Frozen Pentachora Embeddings
3
+ Clean architecture with geometric semantic anchors
4
+ Assumes PentachoronStabilizer is loaded externally
5
  """
6
 
7
  import torch
 
16
  class PentachoraEmbedding(nn.Module):
17
  """
18
  A single frozen pentachora embedding (5 vertices in geometric space).
19
+ Accepts pre-computed vertices only. No random initialization.
20
  """
21
 
22
  def __init__(self, vertices: torch.Tensor):
23
  super().__init__()
24
+ #assert vertices.shape == (5, 128), f"Expected shape (5, 128), got {vertices.shape}"
25
 
26
  self.embed_dim = vertices.shape[-1]
27
 
28
  # Store provided vertices as frozen buffer
29
+ self.register_buffer('vertices', vertices)
30
  self.vertices.requires_grad = False
31
 
32
  # Precompute normalized versions and centroid
 
34
  self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
35
  self.register_buffer('centroid', self.vertices.mean(dim=0))
36
  self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
 
 
 
 
 
 
 
 
 
37
 
38
  def get_vertices(self) -> torch.Tensor:
39
+ """Get all 5 vertices."""
40
  return self.vertices
41
 
42
  def get_centroid(self) -> torch.Tensor:
43
+ """Get the centroid of the pentachora."""
44
  return self.centroid
45
 
46
  def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Compute Rose similarity score with this pentachora.
49
+ Uses external PentachoronStabilizer.rose_score_magnitude
50
+ """
51
+ # Prepare vertices for rose scoring
52
+ verts = self.vertices.unsqueeze(0) # [1, 5, D]
53
  if features.dim() == 1:
54
  features = features.unsqueeze(0)
55
 
56
+ # Expand vertices to batch size if needed
57
  B = features.shape[0]
58
  if B > 1:
59
  verts = verts.expand(B, -1, -1)
 
61
  return PentachoronStabilizer.rose_score_magnitude(features, verts)
62
 
63
  def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
64
+ """
65
+ Compute similarity between features and this pentachora.
66
+
67
+ Args:
68
+ features: [batch, dim] or [batch, seq, dim]
69
+ mode: 'centroid', 'max' (max over vertices), or 'rose' (Rose score)
70
+
71
+ Returns:
72
+ similarities: [batch] or [batch, seq]
73
+ """
74
  if mode == 'rose':
75
  return self.compute_rose_score(features)
76
 
77
  features_norm = F.normalize(features, dim=-1)
78
 
79
  if mode == 'centroid':
80
+ # Dot product with centroid
81
  return torch.matmul(features_norm, self.centroid_norm)
82
  else: # mode == 'max'
83
+ # Max similarity across vertices
84
  sims = torch.matmul(features_norm, self.vertices_norm.T)
85
  return sims.max(dim=-1)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  class TransformerBlock(nn.Module):
 
117
  )
118
 
119
  def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ # Self-attention
121
  x_norm = self.norm1(x)
122
  attn_out, _ = self.attn(x_norm, x_norm, x_norm)
123
  x = x + attn_out
124
+
125
+ # MLP
126
  x = x + self.mlp(self.norm2(x))
127
+
128
  return x
129
 
130
 
131
  class BaselineViT(nn.Module):
132
  """
133
+ Clean baseline Vision Transformer with frozen pentachora embeddings.
 
134
  """
135
 
136
  def __init__(
137
  self,
138
+ pentachora_list: list, # List of torch.Tensor, each [5, vocab_dim]
139
  vocab_dim: int = 256,
140
  img_size: int = 32,
141
  patch_size: int = 4,
 
145
  mlp_ratio: float = 4.0,
146
  dropout: float = 0.0,
147
  attn_dropout: float = 0.0,
148
+ similarity_mode: str = 'rose' # 'centroid', 'max', or 'rose'
 
 
 
149
  ):
150
  super().__init__()
151
 
152
+ # Validate pentachora list
153
+ assert isinstance(pentachora_list, list), f"Expected list, got {type(pentachora_list)}"
154
+ assert len(pentachora_list) > 0, "Empty pentachora list"
155
+
156
+ # Validate each pentachora
157
+ for i, penta in enumerate(pentachora_list):
158
+ assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
159
 
160
  self.num_classes = len(pentachora_list)
161
  self.embed_dim = embed_dim
162
  self.num_patches = (img_size // patch_size) ** 2
163
  self.similarity_mode = similarity_mode
164
  self.pentachora_dim = vocab_dim
 
165
 
166
+ # Create individual pentachora embeddings from list
167
  self.class_pentachora = nn.ModuleList([
168
  PentachoraEmbedding(vertices=penta)
169
  for penta in pentachora_list
 
172
  # Patch embedding
173
  self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
174
 
175
+ # CLS token - learnable
176
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
177
 
178
  # Position embeddings
 
200
  else:
201
  self.to_pentachora_dim = nn.Identity()
202
 
203
+ # Temperature for similarity-based classification
204
+ self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ self.register_buffer(
207
+ 'all_centroids',
208
+ torch.stack([penta.centroid for penta in self.class_pentachora])
209
+ )
210
+ self.register_buffer(
211
+ 'all_centroids_norm',
212
+ F.normalize(self.all_centroids, dim=-1)
213
+ )
214
+
215
+ # Initialize weights
216
  self.init_weights()
217
 
218
  def init_weights(self):
219
+ """Initialize model weights."""
220
  nn.init.trunc_normal_(self.cls_token, std=0.02)
221
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
222
 
 
229
  nn.init.ones_(m.weight)
230
  nn.init.zeros_(m.bias)
231
 
232
+ # Then get_class_centroids becomes:
233
  def get_class_centroids(self) -> torch.Tensor:
234
+ return self.all_centroids_norm
235
+
 
 
 
 
 
236
  def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
237
+ """
238
+ Compute similarities between features and all class pentachora (vectorized).
239
+ """
240
  if self.similarity_mode == 'rose':
241
+ # Stack all vertices into single tensor for batch Rose scoring
242
+ all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5, vocab_dim]
243
+ # Expand features for batch computation
244
+ features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100, vocab_dim]
245
+ # Compute Rose scores in parallel
246
+ return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, self.embed_dim), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1)
247
  else:
248
+ # Stack all centroids
249
+ centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100, vocab_dim]
250
+ features_norm = F.normalize(features, dim=-1) # [B, vocab_dim]
251
+ return torch.matmul(features_norm, centroids.T) # [B, 100]
252
+
253
 
254
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
255
+ """Extract features from images."""
256
  B = x.shape[0]
257
 
258
  # Patch embedding
259
+ x = self.patch_embed(x) # [B, embed_dim, H', W']
260
+ x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
261
 
262
  # Add CLS token
263
  cls_tokens = self.cls_token.expand(B, -1, -1)
 
279
 
280
  def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
281
  """
282
+ Forward pass.
283
+
284
+ Returns dict with:
285
+ - logits: classification logits
286
+ - features: CLS features (if return_features=True)
287
+ - similarities: raw similarities to pentachora
288
  """
289
  features = self.forward_features(x)
290
+
291
  output = {}
292
 
293
  # Project to pentachora dimension
294
  features_proj = self.to_pentachora_dim(features)
295
 
296
+ # Compute similarities based on mode
297
+ if self.similarity_mode == 'rose':
298
+ # Use Rose scoring
 
 
 
 
 
 
 
 
 
299
  similarities = self.compute_pentachora_similarities(features_proj)
300
+ else:
301
+ # Use centroid or max similarity
302
+ features_norm = F.normalize(features_proj, dim=-1)
303
+ centroids = self.get_class_centroids()
304
+ similarities = torch.matmul(features_norm, centroids.T)
305
+
306
+ # Scale by temperature
307
+ logits = similarities * self.temperature.exp()
308
+
309
+ output['logits'] = logits
310
+ output['similarities'] = similarities
311
 
312
  if return_features:
313
  output['features'] = features
 
315
  return output
316
 
317
 
318
+ # Test - requires external setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  if __name__ == "__main__":
320
+ print("BaselineViT requires:")
321
+ print(" 1. PentachoronStabilizer loaded externally")
322
+ print(" 2. pentachora_batch tensor [num_classes, 5, vocab_dim]")
323
+ print("\nNo random initialization. No fallbacks.")