AbstractPhil commited on
Commit
7465a5c
·
verified ·
1 Parent(s): fa4e774

Added theta experimental head

Browse files
Files changed (1) hide show
  1. vit_zana_v3.py +214 -101
vit_zana_v3.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  Baseline Vision Transformer with Frozen Pentachora Embeddings
3
- Clean architecture with geometric semantic anchors
4
- Assumes PentachoronStabilizer is loaded externally
5
  """
6
 
7
  import torch
@@ -16,17 +15,16 @@ from typing import Optional, Tuple, Dict, Any
16
  class PentachoraEmbedding(nn.Module):
17
  """
18
  A single frozen pentachora embedding (5 vertices in geometric space).
19
- Accepts pre-computed vertices only. No random initialization.
20
  """
21
 
22
  def __init__(self, vertices: torch.Tensor):
23
  super().__init__()
24
- #assert vertices.shape == (5, 128), f"Expected shape (5, 128), got {vertices.shape}"
25
 
26
  self.embed_dim = vertices.shape[-1]
27
 
28
  # Store provided vertices as frozen buffer
29
- self.register_buffer('vertices', vertices)
30
  self.vertices.requires_grad = False
31
 
32
  # Precompute normalized versions and centroid
@@ -34,26 +32,27 @@ class PentachoraEmbedding(nn.Module):
34
  self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
35
  self.register_buffer('centroid', self.vertices.mean(dim=0))
36
  self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
 
 
 
 
 
 
 
 
 
37
 
38
  def get_vertices(self) -> torch.Tensor:
39
- """Get all 5 vertices."""
40
  return self.vertices
41
 
42
  def get_centroid(self) -> torch.Tensor:
43
- """Get the centroid of the pentachora."""
44
  return self.centroid
45
 
46
  def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
47
- """
48
- Compute Rose similarity score with this pentachora.
49
- Uses external PentachoronStabilizer.rose_score_magnitude
50
- """
51
- # Prepare vertices for rose scoring
52
- verts = self.vertices.unsqueeze(0) # [1, 5, D]
53
  if features.dim() == 1:
54
  features = features.unsqueeze(0)
55
 
56
- # Expand vertices to batch size if needed
57
  B = features.shape[0]
58
  if B > 1:
59
  verts = verts.expand(B, -1, -1)
@@ -61,28 +60,110 @@ class PentachoraEmbedding(nn.Module):
61
  return PentachoronStabilizer.rose_score_magnitude(features, verts)
62
 
63
  def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
64
- """
65
- Compute similarity between features and this pentachora.
66
-
67
- Args:
68
- features: [batch, dim] or [batch, seq, dim]
69
- mode: 'centroid', 'max' (max over vertices), or 'rose' (Rose score)
70
-
71
- Returns:
72
- similarities: [batch] or [batch, seq]
73
- """
74
  if mode == 'rose':
75
  return self.compute_rose_score(features)
76
 
77
  features_norm = F.normalize(features, dim=-1)
78
 
79
  if mode == 'centroid':
80
- # Dot product with centroid
81
  return torch.matmul(features_norm, self.centroid_norm)
82
  else: # mode == 'max'
83
- # Max similarity across vertices
84
  sims = torch.matmul(features_norm, self.vertices_norm.T)
85
  return sims.max(dim=-1)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  class TransformerBlock(nn.Module):
@@ -117,25 +198,22 @@ class TransformerBlock(nn.Module):
117
  )
118
 
119
  def forward(self, x: torch.Tensor) -> torch.Tensor:
120
- # Self-attention
121
  x_norm = self.norm1(x)
122
  attn_out, _ = self.attn(x_norm, x_norm, x_norm)
123
  x = x + attn_out
124
-
125
- # MLP
126
  x = x + self.mlp(self.norm2(x))
127
-
128
  return x
129
 
130
 
131
  class BaselineViT(nn.Module):
132
  """
133
- Clean baseline Vision Transformer with frozen pentachora embeddings.
 
