AbstractPhil commited on
Commit
96bb4fe
·
verified ·
1 Parent(s): 855fbfb

Create penta_vit_model_v2.py

Browse files
Files changed (1) hide show
  1. penta_vit_model_v2.py +1158 -0
penta_vit_model_v2.py ADDED
@@ -0,0 +1,1158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
3
+ Enhanced with Geometric Attention for improved head cohesion and generalization
4
+ FIXED: All parameters initialized at module creation time (no lazy init)
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ from einops import rearrange, repeat
12
+ import math
13
+ from typing import Optional, Dict, Tuple, List, Any
14
+ from dataclasses import dataclass
15
+ import warnings
16
+
17
+ # ============================================
18
+ # CONFIGURATION CLASSES
19
+ # ============================================
20
+
21
+ @dataclass
22
+ class PentachoraConfig:
23
+ """Configuration for PentachoraViT models."""
24
+ img_size: int = 32
25
+ patch_size: int = 4
26
+ num_classes: int = 100
27
+ dim: int = 512
28
+ vocab_dim: Optional[int] = None # Vocabulary dimension (can differ from model dim)
29
+ depth: int = 12
30
+ heads: int = 8
31
+ mlp_ratio: float = 4.0
32
+ use_mesh_attention: bool = True
33
+ preserve_structure_until_layer: int = 6
34
+ dropout_rate: float = 0.0
35
+ drop_path_rate: float = 0.0
36
+ aux_loss_weight: float = 0.0
37
+ geo_loss_weight: float = 0.0
38
+ vocab: Optional[Any] = None
39
+
40
+ @property
41
+ def num_patches(self) -> int:
42
+ return (self.img_size // self.patch_size) ** 2
43
+
44
+ # ============================================
45
+ # GEOMETRIC ATTENTION COMPONENTS (FIXED INIT)
46
+ # ============================================
47
+
48
+ def perfect_4simplex(device):
49
+ """Create perfect 4-simplex (pentachoron) vertices in 4D."""
50
+ sqrt5 = math.sqrt(5)
51
+ vertices = torch.tensor([
52
+ [1, 1, 1, -1/sqrt5],
53
+ [1, -1, -1, -1/sqrt5],
54
+ [-1, 1, -1, -1/sqrt5],
55
+ [-1, -1, 1, -1/sqrt5],
56
+ [0, 0, 0, 4/sqrt5]
57
+ ], device=device, dtype=torch.float32)
58
+ return vertices / 2 # Normalize scale
59
+
60
+ def softmin_over_last(distances, tau):
61
+ """Softmin over last dimension."""
62
+ return F.softmax(-distances / tau, dim=-1).sum(dim=-1)
63
+
64
+ @dataclass
65
+ class GeometricConfig:
66
+ """Configuration for geometric attention."""
67
+ softmin_tau: float = 0.05
68
+ fuse_alpha: float = 0.7
69
+ phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2)
70
+ jitter: float = 0.02
71
+ shift: float = 0.71
72
+ rotate_cycle: int = 11
73
+ use_phase_variance: bool = False
74
+ geometry_type: str = "pentachoron"
75
+
76
+ class GeometricNavigator(nn.Module):
77
+ """Maps inputs to geometric regions in 4D space - FIXED with immediate initialization."""
78
+
79
+ def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig, num_heads: int = 1, device=None):
80
+ super().__init__()
81
+ self.input_dim = input_dim
82
+ self.num_regions = num_regions
83
+ self.config = config
84
+ self.num_heads = num_heads
85
+
86
+ # Use CPU by default if device not specified
87
+ if device is None:
88
+ device = torch.device('cpu')
89
+
90
+ # Create separate parameters for each head if num_heads > 1
91
+ if num_heads > 1:
92
+ self.to_nav = nn.Parameter(torch.randn(num_heads, input_dim, 4, device=device) * 0.02)
93
+ self.vertex_w = nn.Parameter(torch.zeros(num_heads, num_regions, 5, device=device))
94
+ else:
95
+ self.to_nav = nn.Linear(input_dim, 4, bias=False)
96
+ self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5, device=device))
97
+
98
+ # Pre-compute phase tensors for vectorization
99
+ self.register_buffer('phase_cos', torch.cos(torch.tensor(config.phases, dtype=torch.float32, device=device)))
100
+ self.register_buffer('phase_sin', torch.sin(torch.tensor(config.phases, dtype=torch.float32, device=device)))
101
+
102
+ # Initialize geometry immediately at creation time
103
+ self._init_geometry(device)
104
+
105
+ def _init_geometry(self, device):
106
+ """Initialize geometry at module creation time."""
107
+ base = perfect_4simplex(device)
108
+
109
+ if self.num_heads > 1:
110
+ D = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device)
111
+ S = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device)
112
+
113
+ for h in range(self.num_heads):
114
+ for r in range(self.num_regions):
115
+ D[h, r] = base + self.config.jitter * torch.randn_like(base)
116
+
117
+ theta = torch.tensor(0.2914 + 0.05 * (r % self.config.rotate_cycle), device=device)
118
+ rot = torch.eye(4, device=device)
119
+ c, s_val = torch.cos(theta), torch.sin(theta)
120
+ rot[0, 0] = c; rot[0, 1] = -s_val
121
+ rot[1, 0] = s_val; rot[1, 1] = c
122
+ S[h, r] = (base @ rot) + self.config.shift
123
+ S[h, r] += self.config.jitter * torch.randn_like(S[h, r])
124
+ else:
125
+ D = torch.zeros(self.num_regions, 5, 4, device=device)
126
+ S = torch.zeros(self.num_regions, 5, 4, device=device)
127
+
128
+ for r in range(self.num_regions):
129
+ D[r] = base + self.config.jitter * torch.randn_like(base)
130
+
131
+ theta = torch.tensor(0.2914 + 0.05 * (r % self.config.rotate_cycle), device=device)
132
+ rot = torch.eye(4, device=device)
133
+ c, s_val = torch.cos(theta), torch.sin(theta)
134
+ rot[0, 0] = c; rot[0, 1] = -s_val
135
+ rot[1, 0] = s_val; rot[1, 1] = c
136
+ S[r] = (base @ rot) + self.config.shift
137
+ S[r] += self.config.jitter * torch.randn_like(S[r])
138
+
139
+ self.D = nn.Parameter(D)
140
+ self.S = nn.Parameter(S)
141
+
142
+ def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
143
+ """Navigate inputs through geometric space - OPTIMIZED with vectorized phase computation."""
144
+ if self.num_heads > 1:
145
+ # Batched navigation for multiple heads
146
+ BT, H, head_dim = x.shape
147
+
148
+ # Batched transformation
149
+ nav_x = torch.einsum('bhi,hio->bho', x, self.to_nav) # [BT, H, 4]
150
+
151
+ # Dispatcher scores
152
+ nav_x_disp = nav_x.view(BT, H, 1, 1, 4)
153
+ D_exp = self.D.unsqueeze(0) # [1, H, regions, 5, 4]
154
+ d_disp = torch.norm(nav_x_disp - D_exp, dim=-1)
155
+ s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
156
+
157
+ # OPTIMIZED: Vectorized phase computation (no loop)
158
+ cos_phases = self.phase_cos.view(-1, 1, 1, 1, 1)
159
+ sin_phases = self.phase_sin.view(-1, 1, 1, 1, 1)
160
+
161
+ # Compute all phase variants at once [phases, H, regions, 5, 4]
162
+ Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
163
+
164
+ # Apply vertex weighting to all phases
165
+ w = F.softmax(self.vertex_w, dim=-1)
166
+ w_exp = w.unsqueeze(0).unsqueeze(-1) # [1, H, regions, 5, 1]
167
+ Vt_mean = Vt_all.mean(dim=3, keepdim=True)
168
+ Vt_all = (1.0 - w_exp) * Vt_all + w_exp * Vt_mean
169
+
170
+ # Compute all ribbon distances at once
171
+ nav_x_ribbon = nav_x.view(BT, 1, H, 1, 1, 4)
172
+ Vt_exp = Vt_all.unsqueeze(0) # [1, phases, H, regions, 5, 4]
173
+ d_ribbon_all = torch.norm(nav_x_ribbon - Vt_exp, dim=-1)
174
+ s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau)
175
+ s_ribbon = s_ribbon_all.mean(dim=1) # Average over phases
176
+
177
+ scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
178
+ scores = scores.reshape(BT * H, self.num_regions)
179
+
180
+ else:
181
+ # Original single-head navigation
182
+ nav_x = self.to_nav(x)
183
+ nav_x_exp = nav_x[:, None, None, :]
184
+ D_exp = self.D[None, :, :, :]
185
+
186
+ d_disp = torch.norm(nav_x_exp - D_exp, dim=-1)
187
+ s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
188
+
189
+ w = F.softmax(self.vertex_w, dim=1)
190
+
191
+ # OPTIMIZED: Vectorized phase computation for single head
192
+ cos_phases = self.phase_cos.view(-1, 1, 1, 1)
193
+ sin_phases = self.phase_sin.view(-1, 1, 1, 1)
194
+
195
+ Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0)
196
+ w_expanded = w.unsqueeze(0).unsqueeze(-1)
197
+ Vt_mean = Vt_all.mean(dim=2, keepdim=True)
198
+ Vt_all = (1.0 - w_expanded) * Vt_all + w_expanded * Vt_mean
199
+
200
+ nav_x_phase = nav_x[:, None, None, None, :]
201
+ Vt_exp = Vt_all.unsqueeze(0)
202
+ d_ribbon_all = torch.norm(nav_x_phase - Vt_exp, dim=-1)
203
+ s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau)
204
+ s_ribbon = s_ribbon_all.mean(dim=1)
205
+
206
+ scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
207
+
208
+ diagnostics = {
209
+ 'dispatcher_scores': s_disp.detach() if self.num_heads == 1 else s_disp.reshape(BT * H, -1).detach(),
210
+ 'ribbon_scores': s_ribbon.detach() if self.num_heads == 1 else s_ribbon.reshape(BT * H, -1).detach()
211
+ }
212
+
213
+ return {'scores': scores, 'diagnostics': diagnostics}
214
+
215
+ class GeometricAttention(nn.Module):
216
+ """Multi-head geometric attention with Q-K alignment - FIXED with proper device handling."""
217
+
218
+ def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
219
+ config: Optional[GeometricConfig] = None, dropout: float = 0.0, device=None):
220
+ super().__init__()
221
+ self.dim = dim
222
+ self.num_heads = num_heads
223
+ self.head_dim = dim // num_heads
224
+
225
+ if num_regions is None:
226
+ num_regions = min(self.head_dim, 16)
227
+ if config is None:
228
+ config = GeometricConfig()
229
+
230
+ self.config = config
231
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
232
+
233
+ # Create batched navigators with device
234
+ self.q_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device)
235
+ self.k_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device)
236
+
237
+ self.out_proj = nn.Linear(dim, dim)
238
+ self.dropout = nn.Dropout(dropout)
239
+
240
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
241
+ return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]:
242
+ B, T, D = x.shape
243
+
244
+ qkv = self.to_qkv(x)
245
+ q, k, v = qkv.chunk(3, dim=-1)
246
+
247
+ q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
248
+ k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
249
+ v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
250
+
251
+ # Prepare for batched navigation
252
+ q_batched = q.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim)
253
+ k_batched = k.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim)
254
+
255
+ # Navigate all heads at once
256
+ q_nav = self.q_navigator.navigate(q_batched)
257
+ k_nav = self.k_navigator.navigate(k_batched)
258
+
259
+ # Reshape scores back
260
+ q_scores = q_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2)
261
+ k_scores = k_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2)
262
+
263
+ # OPTIMIZED: Compute attention for all heads at once using einsum
264
+ scale = math.sqrt(q_scores.size(-1))
265
+ attn = torch.einsum('bhqr,bhkr->bhqk', q_scores, k_scores) / scale
266
+
267
+ if mask is not None:
268
+ mask_expanded = mask.unsqueeze(1).unsqueeze(2)
269
+ attn = attn.masked_fill(mask_expanded == 0, -1e9)
270
+
271
+ attn = F.softmax(attn, dim=-1)
272
+ attn = self.dropout(attn)
273
+
274
+ # Apply attention to values
275
+ out = torch.einsum('bhqk,bhkd->bhqd', attn, v)
276
+ out = out.transpose(1, 2).reshape(B, T, D)
277
+
278
+ output = self.out_proj(out)
279
+ output = self.dropout(output)
280
+
281
+ if return_diagnostics:
282
+ return output, {'q_diagnostics': q_nav['diagnostics'], 'k_diagnostics': k_nav['diagnostics']}
283
+ return output, None
284
+
285
+ # ============================================
286
+ # DROP PATH (STOCHASTIC DEPTH)
287
+ # ============================================
288
+
289
+ class DropPath(nn.Module):
290
+ """Drop paths (Stochastic Depth) per sample."""
291
+ def __init__(self, drop_prob: float = 0.):
292
+ super().__init__()
293
+ self.drop_prob = drop_prob
294
+
295
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
296
+ if self.drop_prob == 0. or not self.training:
297
+ return x
298
+ keep_prob = 1 - self.drop_prob
299
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
300
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
301
+ random_tensor.floor_()
302
+ output = x.div(keep_prob) * random_tensor
303
+ return output
304
+
305
+ # ============================================
306
+ # HIERARCHICAL CLS WITH PENTACHORA
307
+ # ============================================
308
+
309
+ class HierarchicalPentachoronCLS(nn.Module):
310
+ """
311
+ Hierarchical CLS structure with pentachoron geometry.
312
+ Uses vocabulary embeddings for CLS tokens.
313
+ """
314
+ def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100):
315
+ super().__init__()
316
+ self.dim = dim
317
+ self.vocab_dim = vocab_dim
318
+ self.num_classes = num_classes
319
+
320
+ # Class-specific pentachora from vocabulary
321
+ self.register_buffer('class_pentachora', torch.randn(num_classes, 5, vocab_dim) * 0.02)
322
+
323
+ # Projection from vocabulary dimension to model dimension
324
+ if vocab_dim != dim:
325
+ self.vocab_to_model = nn.Linear(vocab_dim, dim)
326
+ else:
327
+ self.vocab_to_model = nn.Identity()
328
+
329
+ # Learnable aggregation weights
330
+ self.vertex_weights = nn.Parameter(torch.ones(5) / 5)
331
+
332
+ # Optional learnable offset
333
+ self.global_offset = nn.Parameter(torch.zeros(1, 1, dim))
334
+
335
+ # Layer norms
336
+ self.vertex_norm = nn.LayerNorm(dim)
337
+ self.global_norm = nn.LayerNorm(dim)
338
+
339
+ def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
340
+ """Generate CLS tokens for batch."""
341
+ # Get class-specific pentachora
342
+ class_pentachora = self.class_pentachora # This is now a computed property
343
+
344
+ if class_indices is not None and class_indices.shape[0] == batch_size:
345
+ vertex_cls_vocab = class_pentachora[class_indices]
346
+ else:
347
+ vertex_cls_vocab = class_pentachora.mean(dim=0, keepdim=True)
348
+ vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1)
349
+
350
+ # Project from vocabulary dimension to model dimension
351
+ vertex_cls = self.vocab_to_model(vertex_cls_vocab)
352
+ vertex_cls = self.vertex_norm(vertex_cls)
353
+
354
+ # Create global CLS as weighted combination
355
+ weights = F.softmax(self.vertex_weights, dim=0)
356
+ global_cls = torch.einsum('bvd,v->bd', vertex_cls, weights).unsqueeze(1)
357
+ global_cls = global_cls + self.global_offset
358
+ global_cls = self.global_norm(global_cls)
359
+
360
+ return global_cls, vertex_cls
361
+
362
+ def get_class_prototypes(self) -> torch.Tensor:
363
+ """Get class prototypes in model dimension."""
364
+ class_pentachora = self.class_pentachora # Get computed pentachora
365
+ pentachora_model = self.vocab_to_model(class_pentachora)
366
+ weights = F.softmax(self.vertex_weights, dim=0)
367
+ prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights)
368
+ return prototypes
369
+
370
+ # ============================================
371
+ # GEOMETRIC PROJECTION LAYER
372
+ # ============================================
373
+
374
+ class GeometricProjection(nn.Module):
375
+ """Project patches onto pentachoron geometry."""
376
+ def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1):
377
+ super().__init__()
378
+ self.dim = dim
379
+ self.vocab_dim = vocab_dim
380
+ self.num_classes = num_classes
381
+
382
+ # Projection from model dim to vocab dim
383
+ self.to_vocab_space = nn.Linear(dim, vocab_dim)
384
+
385
+ # Vertex-specific projections
386
+ self.vertex_projections = nn.ModuleList([
387
+ nn.Linear(vocab_dim, vocab_dim, bias=False) for _ in range(5)
388
+ ])
389
+
390
+ # Temperature for alignment scores
391
+ self.temperature = nn.Parameter(torch.ones(1))
392
+
393
+ self.norm = nn.LayerNorm(dim)
394
+ self.dropout = nn.Dropout(dropout)
395
+
396
+ def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor:
397
+ """Compute alignment between patches and class pentachora."""
398
+ B, N, D = patches.shape
399
+ C = pentachora.shape[0]
400
+
401
+ # Normalize patches
402
+ patches = self.norm(patches)
403
+
404
+ # Project patches to vocabulary space
405
+ patches_vocab = self.to_vocab_space(patches)
406
+ patches_vocab = F.normalize(patches_vocab, dim=-1)
407
+
408
+ # Compute alignment with each vertex
409
+ alignments = []
410
+ for v in range(5):
411
+ patches_v = self.vertex_projections[v](patches_vocab)
412
+ patches_v = F.normalize(patches_v, dim=-1)
413
+ vertex_v = F.normalize(pentachora[:, v, :], dim=-1)
414
+ alignment = torch.matmul(patches_v, vertex_v.T) / self.temperature
415
+ alignments.append(alignment)
416
+
417
+ # Average alignments across vertices
418
+ alignments = torch.stack(alignments, dim=-1).mean(dim=-1)
419
+
420
+ return self.dropout(alignments)
421
+
422
+ # ============================================
423
+ # MLP BLOCK
424
+ # ============================================
425
+
426
+ class MLP(nn.Module):
427
+ """MLP block with GELU activation."""
428
+ def __init__(self, in_features: int, hidden_features: Optional[int] = None,
429
+ out_features: Optional[int] = None, dropout: float = 0.):
430
+ super().__init__()
431
+ out_features = out_features or in_features
432
+ hidden_features = hidden_features or in_features
433
+
434
+ self.fc1 = nn.Linear(in_features, hidden_features)
435
+ self.act = nn.GELU()
436
+ self.drop1 = nn.Dropout(dropout)
437
+ self.fc2 = nn.Linear(hidden_features, out_features)
438
+ self.drop2 = nn.Dropout(dropout)
439
+
440
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
441
+ x = self.fc1(x)
442
+ x = self.act(x)
443
+ x = self.drop1(x)
444
+ x = self.fc2(x)
445
+ x = self.drop2(x)
446
+ return x
447
+
448
+ # ============================================
449
+ # VIT BLOCK WITH GEOMETRIC ATTENTION
450
+ # ============================================
451
+
452
+ class PentachoronViTBlock(nn.Module):
453
+ """ViT block with geometric attention for structured layers."""
454
+ def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0,
455
+ use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0.,
456
+ drop_path: float = 0., device=None):
457
+ super().__init__()
458
+ self.norm1 = nn.LayerNorm(dim)
459
+
460
+ # Use GeometricAttention for structured layers, standard for others
461
+ if use_mesh:
462
+ self.attn = GeometricAttention(
463
+ dim=dim,
464
+ num_heads=heads,
465
+ num_regions=min(dim // heads, 16),
466
+ config=GeometricConfig(),
467
+ dropout=attn_dropout,
468
+ device=device
469
+ )
470
+ else:
471
+ # Standard multi-head attention for later layers
472
+ self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True)
473
+
474
+ self.use_mesh = use_mesh
475
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
476
+
477
+ self.norm2 = nn.LayerNorm(dim)
478
+ mlp_hidden = int(dim * mlp_ratio)
479
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout)
480
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
481
+
482
+ def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor:
483
+ if self.use_mesh:
484
+ # GeometricAttention
485
+ attn_out, _ = self.attn(self.norm1(x))
486
+ x = x + self.drop_path1(attn_out)
487
+ else:
488
+ # Standard attention
489
+ normalized = self.norm1(x)
490
+ attn_out, _ = self.attn(normalized, normalized, normalized)
491
+ x = x + self.drop_path1(attn_out)
492
+
493
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
494
+ return x
495
+
496
+ # ============================================
497
+ # PATCH EMBEDDING
498
+ # ============================================
499
+
500
+ class PatchEmbed(nn.Module):
501
+ """2D Image to Patch Embedding."""
502
+ def __init__(self, img_size: int = 32, patch_size: int = 4,
503
+ in_chans: int = 3, embed_dim: int = 512):
504
+ super().__init__()
505
+ self.img_size = img_size
506
+ self.patch_size = patch_size
507
+ self.num_patches = (img_size // patch_size) ** 2
508
+
509
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
510
+ self.norm = nn.LayerNorm(embed_dim)
511
+
512
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
513
+ x = self.proj(x)
514
+ x = rearrange(x, 'b c h w -> b (h w) c')
515
+ x = self.norm(x)
516
+ return x
517
+
518
+ # ============================================
519
+ # PENTACHORA VISION TRANSFORMER
520
+ # ============================================
521
+
522
+ class PentachoraViT(nn.Module):
523
+ """
524
+ Vision Transformer with pentachoron-based hierarchical CLS tokens
525
+ and geometric vocabulary integration.
526
+ """
527
+ def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs):
528
+ super().__init__()
529
+
530
+ # Use config or kwargs
531
+ if config is not None:
532
+ cfg = config
533
+ else:
534
+ cfg = PentachoraConfig(**kwargs)
535
+
536
+ self.config = cfg
537
+ self.num_classes = cfg.num_classes
538
+ self.dim = cfg.dim
539
+ self.depth = cfg.depth
540
+ self.preserve_structure_until_layer = cfg.preserve_structure_until_layer
541
+
542
+ # Set vocabulary dimension
543
+ if cfg.vocab_dim is not None:
544
+ self.vocab_dim = cfg.vocab_dim
545
+ elif 'vocab_dim' in kwargs:
546
+ self.vocab_dim = kwargs['vocab_dim']
547
+ else:
548
+ self.vocab_dim = cfg.dim
549
+
550
+ # Patch embedding
551
+ self.patch_embed = PatchEmbed(
552
+ cfg.img_size, cfg.patch_size, 3, cfg.dim
553
+ )
554
+ num_patches = self.patch_embed.num_patches
555
+
556
+ # Positional embedding
557
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02)
558
+ self.pos_drop = nn.Dropout(cfg.dropout_rate)
559
+
560
+ # CLS tokens with pentachoron structure
561
+ self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes)
562
+
563
+ # Geometric projection layer
564
+ self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate)
565
+
566
+ # Initialize from vocabulary if provided
567
+ if cfg.vocab is not None:
568
+ self._init_from_vocab(cfg.vocab)
569
+
570
+ # Stochastic depth decay rule
571
+ dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)]
572
+
573
+ # Transformer blocks with geometric attention
574
+ self.blocks = nn.ModuleList([
575
+ PentachoronViTBlock(
576
+ dim=cfg.dim,
577
+ heads=cfg.heads,
578
+ mlp_ratio=cfg.mlp_ratio,
579
+ use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer),
580
+ dropout=cfg.dropout_rate,
581
+ attn_dropout=cfg.dropout_rate,
582
+ drop_path=dpr[i],
583
+ device=torch.device('cpu') # Initialize on CPU, will be moved later
584
+ )
585
+ for i in range(cfg.depth)
586
+ ])
587
+
588
+ # Final norm
589
+ self.norm = nn.LayerNorm(cfg.dim)
590
+
591
+ # Classification heads
592
+ self.use_prototype_classifier = True
593
+ if self.use_prototype_classifier:
594
+ self.head = None
595
+ else:
596
+ self.head = nn.Linear(cfg.dim, cfg.num_classes)
597
+
598
+ # Auxiliary head for vertex tokens
599
+ self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes)
600
+
601
+ # Initialize weights
602
+ self.apply(self._init_weights)
603
+
604
+ def _init_weights(self, m: nn.Module):
605
+ """Initialize model weights."""
606
+ if isinstance(m, nn.Linear):
607
+ nn.init.trunc_normal_(m.weight, std=0.02)
608
+ if m.bias is not None:
609
+ nn.init.constant_(m.bias, 0)
610
+ elif isinstance(m, nn.LayerNorm):
611
+ nn.init.constant_(m.bias, 0)
612
+ nn.init.constant_(m.weight, 1.0)
613
+ elif isinstance(m, nn.Conv2d):
614
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
615
+ if m.bias is not None:
616
+ nn.init.constant_(m.bias, 0)
617
+
618
+ def _init_from_vocab(self, vocab):
619
+ """Initialize class pentachora from geometric vocabulary."""
620
+ try:
621
+ print("Initializing pentachora from vocabulary...")
622
+
623
+ if not hasattr(vocab, 'encode_batch'):
624
+ print("Vocabulary provided but encode_batch method not found, using random initialization")
625
+ return
626
+
627
+ # Get CIFAR-100 class names
628
+ class_names = self._get_cifar100_classes()
629
+
630
+ # Generate pentachora for all classes
631
+ pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True)
632
+ pentachora = np.stack(pentachora_list, axis=0)
633
+
634
+ # Get actual dimensions from the encoded data
635
+ actual_vocab_dim = pentachora.shape[-1]
636
+
637
+ print(f"Encoded pentachora shape: {pentachora.shape}")
638
+ print(f"Detected vocabulary dimension: {actual_vocab_dim}")
639
+
640
+ # Validate basic shape requirements
641
+ if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5:
642
+ print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}")
643
+ print("Using random initialization")
644
+ return
645
+
646
+ # Update vocabulary dimension
647
+ self.vocab_dim = actual_vocab_dim
648
+ self.cls_tokens.vocab_dim = actual_vocab_dim
649
+ self.geometric_proj.vocab_dim = actual_vocab_dim
650
+
651
+ # Replace class_pentachora with the loaded vocabulary
652
+ self.cls_tokens.class_pentachora = torch.tensor(pentachora, dtype=torch.float32)
653
+
654
+ # Update/create projection layer if dimensions differ
655
+ if actual_vocab_dim != self.dim:
656
+ self.cls_tokens.vocab_to_model = nn.Linear(actual_vocab_dim, self.dim)
657
+ else:
658
+ self.cls_tokens.vocab_to_model = nn.Identity()
659
+
660
+ # Rebuild geometric projection components
661
+ self.geometric_proj.to_vocab_space = nn.Linear(self.dim, actual_vocab_dim)
662
+ self.geometric_proj.vertex_projections = nn.ModuleList([
663
+ nn.Linear(actual_vocab_dim, actual_vocab_dim, bias=False) for _ in range(5)
664
+ ])
665
+
666
+ # Re-initialize the new layers
667
+ nn.init.xavier_uniform_(self.geometric_proj.to_vocab_space.weight)
668
+ for proj in self.geometric_proj.vertex_projections:
669
+ nn.init.xavier_uniform_(proj.weight)
670
+ if actual_vocab_dim != self.dim:
671
+ nn.init.xavier_uniform_(self.cls_tokens.vocab_to_model.weight)
672
+
673
+ print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary")
674
+ print(f" Vocabulary dimension: {actual_vocab_dim}")
675
+ print(f" Model internal dimension: {self.dim}")
676
+
677
+ except Exception as e:
678
+ print(f"Error initializing from vocabulary: {e}")
679
+ print("Using random initialization")
680
+
681
+ def _get_cifar100_classes(self):
682
+ """Get CIFAR-100 class names."""
683
+ return [
684
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
685
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
686
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
687
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
688
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
689
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
690
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
691
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
692
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
693
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
694
+ 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
695
+ 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
696
+ 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
697
+ 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
698
+ ]
699
+
700
+ def forward_features(self, x: torch.Tensor, class_indices: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
701
+ """Extract features from input."""
702
+ B = x.shape[0]
703
+
704
+ # Patch embedding
705
+ x = self.patch_embed(x)
706
+ x = x + self.pos_embed
707
+ x = self.pos_drop(x)
708
+
709
+ # Get hierarchical CLS tokens
710
+ global_cls, vertex_cls = self.cls_tokens(B, class_indices)
711
+
712
+ # Concatenate CLS tokens with patches
713
+ x = torch.cat([global_cls, vertex_cls, x], dim=1)
714
+
715
+ # Apply transformer blocks
716
+ for i, block in enumerate(self.blocks):
717
+ preserve = i < self.preserve_structure_until_layer
718
+ x = block(x, preserve_structure=preserve)
719
+
720
+ # Apply final norm
721
+ x = self.norm(x)
722
+
723
+ # Split tokens
724
+ global_cls = x[:, 0]
725
+ vertex_cls = x[:, 1:6]
726
+ patches = x[:, 6:]
727
+
728
+ return {
729
+ 'global_cls': global_cls,
730
+ 'vertex_cls': vertex_cls,
731
+ 'patches': patches
732
+ }
733
+
734
+ def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
735
+ """Forward pass through the model."""
736
+ # During training, use target labels for class-specific CLS initialization
737
+ class_indices = targets if self.training and targets is not None else None
738
+
739
+ features = self.forward_features(x, class_indices)
740
+
741
+ # Primary classification using prototype matching
742
+ if self.use_prototype_classifier:
743
+ prototypes = self.cls_tokens.get_class_prototypes()
744
+ prototypes = F.normalize(prototypes, dim=-1)
745
+ global_cls_norm = F.normalize(features['global_cls'], dim=-1)
746
+ logits = torch.matmul(global_cls_norm, prototypes.T) * 20.0
747
+ else:
748
+ logits = self.head(features['global_cls'])
749
+
750
+ # Auxiliary classification using vertex tokens
751
+ B = features['vertex_cls'].shape[0]
752
+ vertex_flat = features['vertex_cls'].reshape(B, -1)
753
+ aux_logits = self.head_aux(vertex_flat)
754
+
755
+ # Geometric alignment scores
756
+ geometric_alignments = self.geometric_proj(
757
+ features['patches'],
758
+ self.cls_tokens.class_pentachora
759
+ )
760
+
761
+ return {
762
+ 'logits': logits,
763
+ 'aux_logits': aux_logits,
764
+ 'geometric_alignments': geometric_alignments,
765
+ 'vertex_cls': features['vertex_cls'],
766
+ 'global_cls': features['global_cls'],
767
+ 'patches': features['patches']
768
+ }
769
+
770
+ # ============================================
771
+ # LOSS FUNCTIONS
772
+ # ============================================
773
+
774
+ class PentachoraLoss(nn.Module):
775
+ """Combined loss for PentachoraViT training."""
776
+ def __init__(self, aux_weight: float = 0.3, geo_weight: float = 0.1,
777
+ smoothing: float = 0.0):
778
+ super().__init__()
779
+ self.aux_weight = aux_weight
780
+ self.geo_weight = geo_weight
781
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing)
782
+
783
+ def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
784
+ """Compute combined loss."""
785
+ # Primary classification loss
786
+ loss = self.criterion(outputs['logits'], targets)
787
+
788
+ # Auxiliary loss from vertex tokens
789
+ if 'aux_logits' in outputs and self.aux_weight > 0:
790
+ aux_loss = self.criterion(outputs['aux_logits'], targets)
791
+ loss = loss + self.aux_weight * aux_loss
792
+
793
+ # Geometric alignment loss
794
+ if 'geometric_alignments' in outputs and self.geo_weight > 0:
795
+ geo_logits = outputs['geometric_alignments'].mean(dim=1)
796
+ geo_loss = self.criterion(geo_logits, targets)
797
+ loss = loss + self.geo_weight * geo_loss
798
+
799
+ return loss
800
+
801
+ # ============================================
802
+ # MODEL REGISTRY AND BUILDERS
803
+ # ============================================
804
+
805
+ MODEL_CONFIGS = {
806
+ 'pentachora_spark_xs': PentachoraConfig(
807
+ dim=100, depth=2, heads=10, mlp_ratio=4.0,
808
+ preserve_structure_until_layer=1,
809
+ dropout_rate=0.0, drop_path_rate=0.0
810
+ ),
811
+ 'pentachora_spark': PentachoraConfig(
812
+ dim=100, depth=5, heads=4, mlp_ratio=4.0,
813
+ preserve_structure_until_layer=1,
814
+ dropout_rate=0.0, drop_path_rate=0.0
815
+ ),
816
+ 'pentachora_shock': PentachoraConfig(
817
+ dim=100, depth=10, heads=5, mlp_ratio=4.0,
818
+ patch_size=5, preserve_structure_until_layer=5,
819
+ dropout_rate=0.0, drop_path_rate=0.0
820
+ ),
821
+ 'pentachora_shock_xs_32d': PentachoraConfig(
822
+ dim=32, depth=2, heads=8, mlp_ratio=4.0,
823
+ preserve_structure_until_layer=4,
824
+ dropout_rate=0.0, drop_path_rate=0.0
825
+ ),
826
+ 'pentachora_shock_xs_64d': PentachoraConfig(
827
+ dim=64, depth=2, heads=8, mlp_ratio=1.0,
828
+ preserve_structure_until_layer=4,
829
+ dropout_rate=0.0, drop_path_rate=0.0
830
+ ),
831
+ 'pentachora_shock_xs_128d': PentachoraConfig(
832
+ dim=128, depth=2, heads=8, mlp_ratio=2.0,
833
+ preserve_structure_until_layer=4,
834
+ vocab_dim=256,
835
+ dropout_rate=0.0, drop_path_rate=0.0
836
+ ),
837
+ 'vit_tinkerbell_128_patch8_h128_shallow': PentachoraConfig(
838
+ dim=128, depth=4, heads=128, mlp_ratio=4.0,
839
+ preserve_structure_until_layer=4,
840
+ vocab_dim=128, patch_size=8,
841
+ dropout_rate=0.0, drop_path_rate=0.0
842
+ ),
843
+ 'vit_tinkerbell_128_patch8_h128_base': PentachoraConfig(
844
+ dim=128, depth=8, heads=128, mlp_ratio=4.0,
845
+ preserve_structure_until_layer=8,
846
+ vocab_dim=128, patch_size=8,
847
+ dropout_rate=0.0, drop_path_rate=0.0
848
+ ),
849
+ 'vit_tinkerbell_128_patch8_h128_deep': PentachoraConfig(
850
+ dim=128, depth=16, heads=128, mlp_ratio=4.0,
851
+ preserve_structure_until_layer=16,
852
+ vocab_dim=128, patch_size=8,
853
+ dropout_rate=0.0, drop_path_rate=0.0
854
+ ),
855
+ 'vit_pixie_128_patch4_echo': PentachoraConfig(
856
+ dim=128, depth=5, heads=32, mlp_ratio=1.0,
857
+ preserve_structure_until_layer=5,
858
+ vocab_dim=128, patch_size=4,
859
+ dropout_rate=0.0, drop_path_rate=0.0
860
+ ),
861
+ 'vit_pixie_128_patch4_echo_h64': PentachoraConfig(
862
+ dim=128, depth=5, heads=64, mlp_ratio=1.0,
863
+ preserve_structure_until_layer=5,
864
+ vocab_dim=128, patch_size=4,
865
+ dropout_rate=0.0, drop_path_rate=0.0
866
+ ),
867
+ 'vit_pixie_128_patch4_echo_h128': PentachoraConfig(
868
+ dim=128, depth=5, heads=128, mlp_ratio=1.0,
869
+ preserve_structure_until_layer=5,
870
+ vocab_dim=128, patch_size=4,
871
+ dropout_rate=0.0, drop_path_rate=0.0
872
+ ),
873
+ 'vit_pixie_256_patch4_echo_h64': PentachoraConfig(
874
+ dim=256, depth=5, heads=64, mlp_ratio=1.0,
875
+ preserve_structure_until_layer=5,
876
+ vocab_dim=256, patch_size=4,
877
+ dropout_rate=0.0, drop_path_rate=0.0
878
+ ),
879
+ 'vit_pixie_256_patch4_echo_h256': PentachoraConfig(
880
+ dim=256, depth=5, heads=256, mlp_ratio=2.0,
881
+ preserve_structure_until_layer=5,
882
+ vocab_dim=256, patch_size=4,
883
+ dropout_rate=0.0, drop_path_rate=0.0
884
+ ),
885
+ 'vit_pixie_128_patch4': PentachoraConfig(
886
+ dim=128, depth=10, heads=16, mlp_ratio=1.0,
887
+ preserve_structure_until_layer=10,
888
+ vocab_dim=128, patch_size=4,
889
+ dropout_rate=0.0, drop_path_rate=0.0
890
+ ),
891
+ 'vit_pixie_256_patch4': PentachoraConfig(
892
+ dim=256, depth=10, heads=16, mlp_ratio=1.0,
893
+ preserve_structure_until_layer=10,
894
+ vocab_dim=256, patch_size=4,
895
+ dropout_rate=0.0, drop_path_rate=0.0
896
+ ),
897
+ 'vit_pixie_256_patch2': PentachoraConfig(
898
+ dim=256, depth=10, heads=16, mlp_ratio=1.0,
899
+ preserve_structure_until_layer=10,
900
+ vocab_dim=256, patch_size=2,
901
+ dropout_rate=0.0, drop_path_rate=0.0
902
+ ),
903
+ 'vit_pixie_256_patch8': PentachoraConfig(
904
+ dim=256, depth=10, heads=16, mlp_ratio=4.0,
905
+ preserve_structure_until_layer=10,
906
+ vocab_dim=256, patch_size=8,
907
+ dropout_rate=0.0, drop_path_rate=0.0
908
+ ),
909
+ 'vit_pixie_512_patch4': PentachoraConfig(
910
+ dim=512, depth=10, heads=8, mlp_ratio=4.0,
911
+ preserve_structure_until_layer=10,
912
+ vocab_dim=512,
913
+ dropout_rate=0.0, drop_path_rate=0.0
914
+ ),
915
+ 'pentachora_shock_xs_256d': PentachoraConfig(
916
+ dim=256, depth=2, heads=8, mlp_ratio=4.0,
917
+ preserve_structure_until_layer=4,
918
+ vocab_dim=128,
919
+ dropout_rate=0.0, drop_path_rate=0.0
920
+ ),
921
+
922
+ 'pentachora_shock_xs_512d': PentachoraConfig(
923
+ dim=512, depth=2, heads=8, mlp_ratio=4.0,
924
+ preserve_structure_until_layer=4,
925
+ dropout_rate=0.0, drop_path_rate=0.0
926
+ ),
927
+ 'pentachora_tiny': PentachoraConfig(
928
+ dim=384, depth=12, heads=6, mlp_ratio=4.0,
929
+ preserve_structure_until_layer=6,
930
+ dropout_rate=0.1, drop_path_rate=0.1
931
+ ),
932
+ 'pentachora_small': PentachoraConfig(
933
+ dim=512, depth=12, heads=8, mlp_ratio=4.0,
934
+ preserve_structure_until_layer=6,
935
+ dropout_rate=0.1, drop_path_rate=0.1
936
+ ),
937
+ 'pentachora_base': PentachoraConfig(
938
+ dim=768, depth=12, heads=12, mlp_ratio=4.0,
939
+ preserve_structure_until_layer=8,
940
+ dropout_rate=0.1, drop_path_rate=0.2
941
+ ),
942
+ 'pentachora_large': PentachoraConfig(
943
+ dim=1024, depth=24, heads=16, mlp_ratio=4.0,
944
+ preserve_structure_until_layer=12,
945
+ dropout_rate=0.1, drop_path_rate=0.3
946
+ ),
947
+ }
948
+
949
+ def create_pentachora_vit(variant: str = 'pentachora_small',
950
+ pretrained: bool = False,
951
+ **kwargs) -> PentachoraViT:
952
+ """Create PentachoraViT model."""
953
+ if variant not in MODEL_CONFIGS:
954
+ raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}")
955
+
956
+ config = MODEL_CONFIGS[variant]
957
+
958
+ # Override config with kwargs
959
+ for key, value in kwargs.items():
960
+ setattr(config, key, value)
961
+
962
+ model = PentachoraViT(config)
963
+
964
+ if pretrained:
965
+ warnings.warn("Pretrained weights not available yet")
966
+
967
+ return model
968
+
969
+ # Convenience functions for each variant
970
+ def pentachora_vit_spark_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT:
971
+ """Create spark variant (smallest)."""
972
+ return create_pentachora_vit('pentachora_spark_xs', pretrained=pretrained, **kwargs)
973
+
974
+ def pentachora_shock_xs_64d(pretrained: bool = False, **kwargs) -> PentachoraViT:
975
+ """Create shock xs 64d variant."""
976
+ return create_pentachora_vit('pentachora_shock_xs_64d', pretrained=pretrained, **kwargs)
977
+
978
+ def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT:
979
+ """Create spark variant."""
980
+ return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs)
981
+
982
+ def pentachora_shock_xs_32d(pretrained: bool = False, **kwargs) -> PentachoraViT:
983
+ """Create shock xs 32d variant."""
984
+ return create_pentachora_vit('pentachora_shock_xs_32d', pretrained=pretrained, **kwargs)
985
+
986
+ def pentachora_shock_xs_256d(pretrained: bool = False, **kwargs) -> PentachoraViT:
987
+ """Create shock xs 256d variant."""
988
+ return create_pentachora_vit('pentachora_shock_xs_256d', pretrained=pretrained, **kwargs)
989
+
990
+ def pentachora_shock_xs_512d(pretrained: bool = False, **kwargs) -> PentachoraViT:
991
+ """Create shock xs 512d variant."""
992
+ return create_pentachora_vit('pentachora_shock_xs_512d', pretrained=pretrained, **kwargs)
993
+
994
+ def pentachora_vit_shock(pretrained: bool = False, **kwargs) -> PentachoraViT:
995
+ """Create shock variant."""
996
+ return create_pentachora_vit('pentachora_shock', pretrained=pretrained, **kwargs)
997
+
998
+ def pentachora_vit_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT:
999
+ """Create tiny variant."""
1000
+ return create_pentachora_vit('pentachora_tiny', pretrained=pretrained, **kwargs)
1001
+
1002
+ def pentachora_vit_small(pretrained: bool = False, **kwargs) -> PentachoraViT:
1003
+ """Create small variant."""
1004
+ return create_pentachora_vit('pentachora_small', pretrained=pretrained, **kwargs)
1005
+
1006
+ def pentachora_vit_base(pretrained: bool = False, **kwargs) -> PentachoraViT:
1007
+ """Create base variant."""
1008
+ return create_pentachora_vit('pentachora_base', pretrained=pretrained, **kwargs)
1009
+
1010
+ def pentachora_vit_large(pretrained: bool = False, **kwargs) -> PentachoraViT:
1011
+ """Create large variant."""
1012
+ return create_pentachora_vit('pentachora_large', pretrained=pretrained, **kwargs)
1013
+
1014
+ # ============================================
1015
+ # TRAINING UTILITIES
1016
+ # ============================================
1017
+
1018
+ def get_parameter_groups(model: PentachoraViT,
1019
+ weight_decay: float = 0.05) -> List[Dict[str, Any]]:
1020
+ """Get parameter groups for optimizer with weight decay handling."""
1021
+ no_decay = ['bias', 'norm', 'LayerNorm']
1022
+
1023
+ decay_params = []
1024
+ no_decay_params = []
1025
+
1026
+ for name, param in model.named_parameters():
1027
+ if not param.requires_grad:
1028
+ continue
1029
+
1030
+ if any(nd in name for nd in no_decay):
1031
+ no_decay_params.append(param)
1032
+ else:
1033
+ decay_params.append(param)
1034
+
1035
+ return [
1036
+ {'params': decay_params, 'weight_decay': weight_decay},
1037
+ {'params': no_decay_params, 'weight_decay': 0.0}
1038
+ ]
1039
+
1040
+ def count_parameters(model: nn.Module) -> Dict[str, int]:
1041
+ """Count model parameters."""
1042
+ total = sum(p.numel() for p in model.parameters())
1043
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
1044
+ return {
1045
+ 'total': total,
1046
+ 'trainable': trainable,
1047
+ 'non_trainable': total - trainable
1048
+ }
1049
+
1050
+ # ============================================
1051
+ # INFERENCE UTILITIES
1052
+ # ============================================
1053
+
1054
+ @torch.no_grad()
1055
+ def extract_features(model: PentachoraViT,
1056
+ images: torch.Tensor,
1057
+ feature_type: str = 'global_cls') -> torch.Tensor:
1058
+ """Extract features from images using the model."""
1059
+ model.eval()
1060
+ features = model.forward_features(images)
1061
+ return features.get(feature_type, features['global_cls'])
1062
+
1063
+ # ============================================
1064
+ # EXAMPLE USAGE AND TESTING
1065
+ # ============================================
1066
+
1067
+ def test_model():
1068
+ """Test model creation and forward pass."""
1069
+ print("Testing Fixed PentachoraViT Model")
1070
+ print("=" * 50)
1071
+
1072
+ # Test different variants
1073
+ variants = ['pentachora_spark', 'pentachora_shock_xs_256d', 'pentachora_small']
1074
+
1075
+ for variant in variants:
1076
+ print(f"\nTesting {variant}:")
1077
+
1078
+ # Create model with vocab_dim
1079
+ model = create_pentachora_vit(
1080
+ variant=variant,
1081
+ img_size=32,
1082
+ patch_size=4,
1083
+ num_classes=100,
1084
+ vocab_dim=64
1085
+ )
1086
+
1087
+ # Count parameters
1088
+ params = count_parameters(model)
1089
+ print(f" Total parameters: {params['total']:,}")
1090
+ print(f" Trainable parameters: {params['trainable']:,}")
1091
+
1092
+ # Test forward pass
1093
+ x = torch.randn(2, 3, 32, 32)
1094
+
1095
+ # Time the forward pass
1096
+ if torch.cuda.is_available():
1097
+ model = model.cuda()
1098
+ x = x.cuda()
1099
+ torch.cuda.synchronize()
1100
+
1101
+ import time
1102
+ start = time.time()
1103
+ outputs = model(x)
1104
+ if torch.cuda.is_available():
1105
+ torch.cuda.synchronize()
1106
+ end = time.time()
1107
+
1108
+ print(f" Output shapes:")
1109
+ print(f" Logits: {outputs['logits'].shape}")
1110
+ print(f" Aux logits: {outputs['aux_logits'].shape}")
1111
+ print(f" Geometric alignments: {outputs['geometric_alignments'].shape}")
1112
+ print(f" Forward pass time: {(end - start)*1000:.2f}ms")
1113
+
1114
+ # Test loss computation
1115
+ loss_fn = PentachoraLoss()
1116
+ targets = torch.randint(0, 100, (2,))
1117
+ if torch.cuda.is_available():
1118
+ targets = targets.cuda()
1119
+ loss = loss_fn(outputs, targets)
1120
+ print(f" Loss: {loss.item():.4f}")
1121
+
1122
+ print("\n" + "=" * 50)
1123
+ print("All tests passed!")
1124
+
1125
+ if __name__ == "__main__":
1126
+ # Run tests
1127
+ test_model()
1128
+
1129
+ # Example: Create model for training
1130
+ print("\nExample: Creating model with proper initialization")
1131
+ model = pentachora_shock_xs_256d(
1132
+ img_size=32,
1133
+ num_classes=100,
1134
+ vocab_dim=100,
1135
+ dropout_rate=0.0,
1136
+ drop_path_rate=0.0
1137
+ )
1138
+
1139
+ # All parameters are initialized immediately
1140
+ print(f"Model has {count_parameters(model)['total']:,} parameters")
1141
+ print("All geometric parameters initialized at creation time")
1142
+
1143
+ # Move model to CUDA if available
1144
+ if torch.cuda.is_available():
1145
+ model = model.cuda()
1146
+ print("Model moved to CUDA")
1147
+
1148
+ # Now torch.compile should work without issues
1149
+ if hasattr(torch, 'compile'):
1150
+ print("Compiling model with torch.compile...")
1151
+ try:
1152
+ model = torch.compile(model)
1153
+ print("✓ Model compiled successfully")
1154
+ except Exception as e:
1155
+ print(f"Compilation warning: {e}")
1156
+ print("Continuing without compilation")
1157
+
1158
+ print("\nModel ready for training with all parameters properly initialized!")