penta-vit-experiments / vit_zana_v3.py
AbstractPhil's picture
Added theta experimental head
7465a5c verified
raw
history blame
15.3 kB
"""
Baseline Vision Transformer with Frozen Pentachora Embeddings
Now with optional theta rotation head for better classification
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
import math
from typing import Optional, Tuple, Dict, Any
class PentachoraEmbedding(nn.Module):
"""
A single frozen pentachora embedding (5 vertices in geometric space).
Now with theta rotation capabilities.
"""
def __init__(self, vertices: torch.Tensor):
super().__init__()
self.embed_dim = vertices.shape[-1]
# Store provided vertices as frozen buffer
self.register_buffer('vertices', vertices.cpu().contiguous().detach().clone().to(get_default_device()))
self.vertices.requires_grad = False
# Precompute normalized versions and centroid
with torch.no_grad():
self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1))
self.register_buffer('centroid', self.vertices.mean(dim=0))
self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1))
# Compute theta bases for rotation
self.register_buffer('theta_bases', self._compute_theta_bases().cpu().contiguous().detach().clone().to(get_default_device()))
def _compute_theta_bases(self) -> torch.Tensor:
"""Compute orthogonal bases from vertices for theta rotation."""
U, S, V = torch.svd(self.vertices)
n_components = min(5, self.embed_dim)
return V[:, :n_components] # [embed_dim, n_components]
def get_vertices(self) -> torch.Tensor:
return self.vertices
def get_centroid(self) -> torch.Tensor:
return self.centroid
def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor:
verts = self.vertices.unsqueeze(0)
if features.dim() == 1:
features = features.unsqueeze(0)
B = features.shape[0]
if B > 1:
verts = verts.expand(B, -1, -1)
return PentachoronStabilizer.rose_score_magnitude(features, verts)
def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor:
if mode == 'rose':
return self.compute_rose_score(features)
features_norm = F.normalize(features, dim=-1)
if mode == 'centroid':
return torch.matmul(features_norm, self.centroid_norm)
else: # mode == 'max'
sims = torch.matmul(features_norm, self.vertices_norm.T)
return sims.max(dim=-1)[0]
def compute_theta_features(self, features: torch.Tensor) -> torch.Tensor:
"""
Project features to theta space defined by this pentachora.
Returns angular features for feedforward classification.
"""
# Project onto pentachora bases
projections = torch.matmul(features, self.theta_bases) # [batch, 5]
# Compute angles relative to centroid
centroid_proj = torch.matmul(self.centroid.unsqueeze(0), self.theta_bases)
angles = torch.atan2(projections, centroid_proj + 1e-8)
# Return sin/cos encoding
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).to(get_default_device()) # [batch, 10]
class ThetaHead(nn.Module):
"""
Theta-based classification head using angular representations.
Replaces similarity matching with learned feedforward.
"""
def __init__(
self,
embed_dim: int,
num_classes: int,
n_pentachora: int = 10, # Use subset of pentachora for theta
hidden_dim: int = 256,
dropout: float = 0.1
):
super().__init__()
self.n_pentachora = n_pentachora
self.embed_dim = embed_dim
# Each pentachora gives 10 theta features (5 sin + 5 cos)
theta_dim = n_pentachora * 10
# Project to theta space
self.to_theta = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, theta_dim)
)
# Classify from theta
self.classifier = nn.Sequential(
nn.LayerNorm(theta_dim),
nn.Dropout(dropout),
nn.Linear(theta_dim, num_classes)
)
# Learnable temperature
self.temperature = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, features: torch.Tensor, pentachora_list: nn.ModuleList) -> Dict[str, torch.Tensor]:
"""
Classify using theta rotation.
Args:
features: [batch, embed_dim] CLS features
pentachora_list: List of PentachoraEmbedding modules
"""
# Get theta features from first n pentachora
theta_features = []
for i in range(min(self.n_pentachora, len(pentachora_list))):
theta = pentachora_list[i].compute_theta_features(features)
theta_features.append(theta)
# Concatenate all theta features
theta_concat = torch.cat(theta_features, dim=-1) # [batch, n_pentachora * 10]
# If we have fewer pentachora than expected, pad with zeros
if len(theta_features) < self.n_pentachora:
pad_size = (self.n_pentachora - len(theta_features)) * 10
padding = torch.zeros(features.shape[0], pad_size, device=features.device)
theta_concat = torch.cat([theta_concat, padding], dim=-1)
# Project through MLP
theta_proj = self.to_theta(features)
# Combine with geometric theta (residual connection)
theta_combined = theta_concat + 0.1 * theta_proj
# Classify
logits = self.classifier(theta_combined) / self.temperature.exp()
return {
'logits': logits,
'theta_features': theta_combined
}
class TransformerBlock(nn.Module):
"""Standard transformer block with multi-head attention and MLP."""
def __init__(
self,
dim: int,
num_heads: int = 8,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attn_dropout: float = 0.0
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
dim,
num_heads,
dropout=attn_dropout,
batch_first=True
)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_norm = self.norm1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_out
x = x + self.mlp(self.norm2(x))
return x
class BaselineViT(nn.Module):
"""
Vision Transformer with optional theta-based classification.
Can switch between similarity-based and theta-based heads.
"""
def __init__(
self,
pentachora_list: list,
vocab_dim: int = 256,
img_size: int = 32,
patch_size: int = 4,
embed_dim: int = 512,
depth: int = 12,
num_heads: int = 8,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attn_dropout: float = 0.0,
similarity_mode: str = 'rose',
use_theta_head: bool = True, # NEW: Toggle theta head
theta_n_pentachora: int = 2, # NEW: How many pentachora for theta
theta_hidden_dim: int = 256 # NEW: Hidden dim for theta MLP
):
super().__init__()
assert isinstance(pentachora_list, list) and len(pentachora_list) > 0
self.num_classes = len(pentachora_list)
self.embed_dim = embed_dim
self.num_patches = (img_size // patch_size) ** 2
self.similarity_mode = similarity_mode
self.pentachora_dim = vocab_dim
self.use_theta_head = use_theta_head
# Create pentachora embeddings
self.class_pentachora = nn.ModuleList([
PentachoraEmbedding(vertices=penta)
for penta in pentachora_list
])
# Patch embedding
self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)
# CLS token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Position embeddings
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim))
self.pos_drop = nn.Dropout(dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
attn_dropout=attn_dropout
)
for i in range(depth)
])
# Final norm
self.norm = nn.LayerNorm(embed_dim)
# Project to pentachora dimension if needed
if self.pentachora_dim != embed_dim:
self.to_pentachora_dim = nn.Linear(embed_dim, self.pentachora_dim)
else:
self.to_pentachora_dim = nn.Identity()
# Classification heads
if use_theta_head:
# NEW: Theta-based classification
self.theta_head = ThetaHead(
embed_dim=self.pentachora_dim,
num_classes=self.num_classes,
n_pentachora=theta_n_pentachora,
hidden_dim=theta_hidden_dim,
dropout=dropout
)
else:
# Original: Similarity-based classification
self.theta_head = None
self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
self.register_buffer(
'all_centroids',
torch.stack([penta.centroid for penta in self.class_pentachora])
)
self.register_buffer(
'all_centroids_norm',
F.normalize(self.all_centroids, dim=-1)
)
self.init_weights()
def init_weights(self):
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def get_class_centroids(self) -> torch.Tensor:
if self.use_theta_head:
# Return centroids from pentachora for compatibility
centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora])
return centroids
else:
return self.all_centroids_norm
def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
if self.similarity_mode == 'rose':
all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora])
features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1)
return PentachoronStabilizer.rose_score_magnitude(
features_exp.reshape(-1, self.pentachora_dim),
all_vertices.repeat(features.shape[0], 1, 1)
).reshape(features.shape[0], -1)
else:
centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora])
features_norm = F.normalize(features, dim=-1)
return torch.matmul(features_norm, centroids.T)
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
B = x.shape[0]
# Patch embedding
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
# Add CLS token
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Add position embeddings
x = x + self.pos_embed
x = self.pos_drop(x)
# Apply transformer blocks
for block in self.blocks:
x = block(x)
# Final norm
x = self.norm(x)
# Return CLS token
return x[:, 0]
def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]:
"""
Forward pass with optional theta head.
"""
features = self.forward_features(x)
output = {}
# Project to pentachora dimension
features_proj = self.to_pentachora_dim(features)
if self.use_theta_head:
# NEW: Use theta-based classification
theta_output = self.theta_head(features_proj, self.class_pentachora)
output['logits'] = theta_output['logits']
output['theta_features'] = theta_output['theta_features']
# Still compute similarities for analysis
with torch.no_grad():
similarities = self.compute_pentachora_similarities(features_proj)
output['similarities'] = similarities
else:
# Original: Use similarity-based classification
similarities = self.compute_pentachora_similarities(features_proj)
logits = similarities * self.temperature.exp()
output['logits'] = logits
output['similarities'] = similarities
if return_features:
output['features'] = features
return output
# Helper function to convert existing model to theta
def enable_theta_head(model: BaselineViT, n_pentachora: int = 10, hidden_dim: int = 256):
"""
Convert an existing similarity-based model to use theta head.
This modifies the model in-place.
"""
if model.use_theta_head:
print("Model already using theta head")
return model
print(f"Converting to theta head with {n_pentachora} pentachora...")
# Create theta head
model.theta_head = ThetaHead(
embed_dim=model.pentachora_dim,
num_classes=model.num_classes,
n_pentachora=n_pentachora,
hidden_dim=hidden_dim,
dropout=0.1
).to(next(model.parameters()).device)
# Set flag
model.use_theta_head = True
# Initialize new parameters
for m in model.theta_head.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
print("✓ Theta head enabled")
return model
if __name__ == "__main__":
print("BaselineViT with optional theta head")
print("Use 'use_theta_head=True' to enable theta classification")
print("Or call enable_theta_head() on existing model")