134
  """
135
 
136
  def __init__(
137
  self,
138
- pentachora_list: list, # List of torch.Tensor, each [5, vocab_dim]
139
  vocab_dim: int = 256,
140
  img_size: int = 32,
141
  patch_size: int = 4,
@@ -145,25 +223,23 @@ class BaselineViT(nn.Module):
145
  mlp_ratio: float = 4.0,
146
  dropout: float = 0.0,
147
  attn_dropout: float = 0.0,
148
- similarity_mode: str = 'rose' # 'centroid', 'max', or 'rose'
 
 
 
149
  ):
150
  super().__init__()
151
 
152
- # Validate pentachora list
153
- assert isinstance(pentachora_list, list), f"Expected list, got {type(pentachora_list)}"
154
- assert len(pentachora_list) > 0, "Empty pentachora list"
155
-
156
- # Validate each pentachora
157
- for i, penta in enumerate(pentachora_list):
158
- assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor"
159
 
160
  self.num_classes = len(pentachora_list)
161
  self.embed_dim = embed_dim
162
  self.num_patches = (img_size // patch_size) ** 2
163
  self.similarity_mode = similarity_mode
164
  self.pentachora_dim = vocab_dim
 
165
 
166
- # Create individual pentachora embeddings from list
167
  self.class_pentachora = nn.ModuleList([
168
  PentachoraEmbedding(vertices=penta)
169
  for penta in pentachora_list
@@ -172,7 +248,7 @@ class BaselineViT(nn.Module):
172
  # Patch embedding
173
  self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
174
 
175
- # CLS token - learnable
176
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
177
 
178
  # Position embeddings
@@ -200,23 +276,33 @@ class BaselineViT(nn.Module):
200
  else:
201
  self.to_pentachora_dim = nn.Identity()
202
 
203
- # Temperature for similarity-based classification
204
- self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- self.register_buffer(
207
- 'all_centroids',
208
- torch.stack([penta.centroid for penta in self.class_pentachora])
209
- )
210
- self.register_buffer(
211
- 'all_centroids_norm',
212
- F.normalize(self.all_centroids, dim=-1)
213
- )
214
-
215
- # Initialize weights
216
  self.init_weights()
217
 
218
  def init_weights(self):
219
- """Initialize model weights."""
220
  nn.init.trunc_normal_(self.cls_token, std=0.02)
221
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
222
 
@@ -229,35 +315,33 @@ class BaselineViT(nn.Module):
229
  nn.init.ones_(m.weight)
230
  nn.init.zeros_(m.bias)
231
 
232
- # Then get_class_centroids becomes:
233
  def get_class_centroids(self) -> torch.Tensor:
234
- return self.all_centroids_norm
235
-
 
 
 
 
 
236
  def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
237
- """
238
- Compute similarities between features and all class pentachora (vectorized).
239
- """
240
  if self.similarity_mode == 'rose':
241
- # Stack all vertices into single tensor for batch Rose scoring
242
- all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5, vocab_dim]
243
- # Expand features for batch computation
244
- features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100, vocab_dim]
245
- # Compute Rose scores in parallel
246
- return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, self.embed_dim), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1)
247
  else:
248
- # Stack all centroids
249
- centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100, vocab_dim]
250
- features_norm = F.normalize(features, dim=-1) # [B, vocab_dim]
251
- return torch.matmul(features_norm, centroids.T) # [B, 100]
252
-
253
 
254
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
255
- """Extract features from images."""
256
  B = x.shape[0]
257
 
258
  # Patch embedding
259
- x = self.patch_embed(x) # [B, embed_dim, H', W']
260
- x = x.flatten(2).transpose(1, 2) # [B, num_patches, embed_dim]
261
 
262
  # Add CLS token
263
  cls_tokens = self.cls_token.expand(B, -1, -1)
@@ -279,35 +363,31 @@ class BaselineViT(nn.Module):
279
 
280
  def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
281
  """
282
- Forward pass.
283
-
284
- Returns dict with:
285
- - logits: classification logits
286
- - features: CLS features (if return_features=True)
287
- - similarities: raw similarities to pentachora
288
  """
289
  features = self.forward_features(x)
290
-
291
  output = {}
292
 
293
  # Project to pentachora dimension
294
  features_proj = self.to_pentachora_dim(features)
295
 
296
- # Compute similarities based on mode
297
- if self.similarity_mode == 'rose':
298
- # Use Rose scoring
299
- similarities = self.compute_pentachora_similarities(features_proj)
 
 
 
 
 
 
300
  else:
301
- # Use centroid or max similarity
302
- features_norm = F.normalize(features_proj, dim=-1)
303
- centroids = self.get_class_centroids()
304
- similarities = torch.matmul(features_norm, centroids.T)
305
-
306
- # Scale by temperature
307
- logits = similarities * self.temperature.exp()
308
-
309
- output['logits'] = logits
310
- output['similarities'] = similarities
311
 
312
  if return_features:
313
  output['features'] = features
@@ -315,9 +395,42 @@ class BaselineViT(nn.Module):
315
  return output
316
 
317
 
318
- # Test - requires external setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  if __name__ == "__main__":
320
- print("BaselineViT requires:")
321
- print(" 1. PentachoronStabilizer loaded externally")
322
- print(" 2. pentachora_batch tensor [num_classes, 5, vocab_dim]")
323
- print("\nNo random initialization. No fallbacks.")
 
1
  """
