""" Baseline Vision Transformer with Frozen Pentachora Embeddings Now with optional theta rotation head for better classification """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from einops import rearrange import math from typing import Optional, Tuple, Dict, Any class PentachoraEmbedding(nn.Module): """ A single frozen pentachora embedding (5 vertices in geometric space). Now with theta rotation capabilities. """ def __init__(self, vertices: torch.Tensor): super().__init__() self.embed_dim = vertices.shape[-1] # Store provided vertices as frozen buffer self.register_buffer('vertices', vertices.cpu().contiguous().detach().clone().to(get_default_device())) self.vertices.requires_grad = False # Precompute normalized versions and centroid with torch.no_grad(): self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1)) self.register_buffer('centroid', self.vertices.mean(dim=0)) self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1)) # Compute theta bases for rotation self.register_buffer('theta_bases', self._compute_theta_bases().cpu().contiguous().detach().clone().to(get_default_device())) def _compute_theta_bases(self) -> torch.Tensor: """Compute orthogonal bases from vertices for theta rotation.""" U, S, V = torch.svd(self.vertices) n_components = min(5, self.embed_dim) return V[:, :n_components] # [embed_dim, n_components] def get_vertices(self) -> torch.Tensor: return self.vertices def get_centroid(self) -> torch.Tensor: return self.centroid def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor: verts = self.vertices.unsqueeze(0) if features.dim() == 1: features = features.unsqueeze(0) B = features.shape[0] if B > 1: verts = verts.expand(B, -1, -1) return PentachoronStabilizer.rose_score_magnitude(features, verts) def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor: if mode == 'rose': return self.compute_rose_score(features) features_norm = F.normalize(features, dim=-1) if mode == 'centroid': return torch.matmul(features_norm, self.centroid_norm) else: # mode == 'max' sims = torch.matmul(features_norm, self.vertices_norm.T) return sims.max(dim=-1)[0] def compute_theta_features(self, features: torch.Tensor) -> torch.Tensor: """ Project features to theta space defined by this pentachora. Returns angular features for feedforward classification. """ # Project onto pentachora bases projections = torch.matmul(features, self.theta_bases) # [batch, 5] # Compute angles relative to centroid centroid_proj = torch.matmul(self.centroid.unsqueeze(0), self.theta_bases) angles = torch.atan2(projections, centroid_proj + 1e-8) # Return sin/cos encoding return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).to(get_default_device()) # [batch, 10] class ThetaHead(nn.Module): """ Theta-based classification head using angular representations. Replaces similarity matching with learned feedforward. """ def __init__( self, embed_dim: int, num_classes: int, n_pentachora: int = 10, # Use subset of pentachora for theta hidden_dim: int = 256, dropout: float = 0.1 ): super().__init__() self.n_pentachora = n_pentachora self.embed_dim = embed_dim # Each pentachora gives 10 theta features (5 sin + 5 cos) theta_dim = n_pentachora * 10 # Project to theta space self.to_theta = nn.Sequential( nn.Linear(embed_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, theta_dim) ) # Classify from theta self.classifier = nn.Sequential( nn.LayerNorm(theta_dim), nn.Dropout(dropout), nn.Linear(theta_dim, num_classes) ) # Learnable temperature self.temperature = nn.Parameter(torch.ones(1) * 0.1) def forward(self, features: torch.Tensor, pentachora_list: nn.ModuleList) -> Dict[str, torch.Tensor]: """ Classify using theta rotation. Args: features: [batch, embed_dim] CLS features pentachora_list: List of PentachoraEmbedding modules """ # Get theta features from first n pentachora theta_features = [] for i in range(min(self.n_pentachora, len(pentachora_list))): theta = pentachora_list[i].compute_theta_features(features) theta_features.append(theta) # Concatenate all theta features theta_concat = torch.cat(theta_features, dim=-1) # [batch, n_pentachora * 10] # If we have fewer pentachora than expected, pad with zeros if len(theta_features) < self.n_pentachora: pad_size = (self.n_pentachora - len(theta_features)) * 10 padding = torch.zeros(features.shape[0], pad_size, device=features.device) theta_concat = torch.cat([theta_concat, padding], dim=-1) # Project through MLP theta_proj = self.to_theta(features) # Combine with geometric theta (residual connection) theta_combined = theta_concat + 0.1 * theta_proj # Classify logits = self.classifier(theta_combined) / self.temperature.exp() return { 'logits': logits, 'theta_features': theta_combined } class TransformerBlock(nn.Module): """Standard transformer block with multi-head attention and MLP.""" def __init__( self, dim: int, num_heads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.0, attn_dropout: float = 0.0 ): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention( dim, num_heads, dropout=attn_dropout, batch_first=True ) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x_norm = self.norm1(x) attn_out, _ = self.attn(x_norm, x_norm, x_norm) x = x + attn_out x = x + self.mlp(self.norm2(x)) return x class BaselineViT(nn.Module): """ Vision Transformer with optional theta-based classification. Can switch between similarity-based and theta-based heads. """ def __init__( self, pentachora_list: list, vocab_dim: int = 256, img_size: int = 32, patch_size: int = 4, embed_dim: int = 512, depth: int = 12, num_heads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.0, attn_dropout: float = 0.0, similarity_mode: str = 'rose', use_theta_head: bool = True, # NEW: Toggle theta head theta_n_pentachora: int = 2, # NEW: How many pentachora for theta theta_hidden_dim: int = 256 # NEW: Hidden dim for theta MLP ): super().__init__() assert isinstance(pentachora_list, list) and len(pentachora_list) > 0 self.num_classes = len(pentachora_list) self.embed_dim = embed_dim self.num_patches = (img_size // patch_size) ** 2 self.similarity_mode = similarity_mode self.pentachora_dim = vocab_dim self.use_theta_head = use_theta_head # Create pentachora embeddings self.class_pentachora = nn.ModuleList([ PentachoraEmbedding(vertices=penta) for penta in pentachora_list ]) # Patch embedding self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) # CLS token self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # Position embeddings self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim)) self.pos_drop = nn.Dropout(dropout) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout, attn_dropout=attn_dropout ) for i in range(depth) ]) # Final norm self.norm = nn.LayerNorm(embed_dim) # Project to pentachora dimension if needed if self.pentachora_dim != embed_dim: self.to_pentachora_dim = nn.Linear(embed_dim, self.pentachora_dim) else: self.to_pentachora_dim = nn.Identity() # Classification heads if use_theta_head: # NEW: Theta-based classification self.theta_head = ThetaHead( embed_dim=self.pentachora_dim, num_classes=self.num_classes, n_pentachora=theta_n_pentachora, hidden_dim=theta_hidden_dim, dropout=dropout ) else: # Original: Similarity-based classification self.theta_head = None self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07)) self.register_buffer( 'all_centroids', torch.stack([penta.centroid for penta in self.class_pentachora]) ) self.register_buffer( 'all_centroids_norm', F.normalize(self.all_centroids, dim=-1) ) self.init_weights() def init_weights(self): nn.init.trunc_normal_(self.cls_token, std=0.02) nn.init.trunc_normal_(self.pos_embed, std=0.02) for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def get_class_centroids(self) -> torch.Tensor: if self.use_theta_head: # Return centroids from pentachora for compatibility centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) return centroids else: return self.all_centroids_norm def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor: if self.similarity_mode == 'rose': all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) return PentachoronStabilizer.rose_score_magnitude( features_exp.reshape(-1, self.pentachora_dim), all_vertices.repeat(features.shape[0], 1, 1) ).reshape(features.shape[0], -1) else: centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) features_norm = F.normalize(features, dim=-1) return torch.matmul(features_norm, centroids.T) def forward_features(self, x: torch.Tensor) -> torch.Tensor: B = x.shape[0] # Patch embedding x = self.patch_embed(x) x = x.flatten(2).transpose(1, 2) # Add CLS token cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat([cls_tokens, x], dim=1) # Add position embeddings x = x + self.pos_embed x = self.pos_drop(x) # Apply transformer blocks for block in self.blocks: x = block(x) # Final norm x = self.norm(x) # Return CLS token return x[:, 0] def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]: """ Forward pass with optional theta head. """ features = self.forward_features(x) output = {} # Project to pentachora dimension features_proj = self.to_pentachora_dim(features) if self.use_theta_head: # NEW: Use theta-based classification theta_output = self.theta_head(features_proj, self.class_pentachora) output['logits'] = theta_output['logits'] output['theta_features'] = theta_output['theta_features'] # Still compute similarities for analysis with torch.no_grad(): similarities = self.compute_pentachora_similarities(features_proj) output['similarities'] = similarities else: # Original: Use similarity-based classification similarities = self.compute_pentachora_similarities(features_proj) logits = similarities * self.temperature.exp() output['logits'] = logits output['similarities'] = similarities if return_features: output['features'] = features return output # Helper function to convert existing model to theta def enable_theta_head(model: BaselineViT, n_pentachora: int = 10, hidden_dim: int = 256): """ Convert an existing similarity-based model to use theta head. This modifies the model in-place. """ if model.use_theta_head: print("Model already using theta head") return model print(f"Converting to theta head with {n_pentachora} pentachora...") # Create theta head model.theta_head = ThetaHead( embed_dim=model.pentachora_dim, num_classes=model.num_classes, n_pentachora=n_pentachora, hidden_dim=hidden_dim, dropout=0.1 ).to(next(model.parameters()).device) # Set flag model.use_theta_head = True # Initialize new parameters for m in model.theta_head.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) print("✓ Theta head enabled") return model if __name__ == "__main__": print("BaselineViT with optional theta head") print("Use 'use_theta_head=True' to enable theta classification") print("Or call enable_theta_head() on existing model")