| """ |
| Baseline Vision Transformer with Frozen Pentachora Embeddings |
| Clean architecture with geometric semantic anchors |
| Assumes PentachoronStabilizer is loaded externally |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from einops import rearrange |
| import math |
| from typing import Optional, Tuple, Dict, Any |
|
|
|
|
| class PentachoraEmbedding(nn.Module): |
| """ |
| A single frozen pentachora embedding (5 vertices in geometric space). |
| Accepts pre-computed vertices only. No random initialization. |
| """ |
| |
| def __init__(self, vertices: torch.Tensor): |
| super().__init__() |
| |
| |
| self.embed_dim = vertices.shape[-1] |
| |
| |
| self.register_buffer('vertices', vertices) |
| self.vertices.requires_grad = False |
| |
| |
| with torch.no_grad(): |
| self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1)) |
| self.register_buffer('centroid', self.vertices.mean(dim=0)) |
| self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1)) |
| |
| def get_vertices(self) -> torch.Tensor: |
| """Get all 5 vertices.""" |
| return self.vertices |
| |
| def get_centroid(self) -> torch.Tensor: |
| """Get the centroid of the pentachora.""" |
| return self.centroid |
| |
| def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute Rose similarity score with this pentachora. |
| Uses external PentachoronStabilizer.rose_score_magnitude |
| """ |
| |
| verts = self.vertices.unsqueeze(0) |
| if features.dim() == 1: |
| features = features.unsqueeze(0) |
| |
| |
| B = features.shape[0] |
| if B > 1: |
| verts = verts.expand(B, -1, -1) |
| |
| return PentachoronStabilizer.rose_score_magnitude(features, verts) |
| |
| def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor: |
| """ |
| Compute similarity between features and this pentachora. |
| |
| Args: |
| features: [batch, dim] or [batch, seq, dim] |
| mode: 'centroid', 'max' (max over vertices), or 'rose' (Rose score) |
| |
| Returns: |
| similarities: [batch] or [batch, seq] |
| """ |
| if mode == 'rose': |
| return self.compute_rose_score(features) |
| |
| features_norm = F.normalize(features, dim=-1) |
| |
| if mode == 'centroid': |
| |
| return torch.matmul(features_norm, self.centroid_norm) |
| else: |
| |
| sims = torch.matmul(features_norm, self.vertices_norm.T) |
| return sims.max(dim=-1)[0] |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """Standard transformer block with multi-head attention and MLP.""" |
| |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.0, |
| attn_dropout: float = 0.0 |
| ): |
| super().__init__() |
| |
| self.norm1 = nn.LayerNorm(dim) |
| self.attn = nn.MultiheadAttention( |
| dim, |
| num_heads, |
| dropout=attn_dropout, |
| batch_first=True |
| ) |
| |
| self.norm2 = nn.LayerNorm(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = nn.Sequential( |
| nn.Linear(dim, mlp_hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(mlp_hidden_dim, dim), |
| nn.Dropout(dropout) |
| ) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| x_norm = self.norm1(x) |
| attn_out, _ = self.attn(x_norm, x_norm, x_norm) |
| x = x + attn_out |
| |
| |
| x = x + self.mlp(self.norm2(x)) |
| |
| return x |
|
|
|
|
| class BaselineViT(nn.Module): |
| """ |
| Clean baseline Vision Transformer with frozen pentachora embeddings. |
| """ |
| |
| def __init__( |
| self, |
| pentachora_list: list, |
| vocab_dim: int = 256, |
| img_size: int = 32, |
| patch_size: int = 4, |
| embed_dim: int = 512, |
| depth: int = 12, |
| num_heads: int = 8, |
| mlp_ratio: float = 4.0, |
| dropout: float = 0.0, |
| attn_dropout: float = 0.0, |
| similarity_mode: str = 'rose' |
| ): |
| super().__init__() |
| |
| |
| assert isinstance(pentachora_list, list), f"Expected list, got {type(pentachora_list)}" |
| assert len(pentachora_list) > 0, "Empty pentachora list" |
| |
| |
| for i, penta in enumerate(pentachora_list): |
| assert isinstance(penta, torch.Tensor), f"Item {i} is not a tensor" |
| |
| self.num_classes = len(pentachora_list) |
| self.embed_dim = embed_dim |
| self.num_patches = (img_size // patch_size) ** 2 |
| self.similarity_mode = similarity_mode |
| self.pentachora_dim = vocab_dim |
| |
| |
| self.class_pentachora = nn.ModuleList([ |
| PentachoraEmbedding(vertices=penta) |
| for penta in pentachora_list |
| ]) |
| |
| |
| self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) |
| |
| |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| |
| |
| self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim)) |
| self.pos_drop = nn.Dropout(dropout) |
| |
| |
| self.blocks = nn.ModuleList([ |
| TransformerBlock( |
| dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| dropout=dropout, |
| attn_dropout=attn_dropout |
| ) |
| for i in range(depth) |
| ]) |
| |
| |
| self.norm = nn.LayerNorm(embed_dim) |
| |
| |
| if self.pentachora_dim != embed_dim: |
| self.to_pentachora_dim = nn.Linear(embed_dim, self.pentachora_dim) |
| else: |
| self.to_pentachora_dim = nn.Identity() |
| |
| |
| self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07)) |
| |
| self.register_buffer( |
| 'all_centroids', |
| torch.stack([penta.centroid for penta in self.class_pentachora]) |
| ) |
| self.register_buffer( |
| 'all_centroids_norm', |
| F.normalize(self.all_centroids, dim=-1) |
| ) |
|
|
| |
| self.init_weights() |
| |
| def init_weights(self): |
| """Initialize model weights.""" |
| nn.init.trunc_normal_(self.cls_token, std=0.02) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight) |
| nn.init.zeros_(m.bias) |
| |
| |
| def get_class_centroids(self) -> torch.Tensor: |
| return self.all_centroids_norm |
| |
| def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute similarities between features and all class pentachora (vectorized). |
| """ |
| if self.similarity_mode == 'rose': |
| |
| all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) |
| |
| features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) |
| |
| return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, self.embed_dim), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1) |
| else: |
| |
| centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) |
| features_norm = F.normalize(features, dim=-1) |
| return torch.matmul(features_norm, centroids.T) |
|
|
| |
| def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
| """Extract features from images.""" |
| B = x.shape[0] |
| |
| |
| x = self.patch_embed(x) |
| x = x.flatten(2).transpose(1, 2) |
| |
| |
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat([cls_tokens, x], dim=1) |
| |
| |
| x = x + self.pos_embed |
| x = self.pos_drop(x) |
| |
| |
| for block in self.blocks: |
| x = block(x) |
| |
| |
| x = self.norm(x) |
| |
| |
| return x[:, 0] |
| |
| def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass. |
| |
| Returns dict with: |
| - logits: classification logits |
| - features: CLS features (if return_features=True) |
| - similarities: raw similarities to pentachora |
| """ |
| features = self.forward_features(x) |
| |
| output = {} |
| |
| |
| features_proj = self.to_pentachora_dim(features) |
| |
| |
| if self.similarity_mode == 'rose': |
| |
| similarities = self.compute_pentachora_similarities(features_proj) |
| else: |
| |
| features_norm = F.normalize(features_proj, dim=-1) |
| centroids = self.get_class_centroids() |
| similarities = torch.matmul(features_norm, centroids.T) |
| |
| |
| logits = similarities * self.temperature.exp() |
| |
| output['logits'] = logits |
| output['similarities'] = similarities |
| |
| if return_features: |
| output['features'] = features |
| |
| return output |
|
|
|
|
| |
| if __name__ == "__main__": |
| print("BaselineViT requires:") |
| print(" 1. PentachoronStabilizer loaded externally") |
| print(" 2. pentachora_batch tensor [num_classes, 5, vocab_dim]") |
| print("\nNo random initialization. No fallbacks.") |