2
  Baseline Vision Transformer with Frozen Pentachora Embeddings
3
+ Now with optional theta rotation head for better classification
 
4
  """
5
 
6
  import torch
 
15
  class PentachoraEmbedding(nn.Module):
16
  """
17
  A single frozen pentachora embedding (5 vertices in geometric space).
18
+ Now with theta rotation capabilities.
19
  """
20
 
21
  def __init__(self, vertices: torch.Tensor):
22
  super().__init__()
 
23
 
24
  self.embed_dim = vertices.shape[-1]
25
 
26
  # Store provided vertices as frozen buffer
27
+ self.register_buffer('vertices', vertices.cpu().contiguous().detach().clone().to(get_default_device()))
28
  self.vertices.requires_grad = False
29
 
30
  # Precompute normalized versions and centroid
 
32
  self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
33
  self.register_buffer('centroid', self.vertices.mean(dim=0))
34
  self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
35
+
36
+ # Compute theta bases for rotation
37
+ self.register_buffer('theta_bases', self._compute_theta_bases().cpu().contiguous().detach().clone().to(get_default_device()))
38
+
39
+ def _compute_theta_bases(self) -> torch.Tensor:
40
+ """Compute orthogonal bases from vertices for theta rotation."""
41
+ U, S, V = torch.svd(self.vertices)
42
+ n_components = min(5, self.embed_dim)
43
+ return V[:, :n_components] # [embed_dim, n_components]
44
 
45
  def get_vertices(self) -> torch.Tensor:
 
46
  return self.vertices
47
 
48
  def get_centroid(self) -> torch.Tensor:
 
49
  return self.centroid
50
 
51
  def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
52
+ verts = self.vertices.unsqueeze(0)
 
 
 
 
 
53
  if features.dim() == 1:
54
  features = features.unsqueeze(0)
55
 
 
56
  B = features.shape[0]
57
  if B > 1:
58
  verts = verts.expand(B, -1, -1)
 
60
  return PentachoronStabilizer.rose_score_magnitude(features, verts)
61
 
62
  def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
63
  if mode == 'rose':
64
  return self.compute_rose_score(features)
65
 
66
  features_norm = F.normalize(features, dim=-1)
67
 
68
  if mode == 'centroid':
 
69
  return torch.matmul(features_norm, self.centroid_norm)
70
  else: # mode == 'max'
 
71
  sims = torch.matmul(features_norm, self.vertices_norm.T)
72
  return sims.max(dim=-1)[0]
73
+
74
+ def compute_theta_features(self, features: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Project features to theta space defined by this pentachora.
77
+ Returns angular features for feedforward classification.
78
+ """
79
+ # Project onto pentachora bases
80
+ projections = torch.matmul(features, self.theta_bases) # [batch, 5]
81
+
82
+ # Compute angles relative to centroid
83
+ centroid_proj = torch.matmul(self.centroid.unsqueeze(0), self.theta_bases)
84
+ angles = torch.atan2(projections, centroid_proj + 1e-8)
85
+
86
+ # Return sin/cos encoding
87
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).to(get_default_device()) # [batch, 10]
88
+
89
+
90
+ class ThetaHead(nn.Module):
91
+ """
92
+ Theta-based classification head using angular representations.
93
+ Replaces similarity matching with learned feedforward.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ embed_dim: int,
99
+ num_classes: int,
100
+ n_pentachora: int = 10, # Use subset of pentachora for theta
101
+ hidden_dim: int = 256,
102
+ dropout: float = 0.1
103
+ ):
104
+ super().__init__()
105
+
106
+ self.n_pentachora = n_pentachora
107
+ self.embed_dim = embed_dim
108
+
109
+ # Each pentachora gives 10 theta features (5 sin + 5 cos)
110
+ theta_dim = n_pentachora * 10
111
+
112
+ # Project to theta space
113
+ self.to_theta = nn.Sequential(
114
+ nn.Linear(embed_dim, hidden_dim),
115
+ nn.LayerNorm(hidden_dim),
116
+ nn.GELU(),
117
+ nn.Dropout(dropout),
118
+ nn.Linear(hidden_dim, theta_dim)
119
+ )
120
+
121
+ # Classify from theta
122
+ self.classifier = nn.Sequential(
123
+ nn.LayerNorm(theta_dim),
124
+ nn.Dropout(dropout),
125
+ nn.Linear(theta_dim, num_classes)
126
+ )
127
+
128
+ # Learnable temperature
129
+ self.temperature = nn.Parameter(torch.ones(1) * 0.1)
130
+
131
+ def forward(self, features: torch.Tensor, pentachora_list: nn.ModuleList) -> Dict[str, torch.Tensor]:
132
+ """
133
+ Classify using theta rotation.
134
+
135
+ Args:
136
+ features: [batch, embed_dim] CLS features
137
+ pentachora_list: List of PentachoraEmbedding modules
138
+ """
139
+ # Get theta features from first n pentachora
140
+ theta_features = []
141
+ for i in range(min(self.n_pentachora, len(pentachora_list))):
142
+ theta = pentachora_list[i].compute_theta_features(features)
143
+ theta_features.append(theta)
144
+
145
+ # Concatenate all theta features
146
+ theta_concat = torch.cat(theta_features, dim=-1) # [batch, n_pentachora * 10]
147
+
148
+ # If we have fewer pentachora than expected, pad with zeros
149
+ if len(theta_features) < self.n_pentachora:
150
+ pad_size = (self.n_pentachora - len(theta_features)) * 10
151
+ padding = torch.zeros(features.shape[0], pad_size, device=features.device)
152
+ theta_concat = torch.cat([theta_concat, padding], dim=-1)
153
+
154
+ # Project through MLP
155
+ theta_proj = self.to_theta(features)
156
+
157
+ # Combine with geometric theta (residual connection)
158
+ theta_combined = theta_concat + 0.1 * theta_proj
159
+
160
+ # Classify
161
+ logits = self.classifier(theta_combined) / self.temperature.exp()
162
+
163
+ return {
164
+ 'logits': logits,
165
+ 'theta_features': theta_combined
166
+ }
167
 
168
 
169
  class TransformerBlock(nn.Module):
 
198
  )
199
 
200
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
201
  x_norm = self.norm1(x)
202
  attn_out, _ = self.attn(x_norm, x_norm, x_norm)
203
  x = x + attn_out
 
 
204
  x = x + self.mlp(self.norm2(x))
 
205
  return x
206
 
207
 
208
  class BaselineViT(nn.Module):
209
  """
