|
|
""" |
|
|
Baseline Vision Transformer with Frozen Pentachora Embeddings |
|
|
Now with optional theta rotation head for better classification |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from einops import rearrange |
|
|
import math |
|
|
from typing import Optional, Tuple, Dict, Any |
|
|
|
|
|
|
|
|
class PentachoraEmbedding(nn.Module): |
|
|
""" |
|
|
A single frozen pentachora embedding (5 vertices in geometric space). |
|
|
Now with theta rotation capabilities. |
|
|
""" |
|
|
|
|
|
def __init__(self, vertices: torch.Tensor): |
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = vertices.shape[-1] |
|
|
|
|
|
|
|
|
self.register_buffer('vertices', vertices.cpu().contiguous().detach().clone().to(get_default_device())) |
|
|
self.vertices.requires_grad = False |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
self.register_buffer('vertices_norm', F.normalize(self.vertices, dim=-1)) |
|
|
self.register_buffer('centroid', self.vertices.mean(dim=0)) |
|
|
self.register_buffer('centroid_norm', F.normalize(self.centroid, dim=-1)) |
|
|
|
|
|
|
|
|
self.register_buffer('theta_bases', self._compute_theta_bases().cpu().contiguous().detach().clone().to(get_default_device())) |
|
|
|
|
|
def _compute_theta_bases(self) -> torch.Tensor: |
|
|
"""Compute orthogonal bases from vertices for theta rotation.""" |
|
|
U, S, V = torch.svd(self.vertices) |
|
|
n_components = min(5, self.embed_dim) |
|
|
return V[:, :n_components] |
|
|
|
|
|
def get_vertices(self) -> torch.Tensor: |
|
|
return self.vertices |
|
|
|
|
|
def get_centroid(self) -> torch.Tensor: |
|
|
return self.centroid |
|
|
|
|
|
def compute_rose_score(self, features: torch.Tensor) -> torch.Tensor: |
|
|
verts = self.vertices.unsqueeze(0) |
|
|
if features.dim() == 1: |
|
|
features = features.unsqueeze(0) |
|
|
|
|
|
B = features.shape[0] |
|
|
if B > 1: |
|
|
verts = verts.expand(B, -1, -1) |
|
|
|
|
|
return PentachoronStabilizer.rose_score_magnitude(features, verts) |
|
|
|
|
|
def compute_similarity(self, features: torch.Tensor, mode: str = 'centroid') -> torch.Tensor: |
|
|
if mode == 'rose': |
|
|
return self.compute_rose_score(features) |
|
|
|
|
|
features_norm = F.normalize(features, dim=-1) |
|
|
|
|
|
if mode == 'centroid': |
|
|
return torch.matmul(features_norm, self.centroid_norm) |
|
|
else: |
|
|
sims = torch.matmul(features_norm, self.vertices_norm.T) |
|
|
return sims.max(dim=-1)[0] |
|
|
|
|
|
def compute_theta_features(self, features: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Project features to theta space defined by this pentachora. |
|
|
Returns angular features for feedforward classification. |
|
|
""" |
|
|
|
|
|
projections = torch.matmul(features, self.theta_bases) |
|
|
|
|
|
|
|
|
centroid_proj = torch.matmul(self.centroid.unsqueeze(0), self.theta_bases) |
|
|
angles = torch.atan2(projections, centroid_proj + 1e-8) |
|
|
|
|
|
|
|
|
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1).to(get_default_device()) |
|
|
|
|
|
|
|
|
class ThetaHead(nn.Module): |
|
|
""" |
|
|
Theta-based classification head using angular representations. |
|
|
Replaces similarity matching with learned feedforward. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int, |
|
|
num_classes: int, |
|
|
n_pentachora: int = 10, |
|
|
hidden_dim: int = 256, |
|
|
dropout: float = 0.1 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.n_pentachora = n_pentachora |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
|
|
|
theta_dim = n_pentachora * 10 |
|
|
|
|
|
|
|
|
self.to_theta = nn.Sequential( |
|
|
nn.Linear(embed_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(hidden_dim, theta_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(theta_dim), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(theta_dim, num_classes) |
|
|
) |
|
|
|
|
|
|
|
|
self.temperature = nn.Parameter(torch.ones(1) * 0.1) |
|
|
|
|
|
def forward(self, features: torch.Tensor, pentachora_list: nn.ModuleList) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Classify using theta rotation. |
|
|
|
|
|
Args: |
|
|
features: [batch, embed_dim] CLS features |
|
|
pentachora_list: List of PentachoraEmbedding modules |
|
|
""" |
|
|
|
|
|
theta_features = [] |
|
|
for i in range(min(self.n_pentachora, len(pentachora_list))): |
|
|
theta = pentachora_list[i].compute_theta_features(features) |
|
|
theta_features.append(theta) |
|
|
|
|
|
|
|
|
theta_concat = torch.cat(theta_features, dim=-1) |
|
|
|
|
|
|
|
|
if len(theta_features) < self.n_pentachora: |
|
|
pad_size = (self.n_pentachora - len(theta_features)) * 10 |
|
|
padding = torch.zeros(features.shape[0], pad_size, device=features.device) |
|
|
theta_concat = torch.cat([theta_concat, padding], dim=-1) |
|
|
|
|
|
|
|
|
theta_proj = self.to_theta(features) |
|
|
|
|
|
|
|
|
theta_combined = theta_concat + 0.1 * theta_proj |
|
|
|
|
|
|
|
|
logits = self.classifier(theta_combined) / self.temperature.exp() |
|
|
|
|
|
return { |
|
|
'logits': logits, |
|
|
'theta_features': theta_combined |
|
|
} |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
"""Standard transformer block with multi-head attention and MLP.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int = 8, |
|
|
mlp_ratio: float = 4.0, |
|
|
dropout: float = 0.0, |
|
|
attn_dropout: float = 0.0 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
dim, |
|
|
num_heads, |
|
|
dropout=attn_dropout, |
|
|
batch_first=True |
|
|
) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, mlp_hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(mlp_hidden_dim, dim), |
|
|
nn.Dropout(dropout) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x_norm = self.norm1(x) |
|
|
attn_out, _ = self.attn(x_norm, x_norm, x_norm) |
|
|
x = x + attn_out |
|
|
x = x + self.mlp(self.norm2(x)) |
|
|
return x |
|
|
|
|
|
|
|
|
class BaselineViT(nn.Module): |
|
|
""" |
|
|
Vision Transformer with optional theta-based classification. |
|
|
Can switch between similarity-based and theta-based heads. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
pentachora_list: list, |
|
|
vocab_dim: int = 256, |
|
|
img_size: int = 32, |
|
|
patch_size: int = 4, |
|
|
embed_dim: int = 512, |
|
|
depth: int = 12, |
|
|
num_heads: int = 8, |
|
|
mlp_ratio: float = 4.0, |
|
|
dropout: float = 0.0, |
|
|
attn_dropout: float = 0.0, |
|
|
similarity_mode: str = 'rose', |
|
|
use_theta_head: bool = True, |
|
|
theta_n_pentachora: int = 2, |
|
|
theta_hidden_dim: int = 256 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
assert isinstance(pentachora_list, list) and len(pentachora_list) > 0 |
|
|
|
|
|
self.num_classes = len(pentachora_list) |
|
|
self.embed_dim = embed_dim |
|
|
self.num_patches = (img_size // patch_size) ** 2 |
|
|
self.similarity_mode = similarity_mode |
|
|
self.pentachora_dim = vocab_dim |
|
|
self.use_theta_head = use_theta_head |
|
|
|
|
|
|
|
|
self.class_pentachora = nn.ModuleList([ |
|
|
PentachoraEmbedding(vertices=penta) |
|
|
for penta in pentachora_list |
|
|
]) |
|
|
|
|
|
|
|
|
self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.num_patches, embed_dim)) |
|
|
self.pos_drop = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
TransformerBlock( |
|
|
dim=embed_dim, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratio, |
|
|
dropout=dropout, |
|
|
attn_dropout=attn_dropout |
|
|
) |
|
|
for i in range(depth) |
|
|
]) |
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
|
|
|
if self.pentachora_dim != embed_dim: |
|
|
self.to_pentachora_dim = nn.Linear(embed_dim, self.pentachora_dim) |
|
|
else: |
|
|
self.to_pentachora_dim = nn.Identity() |
|
|
|
|
|
|
|
|
if use_theta_head: |
|
|
|
|
|
self.theta_head = ThetaHead( |
|
|
embed_dim=self.pentachora_dim, |
|
|
num_classes=self.num_classes, |
|
|
n_pentachora=theta_n_pentachora, |
|
|
hidden_dim=theta_hidden_dim, |
|
|
dropout=dropout |
|
|
) |
|
|
else: |
|
|
|
|
|
self.theta_head = None |
|
|
self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07)) |
|
|
|
|
|
self.register_buffer( |
|
|
'all_centroids', |
|
|
torch.stack([penta.centroid for penta in self.class_pentachora]) |
|
|
) |
|
|
self.register_buffer( |
|
|
'all_centroids_norm', |
|
|
F.normalize(self.all_centroids, dim=-1) |
|
|
) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self): |
|
|
nn.init.trunc_normal_(self.cls_token, std=0.02) |
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
|
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.ones_(m.weight) |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
def get_class_centroids(self) -> torch.Tensor: |
|
|
if self.use_theta_head: |
|
|
|
|
|
centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) |
|
|
return centroids |
|
|
else: |
|
|
return self.all_centroids_norm |
|
|
|
|
|
def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor: |
|
|
if self.similarity_mode == 'rose': |
|
|
all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) |
|
|
features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) |
|
|
return PentachoronStabilizer.rose_score_magnitude( |
|
|
features_exp.reshape(-1, self.pentachora_dim), |
|
|
all_vertices.repeat(features.shape[0], 1, 1) |
|
|
).reshape(features.shape[0], -1) |
|
|
else: |
|
|
centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) |
|
|
features_norm = F.normalize(features, dim=-1) |
|
|
return torch.matmul(features_norm, centroids.T) |
|
|
|
|
|
def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
|
|
B = x.shape[0] |
|
|
|
|
|
|
|
|
x = self.patch_embed(x) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
|
|
|
|
|
|
x = x + self.pos_embed |
|
|
x = self.pos_drop(x) |
|
|
|
|
|
|
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
|
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
|
|
|
return x[:, 0] |
|
|
|
|
|
def forward(self, x: torch.Tensor, return_features: bool = False) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass with optional theta head. |
|
|
""" |
|
|
features = self.forward_features(x) |
|
|
output = {} |
|
|
|
|
|
|
|
|
features_proj = self.to_pentachora_dim(features) |
|
|
|
|
|
if self.use_theta_head: |
|
|
|
|
|
theta_output = self.theta_head(features_proj, self.class_pentachora) |
|
|
output['logits'] = theta_output['logits'] |
|
|
output['theta_features'] = theta_output['theta_features'] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
similarities = self.compute_pentachora_similarities(features_proj) |
|
|
output['similarities'] = similarities |
|
|
else: |
|
|
|
|
|
similarities = self.compute_pentachora_similarities(features_proj) |
|
|
logits = similarities * self.temperature.exp() |
|
|
|
|
|
output['logits'] = logits |
|
|
output['similarities'] = similarities |
|
|
|
|
|
if return_features: |
|
|
output['features'] = features |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def enable_theta_head(model: BaselineViT, n_pentachora: int = 10, hidden_dim: int = 256): |
|
|
""" |
|
|
Convert an existing similarity-based model to use theta head. |
|
|
This modifies the model in-place. |
|
|
""" |
|
|
if model.use_theta_head: |
|
|
print("Model already using theta head") |
|
|
return model |
|
|
|
|
|
print(f"Converting to theta head with {n_pentachora} pentachora...") |
|
|
|
|
|
|
|
|
model.theta_head = ThetaHead( |
|
|
embed_dim=model.pentachora_dim, |
|
|
num_classes=model.num_classes, |
|
|
n_pentachora=n_pentachora, |
|
|
hidden_dim=hidden_dim, |
|
|
dropout=0.1 |
|
|
).to(next(model.parameters()).device) |
|
|
|
|
|
|
|
|
model.use_theta_head = True |
|
|
|
|
|
|
|
|
for m in model.theta_head.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
|
if m.bias is not None: |
|
|
nn.init.zeros_(m.bias) |
|
|
|
|
|
print("✓ Theta head enabled") |
|
|
return model |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("BaselineViT with optional theta head") |
|
|
print("Use 'use_theta_head=True' to enable theta classification") |
|
|
print("Or call enable_theta_head() on existing model") |