""" PentachoraViT: Vision Transformer with Pentachoron Geometric Structure Enhanced with Geometric Attention for improved head cohesion and generalization FIXED: All parameters initialized at module creation time (no lazy init) """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from einops import rearrange, repeat import math from typing import Optional, Dict, Tuple, List, Any from dataclasses import dataclass import warnings # ============================================ # CONFIGURATION CLASSES # ============================================ @dataclass class PentachoraConfig: """Configuration for PentachoraViT models.""" img_size: int = 32 patch_size: int = 4 num_classes: int = 100 dim: int = 512 vocab_dim: Optional[int] = None # Vocabulary dimension (can differ from model dim) depth: int = 12 heads: int = 8 mlp_ratio: float = 4.0 use_mesh_attention: bool = True preserve_structure_until_layer: int = 6 dropout_rate: float = 0.0 drop_path_rate: float = 0.0 aux_loss_weight: float = 0.0 geo_loss_weight: float = 0.0 vocab: Optional[Any] = None @property def num_patches(self) -> int: return (self.img_size // self.patch_size) ** 2 # ============================================ # GEOMETRIC ATTENTION COMPONENTS (FIXED INIT) # ============================================ def perfect_4simplex(device): """Create perfect 4-simplex (pentachoron) vertices in 4D.""" sqrt5 = math.sqrt(5) vertices = torch.tensor([ [1, 1, 1, -1/sqrt5], [1, -1, -1, -1/sqrt5], [-1, 1, -1, -1/sqrt5], [-1, -1, 1, -1/sqrt5], [0, 0, 0, 4/sqrt5] ], device=device, dtype=torch.float32) return vertices / 2 # Normalize scale def softmin_over_last(distances, tau): """Softmin over last dimension.""" return F.softmax(-distances / tau, dim=-1).sum(dim=-1) @dataclass class GeometricConfig: """Configuration for geometric attention.""" softmin_tau: float = 0.05 fuse_alpha: float = 0.7 phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2) jitter: float = 0.02 shift: float = 0.71 rotate_cycle: int = 11 use_phase_variance: bool = False geometry_type: str = "pentachoron" class GeometricNavigator(nn.Module): """Maps inputs to geometric regions in 4D space - FIXED with immediate initialization.""" def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig, num_heads: int = 1, device=None): super().__init__() self.input_dim = input_dim self.num_regions = num_regions self.config = config self.num_heads = num_heads # Use CPU by default if device not specified if device is None: device = torch.device('cpu') # Create separate parameters for each head if num_heads > 1 if num_heads > 1: self.to_nav = nn.Parameter(torch.randn(num_heads, input_dim, 4, device=device) * 0.02) self.vertex_w = nn.Parameter(torch.zeros(num_heads, num_regions, 5, device=device)) else: self.to_nav = nn.Linear(input_dim, 4, bias=False) self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5, device=device)) # Pre-compute phase tensors for vectorization self.register_buffer('phase_cos', torch.cos(torch.tensor(config.phases, dtype=torch.float32, device=device))) self.register_buffer('phase_sin', torch.sin(torch.tensor(config.phases, dtype=torch.float32, device=device))) # Initialize geometry immediately at creation time self._init_geometry(device) def _init_geometry(self, device): """Initialize geometry at module creation time.""" base = perfect_4simplex(device) if self.num_heads > 1: D = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device) S = torch.zeros(self.num_heads, self.num_regions, 5, 4, device=device) for h in range(self.num_heads): for r in range(self.num_regions): D[h, r] = base + self.config.jitter * torch.randn_like(base) theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device) rot = torch.eye(4, device=device) c, s_val = torch.cos(theta), torch.sin(theta) rot[0, 0] = c; rot[0, 1] = -s_val rot[1, 0] = s_val; rot[1, 1] = c S[h, r] = (base @ rot) + self.config.shift S[h, r] += self.config.jitter * torch.randn_like(S[h, r]) else: D = torch.zeros(self.num_regions, 5, 4, device=device) S = torch.zeros(self.num_regions, 5, 4, device=device) for r in range(self.num_regions): D[r] = base + self.config.jitter * torch.randn_like(base) theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device) rot = torch.eye(4, device=device) c, s_val = torch.cos(theta), torch.sin(theta) rot[0, 0] = c; rot[0, 1] = -s_val rot[1, 0] = s_val; rot[1, 1] = c S[r] = (base @ rot) + self.config.shift S[r] += self.config.jitter * torch.randn_like(S[r]) self.D = nn.Parameter(D) self.S = nn.Parameter(S) def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """Navigate inputs through geometric space - OPTIMIZED with vectorized phase computation.""" if self.num_heads > 1: # Batched navigation for multiple heads BT, H, head_dim = x.shape # Batched transformation nav_x = torch.einsum('bhi,hio->bho', x, self.to_nav) # [BT, H, 4] # Dispatcher scores nav_x_disp = nav_x.view(BT, H, 1, 1, 4) D_exp = self.D.unsqueeze(0) # [1, H, regions, 5, 4] d_disp = torch.norm(nav_x_disp - D_exp, dim=-1) s_disp = -softmin_over_last(d_disp, self.config.softmin_tau) # OPTIMIZED: Vectorized phase computation (no loop) cos_phases = self.phase_cos.view(-1, 1, 1, 1, 1) sin_phases = self.phase_sin.view(-1, 1, 1, 1, 1) # Compute all phase variants at once [phases, H, regions, 5, 4] Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0) # Apply vertex weighting to all phases w = F.softmax(self.vertex_w, dim=-1) w_exp = w.unsqueeze(0).unsqueeze(-1) # [1, H, regions, 5, 1] Vt_mean = Vt_all.mean(dim=3, keepdim=True) Vt_all = (1.0 - w_exp) * Vt_all + w_exp * Vt_mean # Compute all ribbon distances at once nav_x_ribbon = nav_x.view(BT, 1, H, 1, 1, 4) Vt_exp = Vt_all.unsqueeze(0) # [1, phases, H, regions, 5, 4] d_ribbon_all = torch.norm(nav_x_ribbon - Vt_exp, dim=-1) s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau) s_ribbon = s_ribbon_all.mean(dim=1) # Average over phases scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp scores = scores.reshape(BT * H, self.num_regions) else: # Original single-head navigation nav_x = self.to_nav(x) nav_x_exp = nav_x[:, None, None, :] D_exp = self.D[None, :, :, :] d_disp = torch.norm(nav_x_exp - D_exp, dim=-1) s_disp = -softmin_over_last(d_disp, self.config.softmin_tau) w = F.softmax(self.vertex_w, dim=1) # OPTIMIZED: Vectorized phase computation for single head cos_phases = self.phase_cos.view(-1, 1, 1, 1) sin_phases = self.phase_sin.view(-1, 1, 1, 1) Vt_all = cos_phases * self.D.unsqueeze(0) + sin_phases * self.S.unsqueeze(0) w_expanded = w.unsqueeze(0).unsqueeze(-1) Vt_mean = Vt_all.mean(dim=2, keepdim=True) Vt_all = (1.0 - w_expanded) * Vt_all + w_expanded * Vt_mean nav_x_phase = nav_x[:, None, None, None, :] Vt_exp = Vt_all.unsqueeze(0) d_ribbon_all = torch.norm(nav_x_phase - Vt_exp, dim=-1) s_ribbon_all = -softmin_over_last(d_ribbon_all, self.config.softmin_tau) s_ribbon = s_ribbon_all.mean(dim=1) scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp diagnostics = { 'dispatcher_scores': s_disp.detach() if self.num_heads == 1 else s_disp.reshape(BT * H, -1).detach(), 'ribbon_scores': s_ribbon.detach() if self.num_heads == 1 else s_ribbon.reshape(BT * H, -1).detach() } return {'scores': scores, 'diagnostics': diagnostics} class GeometricAttention(nn.Module): """Multi-head geometric attention with Q-K alignment - FIXED with proper device handling.""" def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None, config: Optional[GeometricConfig] = None, dropout: float = 0.0, device=None): super().__init__() self.dim = dim self.num_heads = num_heads self.head_dim = dim // num_heads if num_regions is None: num_regions = min(self.head_dim, 16) if config is None: config = GeometricConfig() self.config = config self.to_qkv = nn.Linear(dim, dim * 3, bias=False) # Create batched navigators with device self.q_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device) self.k_navigator = GeometricNavigator(self.head_dim, num_regions, config, num_heads=num_heads, device=device) self.out_proj = nn.Linear(dim, dim) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]: B, T, D = x.shape qkv = self.to_qkv(x) q, k, v = qkv.chunk(3, dim=-1) q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) # Prepare for batched navigation q_batched = q.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim) k_batched = k.transpose(1, 2).reshape(B * T, self.num_heads, self.head_dim) # Navigate all heads at once q_nav = self.q_navigator.navigate(q_batched) k_nav = self.k_navigator.navigate(k_batched) # Reshape scores back q_scores = q_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2) k_scores = k_nav['scores'].reshape(B, T, self.num_heads, -1).transpose(1, 2) # OPTIMIZED: Compute attention for all heads at once using einsum scale = math.sqrt(q_scores.size(-1)) attn = torch.einsum('bhqr,bhkr->bhqk', q_scores, k_scores) / scale if mask is not None: mask_expanded = mask.unsqueeze(1).unsqueeze(2) attn = attn.masked_fill(mask_expanded == 0, -1e9) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) # Apply attention to values out = torch.einsum('bhqk,bhkd->bhqd', attn, v) out = out.transpose(1, 2).reshape(B, T, D) output = self.out_proj(out) output = self.dropout(output) if return_diagnostics: return output, {'q_diagnostics': q_nav['diagnostics'], 'k_diagnostics': k_nav['diagnostics']} return output, None # ============================================ # DROP PATH (STOCHASTIC DEPTH) # ============================================ class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample.""" def __init__(self, drop_prob: float = 0.): super().__init__() self.drop_prob = drop_prob def forward(self, x: torch.Tensor) -> torch.Tensor: if self.drop_prob == 0. or not self.training: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() output = x.div(keep_prob) * random_tensor return output # ============================================ # HIERARCHICAL CLS WITH PENTACHORA # ============================================ class HierarchicalPentachoronCLS(nn.Module): """ Hierarchical CLS structure with pentachoron geometry. Uses vocabulary embeddings for CLS tokens. """ def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100): super().__init__() self.dim = dim self.vocab_dim = vocab_dim self.num_classes = num_classes # Class-specific pentachora from vocabulary self.register_buffer('class_pentachora', torch.randn(num_classes, 5, vocab_dim) * 0.02) # Projection from vocabulary dimension to model dimension if vocab_dim != dim: self.vocab_to_model = nn.Linear(vocab_dim, dim) else: self.vocab_to_model = nn.Identity() # Learnable aggregation weights self.vertex_weights = nn.Parameter(torch.ones(5) / 5) # Optional learnable offset self.global_offset = nn.Parameter(torch.zeros(1, 1, dim)) # Layer norms self.vertex_norm = nn.LayerNorm(dim) self.global_norm = nn.LayerNorm(dim) def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """Generate CLS tokens for batch.""" # Get class-specific pentachora class_pentachora = self.class_pentachora # This is now a computed property if class_indices is not None and class_indices.shape[0] == batch_size: vertex_cls_vocab = class_pentachora[class_indices] else: vertex_cls_vocab = class_pentachora.mean(dim=0, keepdim=True) vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1) # Project from vocabulary dimension to model dimension vertex_cls = self.vocab_to_model(vertex_cls_vocab) vertex_cls = self.vertex_norm(vertex_cls) # Create global CLS as weighted combination weights = F.softmax(self.vertex_weights, dim=0) global_cls = torch.einsum('bvd,v->bd', vertex_cls, weights).unsqueeze(1) global_cls = global_cls + self.global_offset global_cls = self.global_norm(global_cls) return global_cls, vertex_cls def get_class_prototypes(self) -> torch.Tensor: """Get class prototypes in model dimension.""" class_pentachora = self.class_pentachora # Get computed pentachora pentachora_model = self.vocab_to_model(class_pentachora) weights = F.softmax(self.vertex_weights, dim=0) prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights) return prototypes # ============================================ # GEOMETRIC PROJECTION LAYER # ============================================ class GeometricProjection(nn.Module): """Project patches onto pentachoron geometry.""" def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1): super().__init__() self.dim = dim self.vocab_dim = vocab_dim self.num_classes = num_classes # Projection from model dim to vocab dim self.to_vocab_space = nn.Linear(dim, vocab_dim) # Vertex-specific projections self.vertex_projections = nn.ModuleList([ nn.Linear(vocab_dim, vocab_dim, bias=False) for _ in range(5) ]) # Temperature for alignment scores self.temperature = nn.Parameter(torch.ones(1)) self.norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor: """Compute alignment between patches and class pentachora.""" B, N, D = patches.shape C = pentachora.shape[0] # Normalize patches patches = self.norm(patches) # Project patches to vocabulary space patches_vocab = self.to_vocab_space(patches) patches_vocab = F.normalize(patches_vocab, dim=-1) # Compute alignment with each vertex alignments = [] for v in range(5): patches_v = self.vertex_projections[v](patches_vocab) patches_v = F.normalize(patches_v, dim=-1) vertex_v = F.normalize(pentachora[:, v, :], dim=-1) alignment = torch.matmul(patches_v, vertex_v.T) / self.temperature alignments.append(alignment) # Average alignments across vertices alignments = torch.stack(alignments, dim=-1).mean(dim=-1) return self.dropout(alignments) # ============================================ # MLP BLOCK # ============================================ class MLP(nn.Module): """MLP block with GELU activation.""" def __init__(self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, dropout: float = 0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.drop1 = nn.Dropout(dropout) self.fc2 = nn.Linear(hidden_features, out_features) self.drop2 = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x # ============================================ # VIT BLOCK WITH GEOMETRIC ATTENTION # ============================================ class PentachoronViTBlock(nn.Module): """ViT block with geometric attention for structured layers.""" def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0, use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0., drop_path: float = 0., device=None): super().__init__() self.norm1 = nn.LayerNorm(dim) # Use GeometricAttention for structured layers, standard for others if use_mesh: self.attn = GeometricAttention( dim=dim, num_heads=heads, num_regions=min(dim // heads, 16), config=GeometricConfig(), dropout=attn_dropout, device=device ) else: # Standard multi-head attention for later layers self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True) self.use_mesh = use_mesh self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.LayerNorm(dim) mlp_hidden = int(dim * mlp_ratio) self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor: if self.use_mesh: # GeometricAttention attn_out, _ = self.attn(self.norm1(x)) x = x + self.drop_path1(attn_out) else: # Standard attention normalized = self.norm1(x) attn_out, _ = self.attn(normalized, normalized, normalized) x = x + self.drop_path1(attn_out) x = x + self.drop_path2(self.mlp(self.norm2(x))) return x # ============================================ # PATCH EMBEDDING # ============================================ class PatchEmbed(nn.Module): """2D Image to Patch Embedding.""" def __init__(self, img_size: int = 32, patch_size: int = 4, in_chans: int = 3, embed_dim: int = 512): super().__init__() self.img_size = img_size self.patch_size = patch_size self.num_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) x = rearrange(x, 'b c h w -> b (h w) c') x = self.norm(x) return x # ============================================ # PENTACHORA VISION TRANSFORMER # ============================================ class PentachoraViT(nn.Module): """ Vision Transformer with pentachoron-based hierarchical CLS tokens and geometric vocabulary integration. """ def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs): super().__init__() # Use config or kwargs if config is not None: cfg = config else: cfg = PentachoraConfig(**kwargs) self.config = cfg self.num_classes = cfg.num_classes self.dim = cfg.dim self.depth = cfg.depth self.preserve_structure_until_layer = cfg.preserve_structure_until_layer # Set vocabulary dimension if cfg.vocab_dim is not None: self.vocab_dim = cfg.vocab_dim elif 'vocab_dim' in kwargs: self.vocab_dim = kwargs['vocab_dim'] else: self.vocab_dim = cfg.dim # Patch embedding self.patch_embed = PatchEmbed( cfg.img_size, cfg.patch_size, 3, cfg.dim ) num_patches = self.patch_embed.num_patches # Positional embedding self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02) self.pos_drop = nn.Dropout(cfg.dropout_rate) # CLS tokens with pentachoron structure self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes) # Geometric projection layer self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate) # Initialize from vocabulary if provided if cfg.vocab is not None: self._init_from_vocab(cfg.vocab) # Stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)] # Transformer blocks with geometric attention self.blocks = nn.ModuleList([ PentachoronViTBlock( dim=cfg.dim, heads=cfg.heads, mlp_ratio=cfg.mlp_ratio, use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer), dropout=cfg.dropout_rate, attn_dropout=cfg.dropout_rate, drop_path=dpr[i], device=torch.device('cpu') # Initialize on CPU, will be moved later ) for i in range(cfg.depth) ]) # Final norm self.norm = nn.LayerNorm(cfg.dim) # Classification heads self.use_prototype_classifier = True if self.use_prototype_classifier: self.head = None else: self.head = nn.Linear(cfg.dim, cfg.num_classes) # Auxiliary head for vertex tokens self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes) # Initialize weights self.apply(self._init_weights) def _init_weights(self, m: nn.Module): """Initialize model weights.""" if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) def _init_from_vocab(self, vocab): """Initialize class pentachora from geometric vocabulary.""" try: print("Initializing pentachora from vocabulary...") if not hasattr(vocab, 'encode_batch'): print("Vocabulary provided but encode_batch method not found, using random initialization") return # Get CIFAR-100 class names class_names = self._get_cifar100_classes() # Generate pentachora for all classes pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True) pentachora = np.stack(pentachora_list, axis=0) # Get actual dimensions from the encoded data actual_vocab_dim = pentachora.shape[-1] print(f"Encoded pentachora shape: {pentachora.shape}") print(f"Detected vocabulary dimension: {actual_vocab_dim}") # Validate basic shape requirements if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5: print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}") print("Using random initialization") return # Update vocabulary dimension self.vocab_dim = actual_vocab_dim self.cls_tokens.vocab_dim = actual_vocab_dim self.geometric_proj.vocab_dim = actual_vocab_dim # Replace class_pentachora with the loaded vocabulary self.cls_tokens.class_pentachora = torch.tensor(pentachora, dtype=torch.float32) # Update/create projection layer if dimensions differ if actual_vocab_dim != self.dim: self.cls_tokens.vocab_to_model = nn.Linear(actual_vocab_dim, self.dim) else: self.cls_tokens.vocab_to_model = nn.Identity() # Rebuild geometric projection components self.geometric_proj.to_vocab_space = nn.Linear(self.dim, actual_vocab_dim) self.geometric_proj.vertex_projections = nn.ModuleList([ nn.Linear(actual_vocab_dim, actual_vocab_dim, bias=False) for _ in range(5) ]) # Re-initialize the new layers nn.init.xavier_uniform_(self.geometric_proj.to_vocab_space.weight) for proj in self.geometric_proj.vertex_projections: nn.init.xavier_uniform_(proj.weight) if actual_vocab_dim != self.dim: nn.init.xavier_uniform_(self.cls_tokens.vocab_to_model.weight) print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary") print(f" Vocabulary dimension: {actual_vocab_dim}") print(f" Model internal dimension: {self.dim}") except Exception as e: print(f"Error initializing from vocabulary: {e}") print("Using random initialization") def _get_cifar100_classes(self): """Get CIFAR-100 class names.""" return [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] def forward_features(self, x: torch.Tensor, class_indices: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: """Extract features from input.""" B = x.shape[0] # Patch embedding x = self.patch_embed(x) x = x + self.pos_embed x = self.pos_drop(x) # Get hierarchical CLS tokens global_cls, vertex_cls = self.cls_tokens(B, class_indices) # Concatenate CLS tokens with patches x = torch.cat([global_cls, vertex_cls, x], dim=1) # Apply transformer blocks for i, block in enumerate(self.blocks): preserve = i < self.preserve_structure_until_layer x = block(x, preserve_structure=preserve) # Apply final norm x = self.norm(x) # Split tokens global_cls = x[:, 0] vertex_cls = x[:, 1:6] patches = x[:, 6:] return { 'global_cls': global_cls, 'vertex_cls': vertex_cls, 'patches': patches } def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: """Forward pass through the model.""" # During training, use target labels for class-specific CLS initialization class_indices = targets if self.training and targets is not None else None features = self.forward_features(x, class_indices) # Primary classification using prototype matching if self.use_prototype_classifier: prototypes = self.cls_tokens.get_class_prototypes() prototypes = F.normalize(prototypes, dim=-1) global_cls_norm = F.normalize(features['global_cls'], dim=-1) logits = torch.matmul(global_cls_norm, prototypes.T) * 20.0 else: logits = self.head(features['global_cls']) # Auxiliary classification using vertex tokens B = features['vertex_cls'].shape[0] vertex_flat = features['vertex_cls'].reshape(B, -1) aux_logits = self.head_aux(vertex_flat) # Geometric alignment scores geometric_alignments = self.geometric_proj( features['patches'], self.cls_tokens.class_pentachora ) return { 'logits': logits, 'aux_logits': aux_logits, 'geometric_alignments': geometric_alignments, 'vertex_cls': features['vertex_cls'], 'global_cls': features['global_cls'], 'patches': features['patches'] } # ============================================ # LOSS FUNCTIONS # ============================================ class PentachoraLoss(nn.Module): """Combined loss for PentachoraViT training.""" def __init__(self, aux_weight: float = 0.3, geo_weight: float = 0.1, smoothing: float = 0.0): super().__init__() self.aux_weight = aux_weight self.geo_weight = geo_weight self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing) def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor: """Compute combined loss.""" # Primary classification loss loss = self.criterion(outputs['logits'], targets) # Auxiliary loss from vertex tokens if 'aux_logits' in outputs and self.aux_weight > 0: aux_loss = self.criterion(outputs['aux_logits'], targets) loss = loss + self.aux_weight * aux_loss # Geometric alignment loss if 'geometric_alignments' in outputs and self.geo_weight > 0: geo_logits = outputs['geometric_alignments'].mean(dim=1) geo_loss = self.criterion(geo_logits, targets) loss = loss + self.geo_weight * geo_loss return loss # ============================================ # MODEL REGISTRY AND BUILDERS # ============================================ MODEL_CONFIGS = { 'pentachora_spark_xs': PentachoraConfig( dim=100, depth=2, heads=10, mlp_ratio=4.0, preserve_structure_until_layer=1, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_spark': PentachoraConfig( dim=100, depth=5, heads=4, mlp_ratio=4.0, preserve_structure_until_layer=1, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_shock': PentachoraConfig( dim=100, depth=10, heads=5, mlp_ratio=4.0, patch_size=5, preserve_structure_until_layer=5, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_shock_xs_32d': PentachoraConfig( dim=32, depth=2, heads=8, mlp_ratio=4.0, preserve_structure_until_layer=4, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_shock_xs_64d': PentachoraConfig( dim=64, depth=2, heads=8, mlp_ratio=1.0, preserve_structure_until_layer=4, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_shock_xs_128d': PentachoraConfig( dim=128, depth=2, heads=8, mlp_ratio=2.0, preserve_structure_until_layer=4, vocab_dim=256, dropout_rate=0.0, drop_path_rate=0.0 ), 'vit_pixie_256_patch4': PentachoraConfig( dim=256, depth=10, heads=16, mlp_ratio=1.0, preserve_structure_until_layer=10, vocab_dim=256, patch_size=4, dropout_rate=0.0, drop_path_rate=0.0 ), 'vit_pixie_256_patch2': PentachoraConfig( dim=256, depth=10, heads=16, mlp_ratio=1.0, preserve_structure_until_layer=10, vocab_dim=256, patch_size=2, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_shock_xs_256d': PentachoraConfig( dim=256, depth=2, heads=8, mlp_ratio=4.0, preserve_structure_until_layer=4, vocab_dim=128, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_shock_xs_512d': PentachoraConfig( dim=512, depth=2, heads=8, mlp_ratio=4.0, preserve_structure_until_layer=4, dropout_rate=0.0, drop_path_rate=0.0 ), 'pentachora_tiny': PentachoraConfig( dim=384, depth=12, heads=6, mlp_ratio=4.0, preserve_structure_until_layer=6, dropout_rate=0.1, drop_path_rate=0.1 ), 'pentachora_small': PentachoraConfig( dim=512, depth=12, heads=8, mlp_ratio=4.0, preserve_structure_until_layer=6, dropout_rate=0.1, drop_path_rate=0.1 ), 'pentachora_base': PentachoraConfig( dim=768, depth=12, heads=12, mlp_ratio=4.0, preserve_structure_until_layer=8, dropout_rate=0.1, drop_path_rate=0.2 ), 'pentachora_large': PentachoraConfig( dim=1024, depth=24, heads=16, mlp_ratio=4.0, preserve_structure_until_layer=12, dropout_rate=0.1, drop_path_rate=0.3 ), } def create_pentachora_vit(variant: str = 'pentachora_small', pretrained: bool = False, **kwargs) -> PentachoraViT: """Create PentachoraViT model.""" if variant not in MODEL_CONFIGS: raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}") config = MODEL_CONFIGS[variant] # Override config with kwargs for key, value in kwargs.items(): setattr(config, key, value) model = PentachoraViT(config) if pretrained: warnings.warn("Pretrained weights not available yet") return model # Convenience functions for each variant def pentachora_vit_spark_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create spark variant (smallest).""" return create_pentachora_vit('pentachora_spark_xs', pretrained=pretrained, **kwargs) def pentachora_shock_xs_64d(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create shock xs 64d variant.""" return create_pentachora_vit('pentachora_shock_xs_64d', pretrained=pretrained, **kwargs) def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create spark variant.""" return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs) def pentachora_shock_xs_32d(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create shock xs 32d variant.""" return create_pentachora_vit('pentachora_shock_xs_32d', pretrained=pretrained, **kwargs) def pentachora_shock_xs_256d(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create shock xs 256d variant.""" return create_pentachora_vit('pentachora_shock_xs_256d', pretrained=pretrained, **kwargs) def pentachora_shock_xs_512d(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create shock xs 512d variant.""" return create_pentachora_vit('pentachora_shock_xs_512d', pretrained=pretrained, **kwargs) def pentachora_vit_shock(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create shock variant.""" return create_pentachora_vit('pentachora_shock', pretrained=pretrained, **kwargs) def pentachora_vit_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create tiny variant.""" return create_pentachora_vit('pentachora_tiny', pretrained=pretrained, **kwargs) def pentachora_vit_small(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create small variant.""" return create_pentachora_vit('pentachora_small', pretrained=pretrained, **kwargs) def pentachora_vit_base(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create base variant.""" return create_pentachora_vit('pentachora_base', pretrained=pretrained, **kwargs) def pentachora_vit_large(pretrained: bool = False, **kwargs) -> PentachoraViT: """Create large variant.""" return create_pentachora_vit('pentachora_large', pretrained=pretrained, **kwargs) # ============================================ # TRAINING UTILITIES # ============================================ def get_parameter_groups(model: PentachoraViT, weight_decay: float = 0.05) -> List[Dict[str, Any]]: """Get parameter groups for optimizer with weight decay handling.""" no_decay = ['bias', 'norm', 'LayerNorm'] decay_params = [] no_decay_params = [] for name, param in model.named_parameters(): if not param.requires_grad: continue if any(nd in name for nd in no_decay): no_decay_params.append(param) else: decay_params.append(param) return [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ] def count_parameters(model: nn.Module) -> Dict[str, int]: """Count model parameters.""" total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) return { 'total': total, 'trainable': trainable, 'non_trainable': total - trainable } # ============================================ # INFERENCE UTILITIES # ============================================ @torch.no_grad() def extract_features(model: PentachoraViT, images: torch.Tensor, feature_type: str = 'global_cls') -> torch.Tensor: """Extract features from images using the model.""" model.eval() features = model.forward_features(images) return features.get(feature_type, features['global_cls']) # ============================================ # EXAMPLE USAGE AND TESTING # ============================================ def test_model(): """Test model creation and forward pass.""" print("Testing Fixed PentachoraViT Model") print("=" * 50) # Test different variants variants = ['pentachora_spark', 'pentachora_shock_xs_256d', 'pentachora_small'] for variant in variants: print(f"\nTesting {variant}:") # Create model with vocab_dim model = create_pentachora_vit( variant=variant, img_size=32, patch_size=4, num_classes=100, vocab_dim=64 ) # Count parameters params = count_parameters(model) print(f" Total parameters: {params['total']:,}") print(f" Trainable parameters: {params['trainable']:,}") # Test forward pass x = torch.randn(2, 3, 32, 32) # Time the forward pass if torch.cuda.is_available(): model = model.cuda() x = x.cuda() torch.cuda.synchronize() import time start = time.time() outputs = model(x) if torch.cuda.is_available(): torch.cuda.synchronize() end = time.time() print(f" Output shapes:") print(f" Logits: {outputs['logits'].shape}") print(f" Aux logits: {outputs['aux_logits'].shape}") print(f" Geometric alignments: {outputs['geometric_alignments'].shape}") print(f" Forward pass time: {(end - start)*1000:.2f}ms") # Test loss computation loss_fn = PentachoraLoss() targets = torch.randint(0, 100, (2,)) if torch.cuda.is_available(): targets = targets.cuda() loss = loss_fn(outputs, targets) print(f" Loss: {loss.item():.4f}") print("\n" + "=" * 50) print("All tests passed!") if __name__ == "__main__": # Run tests test_model() # Example: Create model for training print("\nExample: Creating model with proper initialization") model = pentachora_shock_xs_256d( img_size=32, num_classes=100, vocab_dim=100, dropout_rate=0.0, drop_path_rate=0.0 ) # All parameters are initialized immediately print(f"Model has {count_parameters(model)['total']:,} parameters") print("All geometric parameters initialized at creation time") # Move model to CUDA if available if torch.cuda.is_available(): model = model.cuda() print("Model moved to CUDA") # Now torch.compile should work without issues if hasattr(torch, 'compile'): print("Compiling model with torch.compile...") try: model = torch.compile(model) print("✓ Model compiled successfully") except Exception as e: print(f"Compilation warning: {e}") print("Continuing without compilation") print("\nModel ready for training with all parameters properly initialized!")