210
+ Vision Transformer with optional theta-based classification.
211
+ Can switch between similarity-based and theta-based heads.
212
  """
213
 
214
  def __init__(
215
  self,
216
+ pentachora_list: list,
217
  vocab_dim: int = 256,
218
  img_size: int = 32,
219
  patch_size: int = 4,
 
223
  mlp_ratio: float = 4.0,
224
  dropout: float = 0.0,
225
  attn_dropout: float = 0.0,
226
+ similarity_mode: str = 'rose',
227
+ use_theta_head: bool = True, # NEW: Toggle theta head
228
+ theta_n_pentachora: int = 2, # NEW: How many pentachora for theta
229
+ theta_hidden_dim: int = 256 # NEW: Hidden dim for theta MLP
230
  ):
231
  super().__init__()
232
 
233
+ assert isinstance(pentachora_list, list) and len(pentachora_list) > 0
 
 
 
 
 
 
234
 
235
  self.num_classes = len(pentachora_list)
236
  self.embed_dim = embed_dim
237
  self.num_patches = (img_size // patch_size) ** 2
238
  self.similarity_mode = similarity_mode
239
  self.pentachora_dim = vocab_dim
240
+ self.use_theta_head = use_theta_head
241
 
242
+ # Create pentachora embeddings
243
  self.class_pentachora = nn.ModuleList([
244
  PentachoraEmbedding(vertices=penta)
245
  for penta in pentachora_list
 
248
  # Patch embedding
249
  self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
250
 
251
+ # CLS token
252
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
253
 
254
  # Position embeddings
 
276
  else:
277
  self.to_pentachora_dim = nn.Identity()
278
 
279
+ # Classification heads
280
+ if use_theta_head:
281
+ # NEW: Theta-based classification
282
+ self.theta_head = ThetaHead(
283
+ embed_dim=self.pentachora_dim,
284
+ num_classes=self.num_classes,
285
+ n_pentachora=theta_n_pentachora,
286
+ hidden_dim=theta_hidden_dim,
287
+ dropout=dropout
288
+ )
289
+ else:
290
+ # Original: Similarity-based classification
291
+ self.theta_head = None
292
+ self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
293
+
294
+ self.register_buffer(
295
+ 'all_centroids',
296
+ torch.stack([penta.centroid for penta in self.class_pentachora])
297
+ )
298
+ self.register_buffer(
299
+ 'all_centroids_norm',
300
+ F.normalize(self.all_centroids, dim=-1)
301
+ )
302
 
 
 
 
 
 
 
 
 
 
 
303
  self.init_weights()
304
 
305
  def init_weights(self):
 
306
  nn.init.trunc_normal_(self.cls_token, std=0.02)
307
  nn.init.trunc_normal_(self.pos_embed, std=0.02)
308
 
 
315
  nn.init.ones_(m.weight)
316
  nn.init.zeros_(m.bias)
317
 
 
318
  def get_class_centroids(self) -> torch.Tensor:
319
+ if self.use_theta_head:
320
+ # Return centroids from pentachora for compatibility
321
+ centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora])
322
+ return centroids
323
+ else:
324
+ return self.all_centroids_norm
325
+
326
  def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
 
 
 
327
  if self.similarity_mode == 'rose':
328
+ all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora])
329
+ features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1)
330
+ return PentachoronStabilizer.rose_score_magnitude(
331
+ features_exp.reshape(-1, self.pentachora_dim),
332
+ all_vertices.repeat(features.shape[0], 1, 1)
333
+ ).reshape(features.shape[0], -1)
334
  else:
335
+ centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora])
336
+ features_norm = F.normalize(features, dim=-1)
337
+ return torch.matmul(features_norm, centroids.T)
 
 
338
 
339
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
 
340
  B = x.shape[0]
341
 
342
  # Patch embedding
343
+ x = self.patch_embed(x)
344
+ x = x.flatten(2).transpose(1, 2)
345
 
346
  # Add CLS token
347
  cls_tokens = self.cls_token.expand(B, -1, -1)
 
363
 
364
  def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
365
  """
