AbstractPhil commited on
Commit
dd3435c
·
verified ·
1 Parent(s): 2ecfc2d

Create penta_vit_model_v1.py

Browse files
Files changed (1) hide show
  1. penta_vit_model_v1.py +1093 -0
penta_vit_model_v1.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
3
+ Enhanced with Geometric Attention for improved head cohesion and generalization
4
+
5
+ Author: AbstractPhil
6
+
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import numpy as np
13
+ from einops import rearrange, repeat
14
+ import math
15
+ from typing import Optional, Dict, Tuple, List, Any
16
+ from dataclasses import dataclass
17
+ import warnings
18
+
19
+ # ============================================
20
+ # CONFIGURATION CLASSES
21
+ # ============================================
22
+
23
+ @dataclass
24
+ class PentachoraConfig:
25
+ """Configuration for PentachoraViT models."""
26
+ img_size: int = 32
27
+ patch_size: int = 4
28
+ num_classes: int = 100
29
+ dim: int = 512
30
+ vocab_dim: Optional[int] = None # Vocabulary dimension (can differ from model dim)
31
+ depth: int = 12
32
+ heads: int = 8
33
+ mlp_ratio: float = 4.0
34
+ use_mesh_attention: bool = True
35
+ preserve_structure_until_layer: int = 6
36
+ dropout_rate: float = 0.0
37
+ drop_path_rate: float = 0.0
38
+ aux_loss_weight: float = 0.0
39
+ geo_loss_weight: float = 0.0
40
+ vocab: Optional[Any] = None
41
+
42
+ @property
43
+ def num_patches(self) -> int:
44
+ return (self.img_size // self.patch_size) ** 2
45
+
46
+ # ============================================
47
+ # GEOMETRIC ATTENTION COMPONENTS (OPTIMIZED)
48
+ # ============================================
49
+
50
+ def perfect_4simplex(device):
51
+ """Create perfect 4-simplex (pentachoron) vertices in 4D."""
52
+ sqrt5 = math.sqrt(5)
53
+ vertices = torch.tensor([
54
+ [1, 1, 1, -1/sqrt5],
55
+ [1, -1, -1, -1/sqrt5],
56
+ [-1, 1, -1, -1/sqrt5],
57
+ [-1, -1, 1, -1/sqrt5],
58
+ [0, 0, 0, 4/sqrt5]
59
+ ], device=device, dtype=torch.float32)
60
+ return vertices / 2 # Normalize scale
61
+
62
+ def softmin_over_last(distances, tau):
63
+ """Softmin over last dimension."""
64
+ return F.softmax(-distances / tau, dim=-1).sum(dim=-1)
65
+
66
+ @dataclass
67
+ class GeometricConfig:
68
+ """Configuration for geometric attention."""
69
+ softmin_tau: float = 0.05
70
+ fuse_alpha: float = 0.7
71
+ phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2)
72
+ jitter: float = 0.02
73
+ shift: float = 0.25
74
+ rotate_cycle: int = 11
75
+ use_phase_variance: bool = False
76
+ geometry_type: str = "pentachoron"
77
+
78
+ class GeometricNavigator(nn.Module):
79
+ """Maps inputs to geometric regions in 4D space - OPTIMIZED with vectorized operations."""
80
+
81
+ def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig, num_heads: int = 1):
82
+ super().__init__()
83
+ self.input_dim = input_dim
84
+ self.num_regions = num_regions
85
+ self.config = config
86
+ self.num_heads = num_heads
87
+
88
+ # Create separate parameters for each head if num_heads > 1
89
+ if num_heads > 1:
90
+ self.to_nav = nn.Parameter(torch.randn(num_heads, input_dim, 4) * 0.02)
91
+ self.vertex_w = nn.Parameter(torch.zeros(num_heads, num_regions, 5))
92
+ else:
93
+ self.to_nav = nn.Linear(input_dim, 4, bias=False)
94
+ self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5))
95
+
96
+ # Pre-compute phase tensors for vectorization
97
+ self.register_buffer('phase_cos', torch.cos(torch.tensor(config.phases, dtype=torch.float32)))
98
+ self.register_buffer('phase_sin', torch.sin(torch.tensor(config.phases, dtype=torch.float32)))
99
+
100
+ # Initialize geometry after module is created
101
+ self.register_parameter('D', None)
102
+ self.register_parameter('S', None)
103
+
104
+ def _lazy_init_geometry(self, device):
105
+ """Initialize geometry on first forward pass."""
106
+ if self.D is not None:
107
+ return
108
+
109
+ base = perfect_4simplex(device)
110
+
111
+ if self.num_heads > 1:
112
+ D = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device)
113
+ S = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device)
114
+
115
+ for h in range(self.num_heads):
116
+ for r in range(self.num_regions):
117
+ D[h, r] = base + self.config.jitter * torch.randn_like(base)
118
+
119
+ theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device)
120
+ rot = torch.eye(4, device=device)
121
+ c, s_val = torch.cos(theta), torch.sin(theta)
122
+ rot[0, 0] = c; rot[0, 1] = -s_val
123
+ rot[1, 0] = s_val; rot[1, 1] = c
124
+ S[h, r] = (base @ rot) + self.config.shift
125
+ S[h, r] += self.config.jitter * torch.randn_like(S[h, r])
126
+ else:
127
+ D = torch.zeros(self.num_regions, 5, 4, device=device)
128
+ S = torch.zeros(self.num_regions, 5, 4, device=device)
129
+
130
+ for r in range(self.num_regions):
131
+ D[r] = base + self.config.jitter * torch.randn_like(base)
132
+
133
+ theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device)
134
+ rot = torch.eye(4, device=device)
135
+ c, s_val = torch.cos(theta), torch.sin(theta)
136
+ rot[0, 0] = c; rot[0, 1] = -s_val
137
+ rot[1, 0] = s_val; rot[1, 1] = c
138
+ S[r] = (base @ rot) + self.config.shift
139
+ S[r] += self.config.jitter * torch.randn_like(S[r])
140
+
141
+ self.D = nn.Parameter(D)
142
+ self.S = nn.Parameter(S)
143
+
144
+ def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
145
+ """Navigate inputs through geometric space - OPTIMIZED with vectorized phase computation."""
146
+ self._lazy_init_geometry(x.device)
147
+
148
+ if self.num_heads > 1:
149
+ # Batched navigation for multiple heads
150
+ BT, H, head_dim = x.shape
151
+
152
+ # Batched transformation
153
+ nav_x = torch.einsum('bhi,hio->bho', x, self.to_nav) # [BT, H, 4]
154
+
155
+ # Dispatcher scores
156
+ nav_x_disp = nav_x.view(BT, H, 1, 1, 4)
157
+ D_exp = self.D.unsqueeze(0) # [1, H, regions, 5, 4]
158
+ d_disp = torch.norm(nav_x_disp - D_exp, dim=-1)
159
+ s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
160
+
161
+ # OPTIMIZED: Vectorized phase computation (no loop)
162
+ cos_phases = self.phase_cos.to(x.device).view(-1, 1, 1, 1, 1)
163
+ sin_phases = self.phase_sin.to(x.device).view(-1, 1, 1, 1, 1)
164
+
165
+ # Compute all phase variants at once [phases, H, regions, 5, 4]
166
+ Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
167
+
168
+ # Apply vertex weighting to all phases
169
+ w = F.softmax(self.vertex_w, dim=-1)
170
+ w_exp = w.unsqueeze(0).unsqueeze(-1) # [1, H, regions, 5, 1]
171
+ Vt_mean = Vt_all.mean(dim=3, keepdim=True)
172
+ Vt_all = (1.0 - w_exp) * Vt_all + w_exp * Vt_mean
173
+
174
+ # Compute all ribbon distances at once
175
+ nav_x_ribbon = nav_x.view(BT, 1, H, 1, 1, 4)
176
+ Vt_exp = Vt_all.unsqueeze(0) # [1, phases, H, regions, 5, 4]
177
+ d_ribbon_all = torch.norm(nav_x_ribbon - Vt_exp, dim=-1)
178
+ s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau)
179
+ s_ribbon = s_ribbon_all.mean(dim=1) # Average over phases
180
+
181
+ scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
182
+ scores = scores.reshape(BT * H, self.num_regions)
183
+
184
+ else:
185
+ # Original single-head navigation
186
+ nav_x = self.to_nav(x)
187
+ nav_x_exp = nav_x[:, None, None, :]
188
+ D_exp = self.D[None, :, :, :]
189
+
190
+ d_disp = torch.norm(nav_x_exp - D_exp, dim=-1)
191
+ s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
192
+
193
+ w = F.softmax(self.vertex_w, dim=1)
194
+
195
+ # OPTIMIZED: Vectorized phase computation for single head
196
+ cos_phases = self.phase_cos.to(x.device).view(-1, 1, 1, 1)
197
+ sin_phases = self.phase_sin.to(x.device).view(-1, 1, 1, 1)
198
+
199
+ Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
200
+ w_expanded = w.unsqueeze(0).unsqueeze(-1)
201
+ Vt_mean = Vt_all.mean(dim=2, keepdim=True)
202
+ Vt_all = (1.0 - w_expanded) * Vt_all + w_expanded * Vt_mean
203
+
204
+ nav_x_phase = nav_x[:, None, None, None, :]
205
+ Vt_exp = Vt_all.unsqueeze(0)
206
+ d_ribbon_all = torch.norm(nav_x_phase - Vt_exp, dim=-1)
207
+ s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau)
208
+ s_ribbon = s_ribbon_all.mean(dim=1)
209
+
210
+ scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
211
+
212
+ diagnostics = {
213
+ 'dispatcher_scores': s_disp.detach() if self.num_heads == 1 else s_disp.reshape(BT * H, -1).detach(),
214
+ 'ribbon_scores': s_ribbon.detach() if self.num_heads == 1 else s_ribbon.reshape(BT * H, -1).detach()
215
+ }
216
+
217
+ return {'scores': scores, 'diagnostics': diagnostics}
218
+
219
+ class GeometricAttention(nn.Module):
220
+ """Multi-head geometric attention with Q-K alignment - OPTIMIZED with batched processing."""
221
+
222
+ def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
223
+ config: Optional[GeometricConfig] = None, dropout: float = 0.0):
224
+ super().__init__()
225
+ self.dim = dim
226
+ self.num_heads = num_heads
227
+ self.head_dim = dim // num_heads
228
+
229
+ if num_regions is None:
230
+ num_regions = min(self.head_dim, 16)
231
+ if config is None:
232
+ config = GeometricConfig()
233
+
234
+ self.config = config
235
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
236
+
237
+ # Create batched navigators
238
+ self.q_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads)
239
+ self.k_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads)
240
+
241
+ self.out_proj = nn.Linear(dim, dim)
242
+ self.dropout = nn.Dropout(dropout)
243
+
244
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
245
+ return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]:
246
+ B, T, D = x.shape
247
+
248
+ qkv = self.to_qkv(x)
249
+ q, k, v = qkv.chunk(3, dim=-1)
250
+
251
+ q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
252
+ k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
253
+ v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
254
+
255
+ # Prepare for batched navigation
256
+ q_batched = q.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim)
257
+ k_batched = k.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim)
258
+
259
+ # Navigate all heads at once
260
+ q_nav = self.q_navigator.navigate(q_batched)
261
+ k_nav = self.k_navigator.navigate(k_batched)
262
+
263
+ # Reshape scores back
264
+ q_scores = q_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2)
265
+ k_scores = k_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2)
266
+
267
+ # OPTIMIZED: Compute attention for all heads at once using einsum
268
+ scale = math.sqrt(q_scores.size(-1))
269
+ attn = torch.einsum('bhqr,bhkr->bhqk', q_scores, k_scores) / scale
270
+
271
+ if mask is not None:
272
+ mask_expanded = mask.unsqueeze(1).unsqueeze(2)
273
+ attn = attn.masked_fill(mask_expanded == 0, -1e9)
274
+
275
+ attn = F.softmax(attn, dim=-1)
276
+ attn = self.dropout(attn)
277
+
278
+ # Apply attention to values
279
+ out = torch.einsum('bhqk,bhkd->bhqd', attn, v)
280
+ out = out.transpose(1, 2).reshape(B, T, D)
281
+
282
+ output = self.out_proj(out)
283
+ output = self.dropout(output)
284
+
285
+ if return_diagnostics:
286
+ return output, {'q_diagnostics': q_nav['diagnostics'], 'k_diagnostics': k_nav['diagnostics']}
287
+ return output, None
288
+
289
+ # ============================================
290
+ # DROP PATH (STOCHASTIC DEPTH)
291
+ # ============================================
292
+
293
+ class DropPath(nn.Module):
294
+ """Drop paths (Stochastic Depth) per sample."""
295
+ def __init__(self, drop_prob: float = 0.):
296
+ super().__init__()
297
+ self.drop_prob = drop_prob
298
+
299
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
300
+ if self.drop_prob == 0. or not self.training:
301
+ return x
302
+ keep_prob = 1 - self.drop_prob
303
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
304
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
305
+ random_tensor.floor_()
306
+ output = x.div(keep_prob) * random_tensor
307
+ return output
308
+
309
+ # ============================================
310
+ # HIERARCHICAL CLS WITH PENTACHORA
311
+ # ============================================
312
+
313
+ class HierarchicalPentachoronCLS(nn.Module):
314
+ """
315
+ Hierarchical CLS structure with pentachoron geometry.
316
+ Uses vocabulary embeddings for CLS tokens.
317
+ """
318
+ def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100):
319
+ super().__init__()
320
+ self.dim = dim
321
+ self.vocab_dim = vocab_dim
322
+ self.num_classes = num_classes
323
+
324
+ # Class-specific pentachora from vocabulary
325
+ self.register_buffer('class_pentachora', torch.randn(num_classes, 5, vocab_dim) * 0.02)
326
+
327
+ # Projection from vocabulary dimension to model dimension
328
+ if vocab_dim != dim:
329
+ self.vocab_to_model = nn.Linear(vocab_dim, dim)
330
+ else:
331
+ self.vocab_to_model = nn.Identity()
332
+
333
+ # Learnable aggregation weights
334
+ self.vertex_weights = nn.Parameter(torch.ones(5) / 5)
335
+
336
+ # Optional learnable offset
337
+ self.global_offset = nn.Parameter(torch.zeros(1, 1, dim))
338
+
339
+ # Layer norms
340
+ self.vertex_norm = nn.LayerNorm(dim)
341
+ self.global_norm = nn.LayerNorm(dim)
342
+
343
+ def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
344
+ """Generate CLS tokens for batch."""
345
+ if class_indices is not None and class_indices.shape[0] == batch_size:
346
+ vertex_cls_vocab = self.class_pentachora[class_indices]
347
+ else:
348
+ vertex_cls_vocab = self.class_pentachora.mean(dim=0, keepdim=True)
349
+ vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1)
350
+
351
+ # Project from vocabulary dimension to model dimension
352
+ vertex_cls = self.vocab_to_model(vertex_cls_vocab)
353
+ vertex_cls = self.vertex_norm(vertex_cls)
354
+
355
+ # Create global CLS as weighted combination
356
+ weights = F.softmax(self.vertex_weights, dim=0)
357
+ global_cls = torch.einsum('bvd,v->bd', vertex_cls, weights).unsqueeze(1)
358
+ global_cls = global_cls + self.global_offset
359
+ global_cls = self.global_norm(global_cls)
360
+
361
+ return global_cls, vertex_cls
362
+
363
+ def get_class_prototypes(self) -> torch.Tensor:
364
+ """Get class prototypes in model dimension."""
365
+ pentachora_model = self.vocab_to_model(self.class_pentachora)
366
+ weights = F.softmax(self.vertex_weights, dim=0)
367
+ prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights)
368
+ return prototypes
369
+
370
+ # ============================================
371
+ # GEOMETRIC PROJECTION LAYER
372
+ # ============================================
373
+
374
+ class GeometricProjection(nn.Module):
375
+ """Project patches onto pentachoron geometry."""
376
+ def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1):
377
+ super().__init__()
378
+ self.dim = dim
379
+ self.vocab_dim = vocab_dim
380
+ self.num_classes = num_classes
381
+
382
+ # Projection from model dim to vocab dim
383
+ self.to_vocab_space = nn.Linear(dim, vocab_dim)
384
+
385
+ # Vertex-specific projections
386
+ self.vertex_projections = nn.ModuleList([
387
+ nn.Linear(vocab_dim, vocab_dim, bias=False) for _ in range(5)
388
+ ])
389
+
390
+ # Temperature for alignment scores
391
+ self.temperature = nn.Parameter(torch.ones(1))
392
+
393
+ self.norm = nn.LayerNorm(dim)
394
+ self.dropout = nn.Dropout(dropout)
395
+
396
+ def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor:
397
+ """Compute alignment between patches and class pentachora."""
398
+ B, N, D = patches.shape
399
+ C = pentachora.shape[0]
400
+
401
+ # Normalize patches
402
+ patches = self.norm(patches)
403
+
404
+ # Project patches to vocabulary space
405
+ patches_vocab = self.to_vocab_space(patches)
406
+ patches_vocab = F.normalize(patches_vocab, dim=-1)
407
+
408
+ # Compute alignment with each vertex
409
+ alignments = []
410
+ for v in range(5):
411
+ patches_v = self.vertex_projections[v](patches_vocab)
412
+ patches_v = F.normalize(patches_v, dim=-1)
413
+ vertex_v = F.normalize(pentachora[:, v, :], dim=-1)
414
+ alignment = torch.matmul(patches_v, vertex_v.T) / self.temperature
415
+ alignments.append(alignment)
416
+
417
+ # Average alignments across vertices
418
+ alignments = torch.stack(alignments, dim=-1).mean(dim=-1)
419
+
420
+ return self.dropout(alignments)
421
+
422
+ # ============================================
423
+ # MLP BLOCK
424
+ # ============================================
425
+
426
+ class MLP(nn.Module):
427
+ """MLP block with GELU activation."""
428
+ def __init__(self, in_features: int, hidden_features: Optional[int] = None,
429
+ out_features: Optional[int] = None, dropout: float = 0.):
430
+ super().__init__()
431
+ out_features = out_features or in_features
432
+ hidden_features = hidden_features or in_features
433
+
434
+ self.fc1 = nn.Linear(in_features, hidden_features)
435
+ self.act = nn.GELU()
436
+ self.drop1 = nn.Dropout(dropout)
437
+ self.fc2 = nn.Linear(hidden_features, out_features)
438
+ self.drop2 = nn.Dropout(dropout)
439
+
440
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
441
+ x = self.fc1(x)
442
+ x = self.act(x)
443
+ x = self.drop1(x)
444
+ x = self.fc2(x)
445
+ x = self.drop2(x)
446
+ return x
447
+
448
+ # ============================================
449
+ # VIT BLOCK WITH GEOMETRIC ATTENTION
450
+ # ============================================
451
+
452
+ class PentachoronViTBlock(nn.Module):
453
+ """ViT block with geometric attention for structured layers."""
454
+ def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0,
455
+ use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0.,
456
+ drop_path: float = 0.):
457
+ super().__init__()
458
+ self.norm1 = nn.LayerNorm(dim)
459
+
460
+ # Use GeometricAttention for structured layers, standard for others
461
+ if use_mesh:
462
+ self.attn = GeometricAttention(
463
+ dim=dim,
464
+ num_heads=heads,
465
+ num_regions=min(dim // heads, 16),
466
+ config=GeometricConfig(),
467
+ dropout=attn_dropout
468
+ )
469
+ else:
470
+ # Standard multi-head attention for later layers
471
+ self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True)
472
+
473
+ self.use_mesh = use_mesh
474
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
475
+
476
+ self.norm2 = nn.LayerNorm(dim)
477
+ mlp_hidden = int(dim * mlp_ratio)
478
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout)
479
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
480
+
481
+ def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor:
482
+ if self.use_mesh:
483
+ # GeometricAttention
484
+ attn_out, _ = self.attn(self.norm1(x))
485
+ x = x + self.drop_path1(attn_out)
486
+ else:
487
+ # Standard attention
488
+ normalized = self.norm1(x)
489
+ attn_out, _ = self.attn(normalized, normalized, normalized)
490
+ x = x + self.drop_path1(attn_out)
491
+
492
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
493
+ return x
494
+
495
+ # ============================================
496
+ # PATCH EMBEDDING
497
+ # ============================================
498
+
499
+ class PatchEmbed(nn.Module):
500
+ """2D Image to Patch Embedding."""
501
+ def __init__(self, img_size: int = 32, patch_size: int = 4,
502
+ in_chans: int = 3, embed_dim: int = 512):
503
+ super().__init__()
504
+ self.img_size = img_size
505
+ self.patch_size = patch_size
506
+ self.num_patches = (img_size // patch_size) ** 2
507
+
508
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
509
+ self.norm = nn.LayerNorm(embed_dim)
510
+
511
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
512
+ x = self.proj(x)
513
+ x = rearrange(x, 'b c h w -> b (h w) c')
514
+ x = self.norm(x)
515
+ return x
516
+
517
+ # ============================================
518
+ # PENTACHORA VISION TRANSFORMER
519
+ # ============================================
520
+
521
+ class PentachoraViT(nn.Module):
522
+ """
523
+ Vision Transformer with pentachoron-based hierarchical CLS tokens
524
+ and geometric vocabulary integration.
525
+ """
526
+ def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs):
527
+ super().__init__()
528
+
529
+ # Use config or kwargs
530
+ if config is not None:
531
+ cfg = config
532
+ else:
533
+ cfg = PentachoraConfig(**kwargs)
534
+
535
+ self.config = cfg
536
+ self.num_classes = cfg.num_classes
537
+ self.dim = cfg.dim
538
+ self.depth = cfg.depth
539
+ self.preserve_structure_until_layer = cfg.preserve_structure_until_layer
540
+
541
+ # Set vocabulary dimension
542
+ if cfg.vocab_dim is not None:
543
+ self.vocab_dim = cfg.vocab_dim
544
+ elif 'vocab_dim' in kwargs:
545
+ self.vocab_dim = kwargs['vocab_dim']
546
+ else:
547
+ self.vocab_dim = cfg.dim
548
+
549
+ # Patch embedding
550
+ self.patch_embed = PatchEmbed(
551
+ cfg.img_size, cfg.patch_size, 3, cfg.dim
552
+ )
553
+ num_patches = self.patch_embed.num_patches
554
+
555
+ # Positional embedding
556
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02)
557
+ self.pos_drop = nn.Dropout(cfg.dropout_rate)
558
+
559
+ # CLS tokens with pentachoron structure
560
+ self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes)
561
+
562
+ # Geometric projection layer
563
+ self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate)
564
+
565
+ # Initialize from vocabulary if provided
566
+ if cfg.vocab is not None:
567
+ self._init_from_vocab(cfg.vocab)
568
+
569
+ # Stochastic depth decay rule
570
+ dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)]
571
+
572
+ # Transformer blocks with geometric attention
573
+ self.blocks = nn.ModuleList([
574
+ PentachoronViTBlock(
575
+ dim=cfg.dim,
576
+ heads=cfg.heads,
577
+ mlp_ratio=cfg.mlp_ratio,
578
+ use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer),
579
+ dropout=cfg.dropout_rate,
580
+ attn_dropout=cfg.dropout_rate,
581
+ drop_path=dpr[i]
582
+ )
583
+ for i in range(cfg.depth)
584
+ ])
585
+
586
+ # Final norm
587
+ self.norm = nn.LayerNorm(cfg.dim)
588
+
589
+ # Classification heads
590
+ self.use_prototype_classifier = True
591
+ if self.use_prototype_classifier:
592
+ self.head = None
593
+ else:
594
+ self.head = nn.Linear(cfg.dim, cfg.num_classes)
595
+
596
+ # Auxiliary head for vertex tokens
597
+ self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes)
598
+
599
+ # Initialize weights
600
+ self.apply(self._init_weights)
601
+
602
+ def _init_weights(self, m: nn.Module):
603
+ """Initialize model weights."""
604
+ if isinstance(m, nn.Linear):
605
+ nn.init.trunc_normal_(m.weight, std=0.02)
606
+ if m.bias is not None:
607
+ nn.init.constant_(m.bias, 0)
608
+ elif isinstance(m, nn.LayerNorm):
609
+ nn.init.constant_(m.bias, 0)
610
+ nn.init.constant_(m.weight, 1.0)
611
+ elif isinstance(m, nn.Conv2d):
612
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
613
+ if m.bias is not None:
614
+ nn.init.constant_(m.bias, 0)
615
+
616
+ def _init_from_vocab(self, vocab):
617
+ """Initialize class pentachora from geometric vocabulary."""
618
+ try:
619
+ print("Initializing pentachora from vocabulary...")
620
+
621
+ if not hasattr(vocab, 'encode_batch'):
622
+ print("Vocabulary provided but encode_batch method not found, using random initialization")
623
+ return
624
+
625
+ # Get CIFAR-100 class names
626
+ class_names = self._get_cifar100_classes()
627
+
628
+ # Generate pentachora for all classes
629
+ pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True)
630
+ pentachora = np.stack(pentachora_list, axis=0)
631
+
632
+ # Get actual dimensions from the encoded data
633
+ actual_vocab_dim = pentachora.shape[-1]
634
+
635
+ print(f"Encoded pentachora shape: {pentachora.shape}")
636
+ print(f"Detected vocabulary dimension: {actual_vocab_dim}")
637
+
638
+ # Validate basic shape requirements
639
+ if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5:
640
+ print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}")
641
+ print("Using random initialization")
642
+ return
643
+
644
+ # Update vocabulary dimension
645
+ self.vocab_dim = actual_vocab_dim
646
+ self.cls_tokens.vocab_dim = actual_vocab_dim
647
+ self.geometric_proj.vocab_dim = actual_vocab_dim
648
+
649
+ # Replace class_pentachora with the loaded vocabulary
650
+ self.cls_tokens.class_pentachora = torch.tensor(pentachora, dtype=torch.float32)
651
+
652
+ # Update/create projection layer if dimensions differ
653
+ if actual_vocab_dim != self.dim:
654
+ self.cls_tokens.vocab_to_model = nn.Linear(actual_vocab_dim, self.dim)
655
+ else:
656
+ self.cls_tokens.vocab_to_model = nn.Identity()
657
+
658
+ # Rebuild geometric projection components
659
+ self.geometric_proj.to_vocab_space = nn.Linear(self.dim, actual_vocab_dim)
660
+ self.geometric_proj.vertex_projections = nn.ModuleList([
661
+ nn.Linear(actual_vocab_dim, actual_vocab_dim, bias=False) for _ in range(5)
662
+ ])
663
+
664
+ # Re-initialize the new layers
665
+ nn.init.xavier_uniform_(self.geometric_proj.to_vocab_space.weight)
666
+ for proj in self.geometric_proj.vertex_projections:
667
+ nn.init.xavier_uniform_(proj.weight)
668
+ if actual_vocab_dim != self.dim:
669
+ nn.init.xavier_uniform_(self.cls_tokens.vocab_to_model.weight)
670
+
671
+ print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary")
672
+ print(f" Vocabulary dimension: {actual_vocab_dim}")
673
+ print(f" Model internal dimension: {self.dim}")
674
+
675
+ except Exception as e:
676
+ print(f"Error initializing from vocabulary: {e}")
677
+ print("Using random initialization")
678
+
679
+ def _get_cifar100_classes(self):
680
+ """Get CIFAR-100 class names."""
681
+ return [
682
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
683
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
684
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
685
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
686
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
687
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
688
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
689
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
690
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
691
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
692
+ 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
693
+ 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
694
+ 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
695
+ 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
696
+ ]
697
+
698
+ def forward_features(self, x: torch.Tensor, class_indices: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
699
+ """Extract features from input."""
700
+ B = x.shape[0]
701
+
702
+ # Patch embedding
703
+ x = self.patch_embed(x)
704
+ x = x + self.pos_embed
705
+ x = self.pos_drop(x)
706
+
707
+ # Get hierarchical CLS tokens
708
+ global_cls, vertex_cls = self.cls_tokens(B, class_indices)
709
+
710
+ # Concatenate CLS tokens with patches
711
+ x = torch.cat([global_cls, vertex_cls, x], dim=1)
712
+
713
+ # Apply transformer blocks
714
+ for i, block in enumerate(self.blocks):
715
+ preserve = i < self.preserve_structure_until_layer
716
+ x = block(x, preserve_structure=preserve)
717
+
718
+ # Apply final norm
719
+ x = self.norm(x)
720
+
721
+ # Split tokens
722
+ global_cls = x[:, 0]
723
+ vertex_cls = x[:, 1:6]
724
+ patches = x[:, 6:]
725
+
726
+ return {
727
+ 'global_cls': global_cls,
728
+ 'vertex_cls': vertex_cls,
729
+ 'patches': patches
730
+ }
731
+
732
+ def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
733
+ """Forward pass through the model."""
734
+ # During training, use target labels for class-specific CLS initialization
735
+ class_indices = targets if self.training and targets is not None else None
736
+
737
+ features = self.forward_features(x, class_indices)
738
+
739
+ # Primary classification using prototype matching
740
+ if self.use_prototype_classifier:
741
+ prototypes = self.cls_tokens.get_class_prototypes()
742
+ prototypes = F.normalize(prototypes, dim=-1)
743
+ global_cls_norm = F.normalize(features['global_cls'], dim=-1)
744
+ logits = torch.matmul(global_cls_norm, prototypes.T) * 20.0
745
+ else:
746
+ logits = self.head(features['global_cls'])
747
+
748
+ # Auxiliary classification using vertex tokens
749
+ B = features['vertex_cls'].shape[0]
750
+ vertex_flat = features['vertex_cls'].reshape(B, -1)
751
+ aux_logits = self.head_aux(vertex_flat)
752
+
753
+ # Geometric alignment scores - use class_pentachora directly
754
+ geometric_alignments = self.geometric_proj(
755
+ features['patches'],
756
+ self.cls_tokens.class_pentachora # Back to original
757
+ )
758
+
759
+ return {
760
+ 'logits': logits,
761
+ 'aux_logits': aux_logits,
762
+ 'geometric_alignments': geometric_alignments,
763
+ 'vertex_cls': features['vertex_cls'],
764
+ 'global_cls': features['global_cls'],
765
+ 'patches': features['patches']
766
+ }
767
+
768
+ # ============================================
769
+ # LOSS FUNCTIONS
770
+ # ============================================
771
+
772
+ class PentachoraLoss(nn.Module):
773
+ """Combined loss for PentachoraViT training."""
774
+ def __init__(self, aux_weight: float = 0.3, geo_weight: float = 0.1,
775
+ smoothing: float = 0.0):
776
+ super().__init__()
777
+ self.aux_weight = aux_weight
778
+ self.geo_weight = geo_weight
779
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing)
780
+
781
+ def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
782
+ """Compute combined loss."""
783
+ # Primary classification loss
784
+ loss = self.criterion(outputs['logits'], targets)
785
+
786
+ # Auxiliary loss from vertex tokens
787
+ if 'aux_logits' in outputs and self.aux_weight > 0:
788
+ aux_loss = self.criterion(outputs['aux_logits'], targets)
789
+ loss = loss + self.aux_weight * aux_loss
790
+
791
+ # Geometric alignment loss
792
+ if 'geometric_alignments' in outputs and self.geo_weight > 0:
793
+ geo_logits = outputs['geometric_alignments'].mean(dim=1)
794
+ geo_loss = self.criterion(geo_logits, targets)
795
+ loss = loss + self.geo_weight * geo_loss
796
+
797
+ return loss
798
+
799
+ # ============================================
800
+ # MODEL REGISTRY AND BUILDERS
801
+ # ============================================
802
+
803
+ MODEL_CONFIGS = {
804
+ 'pentachora_spark_xs': PentachoraConfig(
805
+ dim=100, depth=2, heads=10, mlp_ratio=4.0,
806
+ preserve_structure_until_layer=1,
807
+ dropout_rate=0.0, drop_path_rate=0.0
808
+ ),
809
+ 'pentachora_spark': PentachoraConfig(
810
+ dim=100, depth=5, heads=4, mlp_ratio=4.0,
811
+ preserve_structure_until_layer=1,
812
+ dropout_rate=0.0, drop_path_rate=0.0
813
+ ),
814
+ 'pentachora_shock': PentachoraConfig(
815
+ dim=100, depth=10, heads=5, mlp_ratio=4.0,
816
+ patch_size=5, preserve_structure_until_layer=5,
817
+ dropout_rate=0.0, drop_path_rate=0.0
818
+ ),
819
+ 'pentachora_shock_xs_32d': PentachoraConfig(
820
+ dim=32, depth=2, heads=8, mlp_ratio=4.0,
821
+ preserve_structure_until_layer=4,
822
+ dropout_rate=0.0, drop_path_rate=0.0
823
+ ),
824
+ 'pentachora_shock_xs_64d': PentachoraConfig(
825
+ dim=64, depth=2, heads=8, mlp_ratio=4.0,
826
+ preserve_structure_until_layer=4,
827
+ dropout_rate=0.0, drop_path_rate=0.0
828
+ ),
829
+ 'pentachora_shock_xs_128d': PentachoraConfig(
830
+ dim=128, depth=2, heads=8, mlp_ratio=4.0,
831
+ preserve_structure_until_layer=4,
832
+ dropout_rate=0.0, drop_path_rate=0.0
833
+ ),
834
+ 'pentachora_shock_xs_256d': PentachoraConfig(
835
+ dim=256, depth=2, heads=8, mlp_ratio=4.0,
836
+ preserve_structure_until_layer=4,
837
+ dropout_rate=0.0, drop_path_rate=0.0
838
+ ),
839
+ 'pentachora_shock_xs_512d': PentachoraConfig(
840
+ dim=512, depth=2, heads=8, mlp_ratio=4.0,
841
+ preserve_structure_until_layer=4,
842
+ dropout_rate=0.0, drop_path_rate=0.0
843
+ ),
844
+ 'pentachora_tiny': PentachoraConfig(
845
+ dim=384, depth=12, heads=6, mlp_ratio=4.0,
846
+ preserve_structure_until_layer=6,
847
+ dropout_rate=0.1, drop_path_rate=0.1
848
+ ),
849
+ 'pentachora_small': PentachoraConfig(
850
+ dim=512, depth=12, heads=8, mlp_ratio=4.0,
851
+ preserve_structure_until_layer=6,
852
+ dropout_rate=0.1, drop_path_rate=0.1
853
+ ),
854
+ 'pentachora_base': PentachoraConfig(
855
+ dim=768, depth=12, heads=12, mlp_ratio=4.0,
856
+ preserve_structure_until_layer=8,
857
+ dropout_rate=0.1, drop_path_rate=0.2
858
+ ),
859
+ 'pentachora_large': PentachoraConfig(
860
+ dim=1024, depth=24, heads=16, mlp_ratio=4.0,
861
+ preserve_structure_until_layer=12,
862
+ dropout_rate=0.1, drop_path_rate=0.3
863
+ ),
864
+ }
865
+
866
+ def create_pentachora_vit(variant: str = 'pentachora_small',
867
+ pretrained: bool = False,
868
+ **kwargs) -> PentachoraViT:
869
+ """Create PentachoraViT model."""
870
+ if variant not in MODEL_CONFIGS:
871
+ raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}")
872
+
873
+ config = MODEL_CONFIGS[variant]
874
+
875
+ # Override config with kwargs
876
+ for key, value in kwargs.items():
877
+ setattr(config, key, value)
878
+
879
+ model = PentachoraViT(config)
880
+
881
+ if pretrained:
882
+ warnings.warn("Pretrained weights not available yet")
883
+
884
+ return model
885
+
886
+ # Convenience functions for each variant
887
+ def pentachora_vit_spark_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT:
888
+ """Create spark variant (smallest)."""
889
+ return create_pentachora_vit('pentachora_spark_xs', pretrained=pretrained, **kwargs)
890
+
891
+ def pentachora_shock_xs_64d(pretrained: bool = False, **kwargs) -> PentachoraViT:
892
+ """Create shock xs 64d variant."""
893
+ return create_pentachora_vit('pentachora_shock_xs_64d', pretrained=pretrained, **kwargs)
894
+
895
+ def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT:
896
+ """Create spark variant."""
897
+ return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs)
898
+
899
+ def pentachora_shock_xs_32d(pretrained: bool = False, **kwargs) -> PentachoraViT:
900
+ """Create shock xs 32d variant."""
901
+ return create_pentachora_vit('pentachora_shock_xs_32d', pretrained=pretrained, **kwargs)
902
+
903
+ def pentachora_shock_xs_256d(pretrained: bool = False, **kwargs) -> PentachoraViT:
904
+ """Create shock xs 256d variant."""
905
+ return create_pentachora_vit('pentachora_shock_xs_256d', pretrained=pretrained, **kwargs)
906
+
907
+ def pentachora_shock_xs_512d(pretrained: bool = False, **kwargs) -> PentachoraViT:
908
+ """Create shock xs 512d variant."""
909
+ return create_pentachora_vit('pentachora_shock_xs_512d', pretrained=pretrained, **kwargs)
910
+
911
+ def pentachora_vit_shock(pretrained: bool = False, **kwargs) -> PentachoraViT:
912
+ """Create shock variant."""
913
+ return create_pentachora_vit('pentachora_shock', pretrained=pretrained, **kwargs)
914
+
915
+ def pentachora_vit_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT:
916
+ """Create tiny variant."""
917
+ return create_pentachora_vit('pentachora_tiny', pretrained=pretrained, **kwargs)
918
+
919
+ def pentachora_vit_small(pretrained: bool = False, **kwargs) -> PentachoraViT:
920
+ """Create small variant."""
921
+ return create_pentachora_vit('pentachora_small', pretrained=pretrained, **kwargs)
922
+
923
+ def pentachora_vit_base(pretrained: bool = False, **kwargs) -> PentachoraViT:
924
+ """Create base variant."""
925
+ return create_pentachora_vit('pentachora_base', pretrained=pretrained, **kwargs)
926
+
927
+ def pentachora_vit_large(pretrained: bool = False, **kwargs) -> PentachoraViT:
928
+ """Create large variant."""
929
+ return create_pentachora_vit('pentachora_large', pretrained=pretrained, **kwargs)
930
+
931
+ # ============================================
932
+ # TRAINING UTILITIES
933
+ # ============================================
934
+
935
+ def get_parameter_groups(model: PentachoraViT,
936
+ weight_decay: float = 0.05) -> List[Dict[str, Any]]:
937
+ """Get parameter groups for optimizer with weight decay handling."""
938
+ no_decay = ['bias', 'norm', 'LayerNorm']
939
+
940
+ decay_params = []
941
+ no_decay_params = []
942
+
943
+ for name, param in model.named_parameters():
944
+ if not param.requires_grad:
945
+ continue
946
+
947
+ if any(nd in name for nd in no_decay):
948
+ no_decay_params.append(param)
949
+ else:
950
+ decay_params.append(param)
951
+
952
+ return [
953
+ {'params': decay_params, 'weight_decay': weight_decay},
954
+ {'params': no_decay_params, 'weight_decay': 0.0}
955
+ ]
956
+
957
+ def count_parameters(model: nn.Module) -> Dict[str, int]:
958
+ """Count model parameters."""
959
+ total = sum(p.numel() for p in model.parameters())
960
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
961
+ return {
962
+ 'total': total,
963
+ 'trainable': trainable,
964
+ 'non_trainable': total - trainable
965
+ }
966
+
967
+ # ============================================
968
+ # INFERENCE UTILITIES
969
+ # ============================================
970
+
971
+ @torch.no_grad()
972
+ def extract_features(model: PentachoraViT,
973
+ images: torch.Tensor,
974
+ feature_type: str = 'global_cls') -> torch.Tensor:
975
+ """Extract features from images using the model."""
976
+ model.eval()
977
+ features = model.forward_features(images)
978
+ return features.get(feature_type, features['global_cls'])
979
+
980
+ # ============================================
981
+ # EXAMPLE USAGE AND TESTING
982
+ # ============================================
983
+
984
+ def test_model():
985
+ """Test model creation and forward pass."""
986
+ print("Testing Optimized PentachoraViT Model")
987
+ print("=" * 50)
988
+
989
+ # Test different variants
990
+ variants = ['pentachora_spark', 'pentachora_shock_xs_256d', 'pentachora_small']
991
+
992
+ for variant in variants:
993
+ print(f"\nTesting {variant}:")
994
+
995
+ # Create model with vocab_dim
996
+ model = create_pentachora_vit(
997
+ variant=variant,
998
+ img_size=32,
999
+ patch_size=4,
1000
+ num_classes=100,
1001
+ vocab_dim=64
1002
+ )
1003
+
1004
+ # Count parameters
1005
+ params = count_parameters(model)
1006
+ print(f" Total parameters: {params['total']:,}")
1007
+ print(f" Trainable parameters: {params['trainable']:,}")
1008
+
1009
+ # Test forward pass
1010
+ x = torch.randn(2, 3, 32, 32)
1011
+
1012
+ # Time the forward pass
1013
+ if torch.cuda.is_available():
1014
+ model = model.cuda()
1015
+ x = x.cuda()
1016
+ torch.cuda.synchronize()
1017
+
1018
+ import time
1019
+ start = time.time()
1020
+ outputs = model(x)
1021
+ if torch.cuda.is_available():
1022
+ torch.cuda.synchronize()
1023
+ end = time.time()
1024
+
1025
+ print(f" Output shapes:")
1026
+ print(f" Logits: {outputs['logits'].shape}")
1027
+ print(f" Aux logits: {outputs['aux_logits'].shape}")
1028
+ print(f" Geometric alignments: {outputs['geometric_alignments'].shape}")
1029
+ print(f" Forward pass time: {(end - start)*1000:.2f}ms")
1030
+
1031
+ # Test loss computation
1032
+ loss_fn = PentachoraLoss()
1033
+ targets = torch.randint(0, 100, (2,))
1034
+ if torch.cuda.is_available():
1035
+ targets = targets.cuda()
1036
+ loss = loss_fn(outputs, targets)
1037
+ print(f" Loss: {loss.item():.4f}")
1038
+
1039
+ print("\n" + "=" * 50)
1040
+ print("All tests passed!")
1041
+
1042
+ if __name__ == "__main__":
1043
+ # Run tests
1044
+ #test_model()
1045
+
1046
+ # Example: Create model for A100 training
1047
+ print("\nExample: Creating optimized model for A100 training")
1048
+ model = pentachora_shock_xs_256d(
1049
+ img_size=32,
1050
+ num_classes=100,
1051
+ vocab_dim=100,
1052
+ dropout_rate=0.0,
1053
+ drop_path_rate=0.0
1054
+ )
1055
+
1056
+ # Move model to CUDA first if available
1057
+ if torch.cuda.is_available():
1058
+ model = model.cuda()
1059
+ print("Model moved to CUDA")
1060
+
1061
+ # Now try torch.compile (PyTorch 2.0+)
1062
+ # Model reformatted to allow eager compiling, speeds along training substantially.
1063
+ if hasattr(torch, 'compile'):
1064
+ print("Compiling model with torch.compile...")
1065
+ try:
1066
+ model = torch.compile(model, backend="eager")
1067
+ print("Model compiled successfully")
1068
+ except Exception as e:
1069
+ print(f"Compilation warning: {e}")
1070
+ print("Continuing without compilation - vectorized ops will still provide speedup")
1071
+
1072
+ # Get parameter groups for optimizer
1073
+ param_groups = get_parameter_groups(model, weight_decay=0.05)
1074
+ print(f"Number of parameter groups: {len(param_groups)}")
1075
+
1076
+ # Example batch - FULL PRECISION
1077
+ images = torch.randn(4, 3, 32, 32)
1078
+ targets = torch.randint(0, 100, (4,))
1079
+
1080
+ if torch.cuda.is_available():
1081
+ images = images.cuda()
1082
+ targets = targets.cuda()
1083
+
1084
+ # Forward pass in FULL PRECISION (no autocast)
1085
+ outputs = model(images)
1086
+
1087
+ # Compute loss
1088
+ criterion = PentachoraLoss(aux_weight=0.3, geo_weight=0.1)
1089
+ loss = criterion(outputs, targets)
1090
+
1091
+ print(f"Training loss (full precision): {loss.item():.4f}")
1092
+ print("\nModel ready for full precision A100 training!")
1093
+ print("Eager initialization ensures all parameters are created upfront")