AbstractPhil commited on
Commit
ec394f4
·
verified ·
1 Parent(s): daf8c5e

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +950 -0
model.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PentachoraViT: Vision Transformer with Pentachoron Geometric Structure
3
+ Enhanced with Geometric Attention for improved head cohesion and generalization
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from einops import rearrange, repeat
11
+ import math
12
+ from typing import Optional, Dict, Tuple, List, Any
13
+ from dataclasses import dataclass
14
+ import warnings
15
+
16
+ # ============================================
17
+ # CONFIGURATION CLASSES
18
+ # ============================================
19
+
20
+ @dataclass
21
+ class PentachoraConfig:
22
+ """Configuration for PentachoraViT models."""
23
+ img_size: int = 32
24
+ patch_size: int = 4
25
+ num_classes: int = 100
26
+ dim: int = 512
27
+ vocab_dim: Optional[int] = None # Vocabulary dimension (can differ from model dim)
28
+ depth: int = 12
29
+ heads: int = 8
30
+ mlp_ratio: float = 4.0
31
+ use_mesh_attention: bool = True
32
+ preserve_structure_until_layer: int = 6
33
+ dropout_rate: float = 0.1
34
+ drop_path_rate: float = 0.1
35
+ aux_loss_weight: float = 0.3
36
+ geo_loss_weight: float = 0.1
37
+ vocab: Optional[Any] = None
38
+
39
+ @property
40
+ def num_patches(self) -> int:
41
+ return (self.img_size // self.patch_size) ** 2
42
+
43
+ # ============================================
44
+ # GEOMETRIC ATTENTION COMPONENTS
45
+ # ============================================
46
+
47
+ def perfect_4simplex(device):
48
+ """Create perfect 4-simplex (pentachoron) vertices in 4D."""
49
+ sqrt5 = math.sqrt(5)
50
+ vertices = torch.tensor([
51
+ [1, 1, 1, -1/sqrt5],
52
+ [1, -1, -1, -1/sqrt5],
53
+ [-1, 1, -1, -1/sqrt5],
54
+ [-1, -1, 1, -1/sqrt5],
55
+ [0, 0, 0, 4/sqrt5]
56
+ ], device=device, dtype=torch.float32)
57
+ return vertices / 2 # Normalize scale
58
+
59
+ def softmin_over_last(distances, tau):
60
+ """Softmin over last dimension."""
61
+ return F.softmax(-distances / tau, dim=-1).sum(dim=-1)
62
+
63
+ @dataclass
64
+ class GeometricConfig:
65
+ """Configuration for geometric attention."""
66
+ softmin_tau: float = 0.05
67
+ fuse_alpha: float = 0.7
68
+ phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2)
69
+ jitter: float = 0.02
70
+ shift: float = 0.25
71
+ rotate_cycle: int = 11
72
+ use_phase_variance: bool = False
73
+ geometry_type: str = "pentachoron"
74
+
75
+ class GeometricNavigator(nn.Module):
76
+ """Maps inputs to geometric regions in 4D space."""
77
+
78
+ def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig):
79
+ super().__init__()
80
+ self.input_dim = input_dim
81
+ self.num_regions = num_regions
82
+ self.config = config
83
+
84
+ self.to_nav = nn.Linear(input_dim, 4, bias=False)
85
+ self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5))
86
+
87
+ # Initialize geometry after module is created
88
+ self.register_parameter('D', None)
89
+ self.register_parameter('S', None)
90
+
91
+ def _lazy_init_geometry(self, device):
92
+ """Initialize geometry on first forward pass."""
93
+ if self.D is not None:
94
+ return
95
+
96
+ base = perfect_4simplex(device)
97
+
98
+ D = torch.zeros(self.num_regions, 5, 4, device=device)
99
+ S = torch.zeros(self.num_regions, 5, 4, device=device)
100
+
101
+ for r in range(self.num_regions):
102
+ D[r] = base + self.config.jitter * torch.randn_like(base)
103
+
104
+ theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device)
105
+ rot = torch.eye(4, device=device)
106
+ c, s_val = torch.cos(theta), torch.sin(theta)
107
+ rot[0, 0] = c; rot[0, 1] = -s_val
108
+ rot[1, 0] = s_val; rot[1, 1] = c
109
+ S[r] = (base @ rot) + self.config.shift
110
+ S[r] += self.config.jitter * torch.randn_like(S[r])
111
+
112
+ self.D = nn.Parameter(D)
113
+ self.S = nn.Parameter(S)
114
+
115
+ def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
116
+ """Navigate inputs through geometric space."""
117
+ self._lazy_init_geometry(x.device)
118
+
119
+ nav_x = self.to_nav(x)
120
+ nav_x_exp = nav_x[:, None, None, :]
121
+ D_exp = self.D[None, :, :, :]
122
+
123
+ d_disp = torch.norm(nav_x_exp - D_exp, dim=-1)
124
+ s_disp = -softmin_over_last(d_disp, self.config.softmin_tau)
125
+
126
+ w = F.softmax(self.vertex_w, dim=1)
127
+ phase_scores = []
128
+
129
+ for phase in self.config.phases:
130
+ phase_tensor = torch.tensor(phase, device=x.device)
131
+ ct = torch.cos(phase_tensor)
132
+ st = torch.sin(phase_tensor)
133
+
134
+ Vt = ct * self.D + st * self.S
135
+ w_expanded = w.unsqueeze(-1)
136
+ Vt_mean = Vt.mean(dim=1, keepdim=True)
137
+ Vt = (1.0 - w_expanded) * Vt + w_expanded * Vt_mean
138
+
139
+ Vt_exp = Vt[None, :, :, :]
140
+ d_ribbon = torch.norm(nav_x_exp - Vt_exp, dim=-1)
141
+ s_ribbon = -softmin_over_last(d_ribbon, self.config.softmin_tau)
142
+ phase_scores.append(s_ribbon)
143
+
144
+ s_ribbon = torch.stack(phase_scores).mean(dim=0)
145
+ scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp
146
+
147
+ diagnostics = {
148
+ 'dispatcher_scores': s_disp.detach(),
149
+ 'ribbon_scores': s_ribbon.detach()
150
+ }
151
+
152
+ return {'scores': scores, 'diagnostics': diagnostics}
153
+
154
+ class GeometricAttention(nn.Module):
155
+ """Multi-head geometric attention with Q-K alignment."""
156
+
157
+ def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None,
158
+ config: Optional[GeometricConfig] = None, dropout: float = 0.0):
159
+ super().__init__()
160
+ self.dim = dim
161
+ self.num_heads = num_heads
162
+ self.head_dim = dim // num_heads
163
+
164
+ if num_regions is None:
165
+ num_regions = min(self.head_dim, 16)
166
+ if config is None:
167
+ config = GeometricConfig()
168
+
169
+ self.config = config
170
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
171
+
172
+ self.q_navigators = nn.ModuleList([
173
+ GeometricNavigator(self.head_dim, num_regions, config)
174
+ for _ in range(num_heads)
175
+ ])
176
+ self.k_navigators = nn.ModuleList([
177
+ GeometricNavigator(self.head_dim, num_regions, config)
178
+ for _ in range(num_heads)
179
+ ])
180
+
181
+ self.out_proj = nn.Linear(dim, dim)
182
+ self.dropout = nn.Dropout(dropout)
183
+
184
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None,
185
+ return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]:
186
+ B, T, D = x.shape
187
+
188
+ qkv = self.to_qkv(x)
189
+ q, k, v = qkv.chunk(3, dim=-1)
190
+
191
+ q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
192
+ k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
193
+ v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
194
+
195
+ outputs = []
196
+ all_diagnostics = [] if return_diagnostics else None
197
+
198
+ for h in range(self.num_heads):
199
+ q_h_flat = q[:, h].reshape(B * T, self.head_dim)
200
+ k_h_flat = k[:, h].reshape(B * T, self.head_dim)
201
+
202
+ q_nav = self.q_navigators[h].navigate(q_h_flat)
203
+ k_nav = self.k_navigators[h].navigate(k_h_flat)
204
+
205
+ q_scores = q_nav['scores'].reshape(B, T, -1)
206
+ k_scores = k_nav['scores'].reshape(B, T, -1)
207
+
208
+ attn = torch.bmm(q_scores, k_scores.transpose(1, 2))
209
+ attn = attn / math.sqrt(q_scores.size(-1))
210
+
211
+ if mask is not None:
212
+ attn = attn.masked_fill(mask.unsqueeze(1) == 0, -1e9)
213
+
214
+ attn = F.softmax(attn, dim=-1)
215
+ attn = self.dropout(attn)
216
+
217
+ out = torch.bmm(attn, v[:, h])
218
+ outputs.append(out)
219
+
220
+ if return_diagnostics:
221
+ all_diagnostics.append({'q': q_nav['diagnostics'], 'k': k_nav['diagnostics']})
222
+
223
+ output = torch.stack(outputs, dim=1).transpose(1, 2).reshape(B, T, D)
224
+ output = self.out_proj(output)
225
+ output = self.dropout(output)
226
+
227
+ if return_diagnostics:
228
+ return output, {'head_diagnostics': all_diagnostics}
229
+ return output, None
230
+
231
+ # ============================================
232
+ # DROP PATH (STOCHASTIC DEPTH)
233
+ # ============================================
234
+
235
+ class DropPath(nn.Module):
236
+ """Drop paths (Stochastic Depth) per sample."""
237
+ def __init__(self, drop_prob: float = 0.):
238
+ super().__init__()
239
+ self.drop_prob = drop_prob
240
+
241
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
242
+ if self.drop_prob == 0. or not self.training:
243
+ return x
244
+ keep_prob = 1 - self.drop_prob
245
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
246
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
247
+ random_tensor.floor_()
248
+ output = x.div(keep_prob) * random_tensor
249
+ return output
250
+
251
+ # ============================================
252
+ # HIERARCHICAL CLS WITH PENTACHORA
253
+ # ============================================
254
+
255
+ class HierarchicalPentachoronCLS(nn.Module):
256
+ """
257
+ Hierarchical CLS structure with pentachoron geometry.
258
+ Creates global, vertex-level, and class-specific representations.
259
+ """
260
+ def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100):
261
+ super().__init__()
262
+ self.dim = dim # Model's internal dimension
263
+ self.vocab_dim = vocab_dim # Vocabulary's dimension
264
+ self.num_classes = num_classes
265
+
266
+ # Hierarchical CLS tokens (in model dimension)
267
+ self.global_cls = nn.Parameter(torch.randn(1, 1, dim) * 0.02)
268
+ self.vertex_cls = nn.Parameter(torch.randn(1, 5, dim) * 0.02)
269
+
270
+ # Class-specific pentachora (in vocabulary dimension)
271
+ self.class_pentachora = nn.Parameter(torch.randn(num_classes, 5, vocab_dim) * 0.02)
272
+
273
+ # Projection layer to align vocab_dim with model dim if they differ
274
+ if vocab_dim != dim:
275
+ self.vocab_projection = nn.Linear(vocab_dim, dim)
276
+ else:
277
+ self.vocab_projection = nn.Identity()
278
+
279
+ # Aggregation layers
280
+ self.vertex_to_global = nn.Linear(dim * 5, dim)
281
+ self.norm = nn.LayerNorm(dim)
282
+
283
+ def forward(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
284
+ """Generate CLS tokens for batch."""
285
+ global_cls = self.global_cls.expand(batch_size, -1, -1)
286
+ vertex_cls = self.vertex_cls.expand(batch_size, -1, -1)
287
+ return global_cls, vertex_cls
288
+
289
+ def aggregate_vertices(self, vertex_cls: torch.Tensor) -> torch.Tensor:
290
+ """Aggregate vertex representations to global."""
291
+ B = vertex_cls.shape[0]
292
+ flattened = vertex_cls.reshape(B, -1)
293
+ aggregated = self.vertex_to_global(flattened).unsqueeze(1)
294
+ return self.norm(aggregated)
295
+
296
+ # ============================================
297
+ # GEOMETRIC PROJECTION LAYER
298
+ # ============================================
299
+
300
+ class GeometricProjection(nn.Module):
301
+ """Project patches onto pentachoron geometry."""
302
+ def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1):
303
+ super().__init__()
304
+ self.dim = dim # Model dimension
305
+ self.vocab_dim = vocab_dim # Vocabulary dimension
306
+ self.num_classes = num_classes
307
+
308
+ # Separate projection for each vertex (project from model dim to vocab dim for alignment)
309
+ self.vertex_projections = nn.ModuleList([
310
+ nn.Linear(dim, vocab_dim, bias=False) for _ in range(5)
311
+ ])
312
+
313
+ self.norm = nn.LayerNorm(dim)
314
+ self.dropout = nn.Dropout(dropout)
315
+
316
+ def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor:
317
+ """
318
+ Compute alignment between patches and class pentachora.
319
+
320
+ Args:
321
+ patches: [B, N, D] - patch embeddings
322
+ pentachora: [C, 5, vocab_dim] - class pentachora
323
+
324
+ Returns:
325
+ [B, N, C] - alignment scores
326
+ """
327
+ B, N, D = patches.shape
328
+ C = pentachora.shape[0]
329
+
330
+ patches = self.norm(patches)
331
+
332
+ # Compute alignment with each vertex
333
+ alignments = []
334
+ for v in range(5):
335
+ # Project patches through vertex-specific projection
336
+ patches_proj = self.vertex_projections[v](patches)
337
+ patches_proj = F.normalize(patches_proj, dim=-1)
338
+
339
+ # Get vertex v of all classes
340
+ vertex_v = F.normalize(pentachora[:, v, :], dim=-1)
341
+
342
+ # Compute alignment scores
343
+ alignment = torch.matmul(patches_proj, vertex_v.T)
344
+ alignments.append(alignment)
345
+
346
+ # Average alignments across vertices
347
+ alignments = torch.stack(alignments, dim=-1).mean(dim=-1)
348
+
349
+ return self.dropout(alignments)
350
+
351
+ # ============================================
352
+ # MLP BLOCK
353
+ # ============================================
354
+
355
+ class MLP(nn.Module):
356
+ """MLP block with GELU activation."""
357
+ def __init__(self, in_features: int, hidden_features: Optional[int] = None,
358
+ out_features: Optional[int] = None, dropout: float = 0.):
359
+ super().__init__()
360
+ out_features = out_features or in_features
361
+ hidden_features = hidden_features or in_features
362
+
363
+ self.fc1 = nn.Linear(in_features, hidden_features)
364
+ self.act = nn.GELU()
365
+ self.drop1 = nn.Dropout(dropout)
366
+ self.fc2 = nn.Linear(hidden_features, out_features)
367
+ self.drop2 = nn.Dropout(dropout)
368
+
369
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
370
+ x = self.fc1(x)
371
+ x = self.act(x)
372
+ x = self.drop1(x)
373
+ x = self.fc2(x)
374
+ x = self.drop2(x)
375
+ return x
376
+
377
+ # ============================================
378
+ # VIT BLOCK WITH GEOMETRIC ATTENTION
379
+ # ============================================
380
+
381
+ class PentachoronViTBlock(nn.Module):
382
+ """ViT block with geometric attention for structured layers."""
383
+ def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0,
384
+ use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0.,
385
+ drop_path: float = 0.):
386
+ super().__init__()
387
+ self.norm1 = nn.LayerNorm(dim)
388
+
389
+ # Use GeometricAttention for structured layers, standard for others
390
+ if use_mesh:
391
+ self.attn = GeometricAttention(
392
+ dim=dim,
393
+ num_heads=heads,
394
+ num_regions=min(dim // heads, 16),
395
+ config=GeometricConfig(),
396
+ dropout=attn_dropout
397
+ )
398
+ else:
399
+ # Standard multi-head attention for later layers
400
+ self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True)
401
+
402
+ self.use_mesh = use_mesh
403
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
404
+
405
+ self.norm2 = nn.LayerNorm(dim)
406
+ mlp_hidden = int(dim * mlp_ratio)
407
+ self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout)
408
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
409
+
410
+ def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor:
411
+ if self.use_mesh:
412
+ # GeometricAttention
413
+ attn_out, _ = self.attn(self.norm1(x))
414
+ x = x + self.drop_path1(attn_out)
415
+ else:
416
+ # Standard attention
417
+ normalized = self.norm1(x)
418
+ attn_out, _ = self.attn(normalized, normalized, normalized)
419
+ x = x + self.drop_path1(attn_out)
420
+
421
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
422
+ return x
423
+
424
+ # ============================================
425
+ # PATCH EMBEDDING
426
+ # ============================================
427
+
428
+ class PatchEmbed(nn.Module):
429
+ """2D Image to Patch Embedding."""
430
+ def __init__(self, img_size: int = 32, patch_size: int = 4,
431
+ in_chans: int = 3, embed_dim: int = 512):
432
+ super().__init__()
433
+ self.img_size = img_size
434
+ self.patch_size = patch_size
435
+ self.num_patches = (img_size // patch_size) ** 2
436
+
437
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
438
+ self.norm = nn.LayerNorm(embed_dim)
439
+
440
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
441
+ x = self.proj(x)
442
+ x = rearrange(x, 'b c h w -> b (h w) c')
443
+ x = self.norm(x)
444
+ return x
445
+
446
+ # ============================================
447
+ # PENTACHORA VISION TRANSFORMER
448
+ # ============================================
449
+
450
+ class PentachoraViT(nn.Module):
451
+ """
452
+ Vision Transformer with pentachoron-based hierarchical CLS tokens
453
+ and geometric vocabulary integration.
454
+ """
455
+ def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs):
456
+ super().__init__()
457
+
458
+ # Use config or kwargs
459
+ if config is not None:
460
+ cfg = config
461
+ else:
462
+ cfg = PentachoraConfig(**kwargs)
463
+
464
+ self.config = cfg
465
+ self.num_classes = cfg.num_classes
466
+ self.dim = cfg.dim
467
+ self.depth = cfg.depth
468
+ self.preserve_structure_until_layer = cfg.preserve_structure_until_layer
469
+
470
+ # Set vocabulary dimension - from config, kwargs, or default to model dim
471
+ if cfg.vocab_dim is not None:
472
+ self.vocab_dim = cfg.vocab_dim
473
+ elif 'vocab_dim' in kwargs:
474
+ self.vocab_dim = kwargs['vocab_dim']
475
+ else:
476
+ self.vocab_dim = cfg.dim
477
+
478
+ # Patch embedding
479
+ self.patch_embed = PatchEmbed(
480
+ cfg.img_size, cfg.patch_size, 3, cfg.dim
481
+ )
482
+ num_patches = self.patch_embed.num_patches
483
+
484
+ # Positional embedding
485
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02)
486
+ self.pos_drop = nn.Dropout(cfg.dropout_rate)
487
+
488
+ # CLS tokens with pentachoron structure
489
+ self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes)
490
+
491
+ # Geometric projection layer - CREATE BEFORE vocab init
492
+ self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate)
493
+
494
+ # Initialize from vocabulary AFTER creating all components
495
+ if cfg.vocab is not None:
496
+ self._init_from_vocab(cfg.vocab)
497
+
498
+ # Stochastic depth decay rule
499
+ dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)]
500
+
501
+ # Transformer blocks with geometric attention
502
+ self.blocks = nn.ModuleList([
503
+ PentachoronViTBlock(
504
+ dim=cfg.dim,
505
+ heads=cfg.heads,
506
+ mlp_ratio=cfg.mlp_ratio,
507
+ use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer),
508
+ dropout=cfg.dropout_rate,
509
+ attn_dropout=cfg.dropout_rate,
510
+ drop_path=dpr[i]
511
+ )
512
+ for i in range(cfg.depth)
513
+ ])
514
+
515
+ # Final norm
516
+ self.norm = nn.LayerNorm(cfg.dim)
517
+
518
+ # Classification heads
519
+ self.head = nn.Linear(cfg.dim, cfg.num_classes)
520
+ self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes)
521
+
522
+ # Initialize weights
523
+ self.apply(self._init_weights)
524
+
525
+ def _init_weights(self, m: nn.Module):
526
+ """Initialize model weights."""
527
+ if isinstance(m, nn.Linear):
528
+ nn.init.trunc_normal_(m.weight, std=0.02)
529
+ if m.bias is not None:
530
+ nn.init.constant_(m.bias, 0)
531
+ elif isinstance(m, nn.LayerNorm):
532
+ nn.init.constant_(m.bias, 0)
533
+ nn.init.constant_(m.weight, 1.0)
534
+ elif isinstance(m, nn.Conv2d):
535
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
536
+ if m.bias is not None:
537
+ nn.init.constant_(m.bias, 0)
538
+
539
+ def _init_from_vocab(self, vocab):
540
+ """Initialize class pentachora from geometric vocabulary."""
541
+ try:
542
+ print("Initializing pentachora from vocabulary...")
543
+
544
+ if not hasattr(vocab, 'encode_batch'):
545
+ print("Vocabulary provided but encode_batch method not found, using random initialization")
546
+ return
547
+
548
+ # Get CIFAR-100 class names
549
+ class_names = self._get_cifar100_classes()
550
+
551
+ # Generate pentachora for all classes
552
+ pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True)
553
+ pentachora = np.stack(pentachora_list, axis=0)
554
+
555
+ # Get actual dimensions from the encoded data
556
+ actual_vocab_dim = pentachora.shape[-1]
557
+
558
+ print(f"Encoded pentachora shape: {pentachora.shape}")
559
+ print(f"Detected vocabulary dimension: {actual_vocab_dim}")
560
+
561
+ # Validate basic shape requirements
562
+ if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5:
563
+ print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}")
564
+ print("Using random initialization")
565
+ return
566
+
567
+ # Update all components to use the actual vocabulary dimension
568
+ self.vocab_dim = actual_vocab_dim
569
+ self.cls_tokens.vocab_dim = actual_vocab_dim
570
+ self.geometric_proj.vocab_dim = actual_vocab_dim
571
+
572
+ # Replace class_pentachora with the loaded vocabulary
573
+ self.cls_tokens.class_pentachora = nn.Parameter(
574
+ torch.tensor(pentachora, dtype=torch.float32)
575
+ )
576
+
577
+ # Update/create projection layer if dimensions differ
578
+ if actual_vocab_dim != self.dim:
579
+ self.cls_tokens.vocab_projection = nn.Linear(actual_vocab_dim, self.dim)
580
+ else:
581
+ self.cls_tokens.vocab_projection = nn.Identity()
582
+
583
+ # Rebuild geometric projection layers with correct dimensions
584
+ self.geometric_proj.vertex_projections = nn.ModuleList([
585
+ nn.Linear(self.dim, actual_vocab_dim, bias=False) for _ in range(5)
586
+ ])
587
+
588
+ # Re-initialize the new layers
589
+ for proj in self.geometric_proj.vertex_projections:
590
+ nn.init.xavier_uniform_(proj.weight)
591
+ if actual_vocab_dim != self.dim:
592
+ nn.init.xavier_uniform_(self.cls_tokens.vocab_projection.weight)
593
+
594
+ print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary")
595
+ print(f" Vocabulary dimension: {actual_vocab_dim}")
596
+ print(f" Model internal dimension: {self.dim}")
597
+ print(f" Projection: {actual_vocab_dim} → {self.dim}")
598
+
599
+ except Exception as e:
600
+ print(f"Error initializing from vocabulary: {e}")
601
+ print("Using random initialization")
602
+
603
+ def _get_cifar100_classes(self):
604
+ """Get CIFAR-100 class names."""
605
+ return [
606
+ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
607
+ 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
608
+ 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
609
+ 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
610
+ 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
611
+ 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
612
+ 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
613
+ 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
614
+ 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
615
+ 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
616
+ 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
617
+ 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
618
+ 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
619
+ 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
620
+ ]
621
+
622
+ def forward_features(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
623
+ """Extract features from input."""
624
+ B = x.shape[0]
625
+
626
+ # Patch embedding
627
+ x = self.patch_embed(x)
628
+ x = x + self.pos_embed
629
+ x = self.pos_drop(x)
630
+
631
+ # Get hierarchical CLS tokens
632
+ global_cls, vertex_cls = self.cls_tokens(B)
633
+
634
+ # Concatenate CLS tokens with patches
635
+ x = torch.cat([global_cls, vertex_cls, x], dim=1)
636
+
637
+ # Apply transformer blocks
638
+ for i, block in enumerate(self.blocks):
639
+ preserve = i < self.preserve_structure_until_layer
640
+ x = block(x, preserve_structure=preserve)
641
+
642
+ # Apply final norm
643
+ x = self.norm(x)
644
+
645
+ # Split tokens
646
+ global_cls = x[:, 0]
647
+ vertex_cls = x[:, 1:6]
648
+ patches = x[:, 6:]
649
+
650
+ return {
651
+ 'global_cls': global_cls,
652
+ 'vertex_cls': vertex_cls,
653
+ 'patches': patches
654
+ }
655
+
656
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
657
+ """Forward pass through the model."""
658
+ features = self.forward_features(x)
659
+
660
+ # Primary classification using global CLS
661
+ logits = self.head(features['global_cls'])
662
+
663
+ # Auxiliary classification using vertex tokens
664
+ B = features['vertex_cls'].shape[0]
665
+ vertex_flat = features['vertex_cls'].reshape(B, -1)
666
+ aux_logits = self.head_aux(vertex_flat)
667
+
668
+ # Geometric alignment scores
669
+ geometric_alignments = self.geometric_proj(
670
+ features['patches'],
671
+ self.cls_tokens.class_pentachora
672
+ )
673
+
674
+ return {
675
+ 'logits': logits,
676
+ 'aux_logits': aux_logits,
677
+ 'geometric_alignments': geometric_alignments,
678
+ 'vertex_cls': features['vertex_cls'],
679
+ 'global_cls': features['global_cls'],
680
+ 'patches': features['patches']
681
+ }
682
+
683
+ # ============================================
684
+ # LOSS FUNCTIONS
685
+ # ============================================
686
+
687
+ class PentachoraLoss(nn.Module):
688
+ """Combined loss for PentachoraViT training."""
689
+ def __init__(self, aux_weight: float = 0.3, geo_weight: float = 0.1,
690
+ smoothing: float = 0.0):
691
+ super().__init__()
692
+ self.aux_weight = aux_weight
693
+ self.geo_weight = geo_weight
694
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing)
695
+
696
+ def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor:
697
+ """Compute combined loss."""
698
+ # Primary classification loss
699
+ loss = self.criterion(outputs['logits'], targets)
700
+
701
+ # Auxiliary loss from vertex tokens
702
+ if 'aux_logits' in outputs and self.aux_weight > 0:
703
+ aux_loss = self.criterion(outputs['aux_logits'], targets)
704
+ loss = loss + self.aux_weight * aux_loss
705
+
706
+ # Geometric alignment loss
707
+ if 'geometric_alignments' in outputs and self.geo_weight > 0:
708
+ # Average over patches
709
+ geo_logits = outputs['geometric_alignments'].mean(dim=1)
710
+ geo_loss = self.criterion(geo_logits, targets)
711
+ loss = loss + self.geo_weight * geo_loss
712
+
713
+ return loss
714
+
715
+ # ============================================
716
+ # MODEL REGISTRY AND BUILDERS
717
+ # ============================================
718
+
719
+ MODEL_CONFIGS = {
720
+ 'pentachora_spark': PentachoraConfig(
721
+ dim=64, depth=5, heads=4, mlp_ratio=4.0,
722
+ preserve_structure_until_layer=2,
723
+ dropout_rate=0.0, drop_path_rate=0.0
724
+ ),
725
+ 'pentachora_tiny': PentachoraConfig(
726
+ dim=384, depth=12, heads=6, mlp_ratio=4.0,
727
+ preserve_structure_until_layer=6,
728
+ dropout_rate=0.1, drop_path_rate=0.1
729
+ ),
730
+ 'pentachora_small': PentachoraConfig(
731
+ dim=512, depth=12, heads=8, mlp_ratio=4.0,
732
+ preserve_structure_until_layer=6,
733
+ dropout_rate=0.1, drop_path_rate=0.1
734
+ ),
735
+ 'pentachora_base': PentachoraConfig(
736
+ dim=768, depth=12, heads=12, mlp_ratio=4.0,
737
+ preserve_structure_until_layer=8,
738
+ dropout_rate=0.1, drop_path_rate=0.2
739
+ ),
740
+ 'pentachora_large': PentachoraConfig(
741
+ dim=1024, depth=24, heads=16, mlp_ratio=4.0,
742
+ preserve_structure_until_layer=12,
743
+ dropout_rate=0.1, drop_path_rate=0.3
744
+ ),
745
+ }
746
+
747
+ def create_pentachora_vit(variant: str = 'pentachora_small',
748
+ pretrained: bool = False,
749
+ **kwargs) -> PentachoraViT:
750
+ """
751
+ Create PentachoraViT model.
752
+
753
+ Args:
754
+ variant: Model variant name
755
+ pretrained: Whether to load pretrained weights
756
+ **kwargs: Override config parameters (including vocab_dim)
757
+
758
+ Returns:
759
+ PentachoraViT model
760
+ """
761
+ if variant not in MODEL_CONFIGS:
762
+ raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}")
763
+
764
+ config = MODEL_CONFIGS[variant]
765
+
766
+ # Override config with kwargs
767
+ for key, value in kwargs.items():
768
+ setattr(config, key, value)
769
+
770
+ model = PentachoraViT(config)
771
+
772
+ if pretrained:
773
+ warnings.warn("Pretrained weights not available yet")
774
+
775
+ return model
776
+
777
+ def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT:
778
+ """Create spark variant (smallest)."""
779
+ return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs)
780
+
781
+ def pentachora_vit_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT:
782
+ """Create tiny variant."""
783
+ return create_pentachora_vit('pentachora_tiny', pretrained=pretrained, **kwargs)
784
+
785
+ def pentachora_vit_small(pretrained: bool = False, **kwargs) -> PentachoraViT:
786
+ """Create small variant."""
787
+ return create_pentachora_vit('pentachora_small', pretrained=pretrained, **kwargs)
788
+
789
+ def pentachora_vit_base(pretrained: bool = False, **kwargs) -> PentachoraViT:
790
+ """Create base variant."""
791
+ return create_pentachora_vit('pentachora_base', pretrained=pretrained, **kwargs)
792
+
793
+ def pentachora_vit_large(pretrained: bool = False, **kwargs) -> PentachoraViT:
794
+ """Create large variant."""
795
+ return create_pentachora_vit('pentachora_large', pretrained=pretrained, **kwargs)
796
+
797
+ # ============================================
798
+ # TRAINING UTILITIES
799
+ # ============================================
800
+
801
+ def get_parameter_groups(model: PentachoraViT,
802
+ weight_decay: float = 0.05) -> List[Dict[str, Any]]:
803
+ """
804
+ Get parameter groups for optimizer with weight decay handling.
805
+
806
+ Args:
807
+ model: PentachoraViT model
808
+ weight_decay: Weight decay value
809
+
810
+ Returns:
811
+ List of parameter group dictionaries
812
+ """
813
+ no_decay = ['bias', 'norm', 'LayerNorm']
814
+
815
+ decay_params = []
816
+ no_decay_params = []
817
+
818
+ for name, param in model.named_parameters():
819
+ if not param.requires_grad:
820
+ continue
821
+
822
+ if any(nd in name for nd in no_decay):
823
+ no_decay_params.append(param)
824
+ else:
825
+ decay_params.append(param)
826
+
827
+ return [
828
+ {'params': decay_params, 'weight_decay': weight_decay},
829
+ {'params': no_decay_params, 'weight_decay': 0.0}
830
+ ]
831
+
832
+ def count_parameters(model: nn.Module) -> Dict[str, int]:
833
+ """Count model parameters."""
834
+ total = sum(p.numel() for p in model.parameters())
835
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
836
+ return {
837
+ 'total': total,
838
+ 'trainable': trainable,
839
+ 'non_trainable': total - trainable
840
+ }
841
+
842
+ # ============================================
843
+ # INFERENCE UTILITIES
844
+ # ============================================
845
+
846
+ @torch.no_grad()
847
+ def extract_features(model: PentachoraViT,
848
+ images: torch.Tensor,
849
+ feature_type: str = 'global_cls') -> torch.Tensor:
850
+ """
851
+ Extract features from images using the model.
852
+
853
+ Args:
854
+ model: PentachoraViT model
855
+ images: Input images [B, 3, H, W]
856
+ feature_type: Type of features to extract
857
+ - 'global_cls': Global CLS token
858
+ - 'vertex_cls': Vertex CLS tokens
859
+ - 'patches': Patch embeddings
860
+
861
+ Returns:
862
+ Extracted features
863
+ """
864
+ model.eval()
865
+ features = model.forward_features(images)
866
+ return features.get(feature_type, features['global_cls'])
867
+
868
+ # ============================================
869
+ # EXAMPLE USAGE AND TESTING
870
+ # ============================================
871
+
872
+ def test_model():
873
+ """Test model creation and forward pass."""
874
+ print("Testing PentachoraViT Model with Geometric Attention")
875
+ print("=" * 50)
876
+
877
+ # Test different variants
878
+ variants = ['pentachora_spark', 'pentachora_tiny', 'pentachora_small']
879
+
880
+ for variant in variants:
881
+ print(f"\nTesting {variant}:")
882
+
883
+ # Create model with vocab_dim
884
+ model = create_pentachora_vit(
885
+ variant=variant,
886
+ img_size=32,
887
+ patch_size=4,
888
+ num_classes=100,
889
+ vocab_dim=64 # Test with 64-dim vocabulary
890
+ )
891
+
892
+ # Count parameters
893
+ params = count_parameters(model)
894
+ print(f" Total parameters: {params['total']:,}")
895
+ print(f" Trainable parameters: {params['trainable']:,}")
896
+
897
+ # Test forward pass
898
+ x = torch.randn(2, 3, 32, 32)
899
+ outputs = model(x)
900
+
901
+ print(f" Output shapes:")
902
+ print(f" Logits: {outputs['logits'].shape}")
903
+ print(f" Aux logits: {outputs['aux_logits'].shape}")
904
+ print(f" Geometric alignments: {outputs['geometric_alignments'].shape}")
905
+
906
+ # Test loss computation
907
+ loss_fn = PentachoraLoss()
908
+ targets = torch.randint(0, 100, (2,))
909
+ loss = loss_fn(outputs, targets)
910
+ print(f" Loss: {loss.item():.4f}")
911
+
912
+ # Test feature extraction
913
+ features = extract_features(model, x, 'global_cls')
914
+ print(f" Extracted features shape: {features.shape}")
915
+
916
+ print("\n" + "=" * 50)
917
+ print("All tests passed!")
918
+
919
+ if __name__ == "__main__":
920
+ # Run tests
921
+ test_model()
922
+
923
+ # Example: Create model for training with vocabulary
924
+ print("\nExample: Creating model for training with 64-dim vocabulary")
925
+ model = pentachora_vit_spark(
926
+ img_size=32,
927
+ patch_size=4,
928
+ num_classes=100,
929
+ vocab_dim=64, # Specify vocabulary dimension
930
+ dropout_rate=0.0,
931
+ drop_path_rate=0.0
932
+ )
933
+
934
+ # Get parameter groups for optimizer
935
+ param_groups = get_parameter_groups(model, weight_decay=0.05)
936
+ print(f"Number of parameter groups: {len(param_groups)}")
937
+
938
+ # Example batch
939
+ images = torch.randn(4, 3, 32, 32)
940
+ targets = torch.randint(0, 100, (4,))
941
+
942
+ # Forward pass
943
+ outputs = model(images)
944
+
945
+ # Compute loss
946
+ criterion = PentachoraLoss(aux_weight=0.3, geo_weight=0.1)
947
+ loss = criterion(outputs, targets)
948
+
949
+ print(f"Training loss: {loss.item():.4f}")
950
+ print("\nModel ready for training with geometric attention!")