366
+ Forward pass with optional theta head.
 
 
 
 
 
367
  """
368
  features = self.forward_features(x)
 
369
  output = {}
370
 
371
  # Project to pentachora dimension
372
  features_proj = self.to_pentachora_dim(features)
373
 
374
+ if self.use_theta_head:
375
+ # NEW: Use theta-based classification
376
+ theta_output = self.theta_head(features_proj, self.class_pentachora)
377
+ output['logits'] = theta_output['logits']
378
+ output['theta_features'] = theta_output['theta_features']
379
+
380
+ # Still compute similarities for analysis
381
+ with torch.no_grad():
382
+ similarities = self.compute_pentachora_similarities(features_proj)
383
+ output['similarities'] = similarities
384
  else:
385
+ # Original: Use similarity-based classification
386
+ similarities = self.compute_pentachora_similarities(features_proj)
387
+ logits = similarities * self.temperature.exp()
388
+
389
+ output['logits'] = logits
390
+ output['similarities'] = similarities
 
 
 
 
391
 
392
  if return_features:
393
  output['features'] = features
 
395
  return output
396
 
397
 
398
+ # Helper function to convert existing model to theta
399
+ def enable_theta_head(model: BaselineViT, n_pentachora: int = 10, hidden_dim: int = 256):
400
+ """
401
+ Convert an existing similarity-based model to use theta head.
402
+ This modifies the model in-place.
403
+ """
404
+ if model.use_theta_head:
405
+ print("Model already using theta head")
406
+ return model
407
+
408
+ print(f"Converting to theta head with {n_pentachora} pentachora...")
409
+
410
+ # Create theta head
411
+ model.theta_head = ThetaHead(
412
+ embed_dim=model.pentachora_dim,
413
+ num_classes=model.num_classes,
414
+ n_pentachora=n_pentachora,
415
+ hidden_dim=hidden_dim,
416
+ dropout=0.1
417
+ ).to(next(model.parameters()).device)
418
+
419
+ # Set flag
420
+ model.use_theta_head = True
421
+
422
+ # Initialize new parameters
423
+ for m in model.theta_head.modules():
424
+ if isinstance(m, nn.Linear):
425
+ nn.init.trunc_normal_(m.weight, std=0.02)
426
+ if m.bias is not None:
427
+ nn.init.zeros_(m.bias)
428
+
429
+ print("✓ Theta head enabled")
430
+ return model
431
+
432
+
433
  if __name__ == "__main__":
434
+ print("BaselineViT with optional theta head")
435
+ print("Use 'use_theta_head=True' to enable theta classification")
436
+ print("Or call enable_theta_head() on existing model")