| | """ |
| | PentachoraViT: Vision Transformer with Pentachoron Geometric Structure |
| | Enhanced with Geometric Attention for improved head cohesion and generalization |
| | FIXED: CLS tokens now properly reference and utilize vocabulary embeddings |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | from einops import rearrange, repeat |
| | import math |
| | from typing import Optional, Dict, Tuple, List, Any |
| | from dataclasses import dataclass |
| | import warnings |
| |
|
| | |
| | |
| | |
| |
|
| | @dataclass |
| | class PentachoraConfig: |
| | """Configuration for PentachoraViT models.""" |
| | img_size: int = 32 |
| | patch_size: int = 4 |
| | num_classes: int = 100 |
| | dim: int = 512 |
| | vocab_dim: Optional[int] = None |
| | depth: int = 12 |
| | heads: int = 8 |
| | mlp_ratio: float = 4.0 |
| | use_mesh_attention: bool = True |
| | preserve_structure_until_layer: int = 6 |
| | dropout_rate: float = 0.1 |
| | drop_path_rate: float = 0.1 |
| | aux_loss_weight: float = 0.3 |
| | geo_loss_weight: float = 0.1 |
| | vocab: Optional[Any] = None |
| |
|
| | @property |
| | def num_patches(self) -> int: |
| | return (self.img_size // self.patch_size) ** 2 |
| |
|
| | |
| | |
| | |
| |
|
| | def perfect_4simplex(device): |
| | """Create perfect 4-simplex (pentachoron) vertices in 4D.""" |
| | sqrt5 = math.sqrt(5) |
| | vertices = torch.tensor([ |
| | [1, 1, 1, -1/sqrt5], |
| | [1, -1, -1, -1/sqrt5], |
| | [-1, 1, -1, -1/sqrt5], |
| | [-1, -1, 1, -1/sqrt5], |
| | [0, 0, 0, 4/sqrt5] |
| | ], device=device, dtype=torch.float32) |
| | return vertices / 2 |
| |
|
| | def softmin_over_last(distances, tau): |
| | """Softmin over last dimension.""" |
| | return F.softmax(-distances / tau, dim=-1).sum(dim=-1) |
| |
|
| | @dataclass |
| | class GeometricConfig: |
| | """Configuration for geometric attention.""" |
| | softmin_tau: float = 0.05 |
| | fuse_alpha: float = 0.7 |
| | phases: Tuple[float, ...] = (0.0, math.pi/2, math.pi, 3*math.pi/2) |
| | jitter: float = 0.02 |
| | shift: float = 0.25 |
| | rotate_cycle: int = 11 |
| | use_phase_variance: bool = False |
| | geometry_type: str = "pentachoron" |
| |
|
| | class GeometricNavigator(nn.Module): |
| | """Maps inputs to geometric regions in 4D space.""" |
| |
|
| | def __init__(self, input_dim: int, num_regions: int, config: GeometricConfig): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.num_regions = num_regions |
| | self.config = config |
| |
|
| | self.to_nav = nn.Linear(input_dim, 4, bias=False) |
| | self.vertex_w = nn.Parameter(torch.zeros(num_regions, 5)) |
| |
|
| | |
| | self.register_parameter('D', None) |
| | self.register_parameter('S', None) |
| |
|
| | def _lazy_init_geometry(self, device): |
| | """Initialize geometry on first forward pass.""" |
| | if self.D is not None: |
| | return |
| |
|
| | base = perfect_4simplex(device) |
| |
|
| | D = torch.zeros(self.num_regions, 5, 4, device=device) |
| | S = torch.zeros(self.num_regions, 5, 4, device=device) |
| |
|
| | for r in range(self.num_regions): |
| | D[r] = base + self.config.jitter * torch.randn_like(base) |
| |
|
| | theta = torch.tensor(0.27 + 0.05 * (r % self.config.rotate_cycle), device=device) |
| | rot = torch.eye(4, device=device) |
| | c, s_val = torch.cos(theta), torch.sin(theta) |
| | rot[0, 0] = c; rot[0, 1] = -s_val |
| | rot[1, 0] = s_val; rot[1, 1] = c |
| | S[r] = (base @ rot) + self.config.shift |
| | S[r] += self.config.jitter * torch.randn_like(S[r]) |
| |
|
| | self.D = nn.Parameter(D) |
| | self.S = nn.Parameter(S) |
| |
|
| | def navigate(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | """Navigate inputs through geometric space.""" |
| | self._lazy_init_geometry(x.device) |
| |
|
| | nav_x = self.to_nav(x) |
| | nav_x_exp = nav_x[:, None, None, :] |
| | D_exp = self.D[None, :, :, :] |
| |
|
| | d_disp = torch.norm(nav_x_exp - D_exp, dim=-1) |
| | s_disp = -softmin_over_last(d_disp, self.config.softmin_tau) |
| |
|
| | w = F.softmax(self.vertex_w, dim=1) |
| | phase_scores = [] |
| |
|
| | for phase in self.config.phases: |
| | phase_tensor = torch.tensor(phase, device=x.device) |
| | ct = torch.cos(phase_tensor) |
| | st = torch.sin(phase_tensor) |
| |
|
| | Vt = ct * self.D + st * self.S |
| | w_expanded = w.unsqueeze(-1) |
| | Vt_mean = Vt.mean(dim=1, keepdim=True) |
| | Vt = (1.0 - w_expanded) * Vt + w_expanded * Vt_mean |
| |
|
| | Vt_exp = Vt[None, :, :, :] |
| | d_ribbon = torch.norm(nav_x_exp - Vt_exp, dim=-1) |
| | s_ribbon = -softmin_over_last(d_ribbon, self.config.softmin_tau) |
| | phase_scores.append(s_ribbon) |
| |
|
| | s_ribbon = torch.stack(phase_scores).mean(dim=0) |
| | scores = self.config.fuse_alpha * s_ribbon + (1 - self.config.fuse_alpha) * s_disp |
| |
|
| | diagnostics = { |
| | 'dispatcher_scores': s_disp.detach(), |
| | 'ribbon_scores': s_ribbon.detach() |
| | } |
| |
|
| | return {'scores': scores, 'diagnostics': diagnostics} |
| |
|
| | class GeometricAttention(nn.Module): |
| | """Multi-head geometric attention with Q-K alignment.""" |
| |
|
| | def __init__(self, dim: int, num_heads: int = 8, num_regions: Optional[int] = None, |
| | config: Optional[GeometricConfig] = None, dropout: float = 0.0): |
| | super().__init__() |
| | self.dim = dim |
| | self.num_heads = num_heads |
| | self.head_dim = dim // num_heads |
| |
|
| | if num_regions is None: |
| | num_regions = min(self.head_dim, 16) |
| | if config is None: |
| | config = GeometricConfig() |
| |
|
| | self.config = config |
| | self.to_qkv = nn.Linear(dim, dim * 3, bias=False) |
| |
|
| | self.q_navigators = nn.ModuleList([ |
| | GeometricNavigator(self.head_dim, num_regions, config) |
| | for _ in range(num_heads) |
| | ]) |
| | self.k_navigators = nn.ModuleList([ |
| | GeometricNavigator(self.head_dim, num_regions, config) |
| | for _ in range(num_heads) |
| | ]) |
| |
|
| | self.out_proj = nn.Linear(dim, dim) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, |
| | return_diagnostics: bool = False) -> Tuple[torch.Tensor, Optional[Dict]]: |
| | B, T, D = x.shape |
| |
|
| | qkv = self.to_qkv(x) |
| | q, k, v = qkv.chunk(3, dim=-1) |
| |
|
| | q = q.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| | k = k.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| | v = v.reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | outputs = [] |
| | all_diagnostics = [] if return_diagnostics else None |
| |
|
| | for h in range(self.num_heads): |
| | q_h_flat = q[:, h].reshape(B * T, self.head_dim) |
| | k_h_flat = k[:, h].reshape(B * T, self.head_dim) |
| |
|
| | q_nav = self.q_navigators[h].navigate(q_h_flat) |
| | k_nav = self.k_navigators[h].navigate(k_h_flat) |
| |
|
| | q_scores = q_nav['scores'].reshape(B, T, -1) |
| | k_scores = k_nav['scores'].reshape(B, T, -1) |
| |
|
| | attn = torch.bmm(q_scores, k_scores.transpose(1, 2)) |
| | attn = attn / math.sqrt(q_scores.size(-1)) |
| |
|
| | if mask is not None: |
| | attn = attn.masked_fill(mask.unsqueeze(1) == 0, -1e9) |
| |
|
| | attn = F.softmax(attn, dim=-1) |
| | attn = self.dropout(attn) |
| |
|
| | out = torch.bmm(attn, v[:, h]) |
| | outputs.append(out) |
| |
|
| | if return_diagnostics: |
| | all_diagnostics.append({'q': q_nav['diagnostics'], 'k': k_nav['diagnostics']}) |
| |
|
| | output = torch.stack(outputs, dim=1).transpose(1, 2).reshape(B, T, D) |
| | output = self.out_proj(output) |
| | output = self.dropout(output) |
| |
|
| | if return_diagnostics: |
| | return output, {'head_diagnostics': all_diagnostics} |
| | return output, None |
| |
|
| | |
| | |
| | |
| |
|
| | class DropPath(nn.Module): |
| | """Drop paths (Stochastic Depth) per sample.""" |
| | def __init__(self, drop_prob: float = 0.): |
| | super().__init__() |
| | self.drop_prob = drop_prob |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | if self.drop_prob == 0. or not self.training: |
| | return x |
| | keep_prob = 1 - self.drop_prob |
| | shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
| | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
| | random_tensor.floor_() |
| | output = x.div(keep_prob) * random_tensor |
| | return output |
| |
|
| | |
| | |
| | |
| |
|
| | class HierarchicalPentachoronCLS(nn.Module): |
| | """ |
| | Hierarchical CLS structure with pentachoron geometry. |
| | FIXED: Now properly uses vocabulary embeddings for CLS tokens. |
| | """ |
| | def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100): |
| | super().__init__() |
| | self.dim = dim |
| | self.vocab_dim = vocab_dim |
| | self.num_classes = num_classes |
| |
|
| | |
| | self.class_pentachora = nn.Parameter(torch.randn(num_classes, 5, vocab_dim) * 0.02) |
| | |
| | |
| | if vocab_dim != dim: |
| | self.vocab_to_model = nn.Linear(vocab_dim, dim) |
| | else: |
| | self.vocab_to_model = nn.Identity() |
| | |
| | |
| | self.vertex_weights = nn.Parameter(torch.ones(5) / 5) |
| | |
| | |
| | self.global_offset = nn.Parameter(torch.zeros(1, 1, dim)) |
| | |
| | |
| | self.vertex_norm = nn.LayerNorm(dim) |
| | self.global_norm = nn.LayerNorm(dim) |
| |
|
| | def forward(self, batch_size: int, class_indices: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Generate CLS tokens for batch. |
| | |
| | Args: |
| | batch_size: Batch size |
| | class_indices: Optional class indices for class-specific initialization |
| | |
| | Returns: |
| | global_cls: [B, 1, D] - Global CLS tokens |
| | vertex_cls: [B, 5, D] - Vertex CLS tokens |
| | """ |
| | if class_indices is not None and class_indices.shape[0] == batch_size: |
| | |
| | |
| | vertex_cls_vocab = self.class_pentachora[class_indices] |
| | else: |
| | |
| | |
| | vertex_cls_vocab = self.class_pentachora.mean(dim=0, keepdim=True) |
| | vertex_cls_vocab = vertex_cls_vocab.expand(batch_size, -1, -1) |
| | |
| | |
| | vertex_cls = self.vocab_to_model(vertex_cls_vocab) |
| | vertex_cls = self.vertex_norm(vertex_cls) |
| | |
| | |
| | weights = F.softmax(self.vertex_weights, dim=0) |
| | global_cls = torch.einsum('bvd,v->bd', vertex_cls, weights).unsqueeze(1) |
| | global_cls = global_cls + self.global_offset |
| | global_cls = self.global_norm(global_cls) |
| | |
| | return global_cls, vertex_cls |
| |
|
| | def get_class_prototypes(self) -> torch.Tensor: |
| | """ |
| | Get class prototypes in model dimension. |
| | |
| | Returns: |
| | prototypes: [num_classes, dim] - Class prototype vectors |
| | """ |
| | |
| | pentachora_model = self.vocab_to_model(self.class_pentachora) |
| | |
| | |
| | weights = F.softmax(self.vertex_weights, dim=0) |
| | prototypes = torch.einsum('cvd,v->cd', pentachora_model, weights) |
| | |
| | return prototypes |
| |
|
| | |
| | |
| | |
| |
|
| | class GeometricProjection(nn.Module): |
| | """ |
| | Project patches onto pentachoron geometry. |
| | ENHANCED: Now provides better integration with vocabulary. |
| | """ |
| | def __init__(self, dim: int, vocab_dim: int, num_classes: int = 100, dropout: float = 0.1): |
| | super().__init__() |
| | self.dim = dim |
| | self.vocab_dim = vocab_dim |
| | self.num_classes = num_classes |
| |
|
| | |
| | self.to_vocab_space = nn.Linear(dim, vocab_dim) |
| | |
| | |
| | self.vertex_projections = nn.ModuleList([ |
| | nn.Linear(vocab_dim, vocab_dim, bias=False) for _ in range(5) |
| | ]) |
| | |
| | |
| | self.temperature = nn.Parameter(torch.ones(1)) |
| |
|
| | self.norm = nn.LayerNorm(dim) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, patches: torch.Tensor, pentachora: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Compute alignment between patches and class pentachora. |
| | |
| | Args: |
| | patches: [B, N, D] - patch embeddings in model dimension |
| | pentachora: [C, 5, vocab_dim] - class pentachora in vocabulary dimension |
| | |
| | Returns: |
| | [B, N, C] - alignment scores |
| | """ |
| | B, N, D = patches.shape |
| | C = pentachora.shape[0] |
| |
|
| | |
| | patches = self.norm(patches) |
| | |
| | |
| | patches_vocab = self.to_vocab_space(patches) |
| | patches_vocab = F.normalize(patches_vocab, dim=-1) |
| |
|
| | |
| | alignments = [] |
| | for v in range(5): |
| | |
| | patches_v = self.vertex_projections[v](patches_vocab) |
| | patches_v = F.normalize(patches_v, dim=-1) |
| | |
| | |
| | vertex_v = F.normalize(pentachora[:, v, :], dim=-1) |
| | |
| | |
| | alignment = torch.matmul(patches_v, vertex_v.T) / self.temperature |
| | alignments.append(alignment) |
| |
|
| | |
| | alignments = torch.stack(alignments, dim=-1).mean(dim=-1) |
| |
|
| | return self.dropout(alignments) |
| |
|
| | |
| | |
| | |
| |
|
| | class MLP(nn.Module): |
| | """MLP block with GELU activation.""" |
| | def __init__(self, in_features: int, hidden_features: Optional[int] = None, |
| | out_features: Optional[int] = None, dropout: float = 0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| |
|
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = nn.GELU() |
| | self.drop1 = nn.Dropout(dropout) |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop2 = nn.Dropout(dropout) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop1(x) |
| | x = self.fc2(x) |
| | x = self.drop2(x) |
| | return x |
| |
|
| | |
| | |
| | |
| |
|
| | class PentachoronViTBlock(nn.Module): |
| | """ViT block with geometric attention for structured layers.""" |
| | def __init__(self, dim: int, heads: int = 8, mlp_ratio: float = 4.0, |
| | use_mesh: bool = True, dropout: float = 0., attn_dropout: float = 0., |
| | drop_path: float = 0.): |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(dim) |
| |
|
| | |
| | if use_mesh: |
| | self.attn = GeometricAttention( |
| | dim=dim, |
| | num_heads=heads, |
| | num_regions=min(dim // heads, 16), |
| | config=GeometricConfig(), |
| | dropout=attn_dropout |
| | ) |
| | else: |
| | |
| | self.attn = nn.MultiheadAttention(dim, heads, dropout=attn_dropout, batch_first=True) |
| |
|
| | self.use_mesh = use_mesh |
| | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| |
|
| | self.norm2 = nn.LayerNorm(dim) |
| | mlp_hidden = int(dim * mlp_ratio) |
| | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden, dropout=dropout) |
| | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| |
|
| | def forward(self, x: torch.Tensor, preserve_structure: bool = True) -> torch.Tensor: |
| | if self.use_mesh: |
| | |
| | attn_out, _ = self.attn(self.norm1(x)) |
| | x = x + self.drop_path1(attn_out) |
| | else: |
| | |
| | normalized = self.norm1(x) |
| | attn_out, _ = self.attn(normalized, normalized, normalized) |
| | x = x + self.drop_path1(attn_out) |
| |
|
| | x = x + self.drop_path2(self.mlp(self.norm2(x))) |
| | return x |
| |
|
| | |
| | |
| | |
| |
|
| | class PatchEmbed(nn.Module): |
| | """2D Image to Patch Embedding.""" |
| | def __init__(self, img_size: int = 32, patch_size: int = 4, |
| | in_chans: int = 3, embed_dim: int = 512): |
| | super().__init__() |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.num_patches = (img_size // patch_size) ** 2 |
| |
|
| | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
| | self.norm = nn.LayerNorm(embed_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.proj(x) |
| | x = rearrange(x, 'b c h w -> b (h w) c') |
| | x = self.norm(x) |
| | return x |
| |
|
| | |
| | |
| | |
| |
|
| | class PentachoraViT(nn.Module): |
| | """ |
| | Vision Transformer with pentachoron-based hierarchical CLS tokens |
| | and geometric vocabulary integration. |
| | FIXED: CLS tokens now properly reference vocabulary embeddings. |
| | """ |
| | def __init__(self, config: Optional[PentachoraConfig] = None, **kwargs): |
| | super().__init__() |
| |
|
| | |
| | if config is not None: |
| | cfg = config |
| | else: |
| | cfg = PentachoraConfig(**kwargs) |
| |
|
| | self.config = cfg |
| | self.num_classes = cfg.num_classes |
| | self.dim = cfg.dim |
| | self.depth = cfg.depth |
| | self.preserve_structure_until_layer = cfg.preserve_structure_until_layer |
| |
|
| | |
| | if cfg.vocab_dim is not None: |
| | self.vocab_dim = cfg.vocab_dim |
| | elif 'vocab_dim' in kwargs: |
| | self.vocab_dim = kwargs['vocab_dim'] |
| | else: |
| | self.vocab_dim = cfg.dim |
| |
|
| | |
| | self.patch_embed = PatchEmbed( |
| | cfg.img_size, cfg.patch_size, 3, cfg.dim |
| | ) |
| | num_patches = self.patch_embed.num_patches |
| |
|
| | |
| | self.pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.dim) * 0.02) |
| | self.pos_drop = nn.Dropout(cfg.dropout_rate) |
| |
|
| | |
| | self.cls_tokens = HierarchicalPentachoronCLS(cfg.dim, self.vocab_dim, cfg.num_classes) |
| |
|
| | |
| | self.geometric_proj = GeometricProjection(cfg.dim, self.vocab_dim, cfg.num_classes, cfg.dropout_rate) |
| |
|
| | |
| | if cfg.vocab is not None: |
| | self._init_from_vocab(cfg.vocab) |
| |
|
| | |
| | dpr = [x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth)] |
| |
|
| | |
| | self.blocks = nn.ModuleList([ |
| | PentachoronViTBlock( |
| | dim=cfg.dim, |
| | heads=cfg.heads, |
| | mlp_ratio=cfg.mlp_ratio, |
| | use_mesh=(cfg.use_mesh_attention and i < cfg.preserve_structure_until_layer), |
| | dropout=cfg.dropout_rate, |
| | attn_dropout=cfg.dropout_rate, |
| | drop_path=dpr[i] |
| | ) |
| | for i in range(cfg.depth) |
| | ]) |
| |
|
| | |
| | self.norm = nn.LayerNorm(cfg.dim) |
| |
|
| | |
| | |
| | self.use_prototype_classifier = True |
| | if self.use_prototype_classifier: |
| | |
| | self.head = None |
| | else: |
| | |
| | self.head = nn.Linear(cfg.dim, cfg.num_classes) |
| | |
| | |
| | self.head_aux = nn.Linear(cfg.dim * 5, cfg.num_classes) |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, m: nn.Module): |
| | """Initialize model weights.""" |
| | if isinstance(m, nn.Linear): |
| | nn.init.trunc_normal_(m.weight, std=0.02) |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| | elif isinstance(m, nn.Conv2d): |
| | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| | if m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def _init_from_vocab(self, vocab): |
| | """Initialize class pentachora from geometric vocabulary.""" |
| | try: |
| | print("Initializing pentachora from vocabulary...") |
| |
|
| | if not hasattr(vocab, 'encode_batch'): |
| | print("Vocabulary provided but encode_batch method not found, using random initialization") |
| | return |
| |
|
| | |
| | class_names = self._get_cifar100_classes() |
| |
|
| | |
| | pentachora_list = vocab.encode_batch(class_names[:self.num_classes], generate=True) |
| | pentachora = np.stack(pentachora_list, axis=0) |
| |
|
| | |
| | actual_vocab_dim = pentachora.shape[-1] |
| |
|
| | print(f"Encoded pentachora shape: {pentachora.shape}") |
| | print(f"Detected vocabulary dimension: {actual_vocab_dim}") |
| |
|
| | |
| | if pentachora.shape[0] != self.num_classes or pentachora.shape[1] != 5: |
| | print(f"Invalid shape: expected ({self.num_classes}, 5, ?), got {pentachora.shape}") |
| | print("Using random initialization") |
| | return |
| |
|
| | |
| | self.vocab_dim = actual_vocab_dim |
| | self.cls_tokens.vocab_dim = actual_vocab_dim |
| | self.geometric_proj.vocab_dim = actual_vocab_dim |
| |
|
| | |
| | self.cls_tokens.class_pentachora = nn.Parameter( |
| | torch.tensor(pentachora, dtype=torch.float32) |
| | ) |
| |
|
| | |
| | if actual_vocab_dim != self.dim: |
| | self.cls_tokens.vocab_to_model = nn.Linear(actual_vocab_dim, self.dim) |
| | else: |
| | self.cls_tokens.vocab_to_model = nn.Identity() |
| |
|
| | |
| | self.geometric_proj.to_vocab_space = nn.Linear(self.dim, actual_vocab_dim) |
| | self.geometric_proj.vertex_projections = nn.ModuleList([ |
| | nn.Linear(actual_vocab_dim, actual_vocab_dim, bias=False) for _ in range(5) |
| | ]) |
| |
|
| | |
| | nn.init.xavier_uniform_(self.geometric_proj.to_vocab_space.weight) |
| | for proj in self.geometric_proj.vertex_projections: |
| | nn.init.xavier_uniform_(proj.weight) |
| | if actual_vocab_dim != self.dim: |
| | nn.init.xavier_uniform_(self.cls_tokens.vocab_to_model.weight) |
| |
|
| | print(f"✓ Successfully initialized {self.num_classes} class pentachora from vocabulary") |
| | print(f" Vocabulary dimension: {actual_vocab_dim}") |
| | print(f" Model internal dimension: {self.dim}") |
| | print(f" CLS tokens now reference vocabulary embeddings") |
| |
|
| | except Exception as e: |
| | print(f"Error initializing from vocabulary: {e}") |
| | print("Using random initialization") |
| |
|
| | def _get_cifar100_classes(self): |
| | """Get CIFAR-100 class names.""" |
| | return [ |
| | 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', |
| | 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', |
| | 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', |
| | 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', |
| | 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', |
| | 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', |
| | 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', |
| | 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', |
| | 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', |
| | 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', |
| | 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', |
| | 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', |
| | 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', |
| | 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' |
| | ] |
| |
|
| | def forward_features(self, x: torch.Tensor, class_indices: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: |
| | """ |
| | Extract features from input. |
| | |
| | Args: |
| | x: Input images [B, 3, H, W] |
| | class_indices: Optional class indices for class-aware CLS tokens [B] |
| | """ |
| | B = x.shape[0] |
| |
|
| | |
| | x = self.patch_embed(x) |
| | x = x + self.pos_embed |
| | x = self.pos_drop(x) |
| |
|
| | |
| | global_cls, vertex_cls = self.cls_tokens(B, class_indices) |
| |
|
| | |
| | x = torch.cat([global_cls, vertex_cls, x], dim=1) |
| |
|
| | |
| | for i, block in enumerate(self.blocks): |
| | preserve = i < self.preserve_structure_until_layer |
| | x = block(x, preserve_structure=preserve) |
| |
|
| | |
| | x = self.norm(x) |
| |
|
| | |
| | global_cls = x[:, 0] |
| | vertex_cls = x[:, 1:6] |
| | patches = x[:, 6:] |
| |
|
| | return { |
| | 'global_cls': global_cls, |
| | 'vertex_cls': vertex_cls, |
| | 'patches': patches |
| | } |
| |
|
| | def forward(self, x: torch.Tensor, targets: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: |
| | """ |
| | Forward pass through the model. |
| | |
| | Args: |
| | x: Input images [B, 3, H, W] |
| | targets: Optional target labels for class-aware processing [B] |
| | """ |
| | |
| | class_indices = targets if self.training and targets is not None else None |
| | |
| | features = self.forward_features(x, class_indices) |
| |
|
| | |
| | if self.use_prototype_classifier: |
| | |
| | prototypes = self.cls_tokens.get_class_prototypes() |
| | prototypes = F.normalize(prototypes, dim=-1) |
| | |
| | |
| | global_cls_norm = F.normalize(features['global_cls'], dim=-1) |
| | |
| | |
| | logits = torch.matmul(global_cls_norm, prototypes.T) * 20.0 |
| | else: |
| | |
| | logits = self.head(features['global_cls']) |
| |
|
| | |
| | B = features['vertex_cls'].shape[0] |
| | vertex_flat = features['vertex_cls'].reshape(B, -1) |
| | aux_logits = self.head_aux(vertex_flat) |
| |
|
| | |
| | geometric_alignments = self.geometric_proj( |
| | features['patches'], |
| | self.cls_tokens.class_pentachora |
| | ) |
| |
|
| | return { |
| | 'logits': logits, |
| | 'aux_logits': aux_logits, |
| | 'geometric_alignments': geometric_alignments, |
| | 'vertex_cls': features['vertex_cls'], |
| | 'global_cls': features['global_cls'], |
| | 'patches': features['patches'] |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | class PentachoraLoss(nn.Module): |
| | """Combined loss for PentachoraViT training.""" |
| | def __init__(self, aux_weight: float = 0.3, geo_weight: float = 0.1, |
| | smoothing: float = 0.0): |
| | super().__init__() |
| | self.aux_weight = aux_weight |
| | self.geo_weight = geo_weight |
| | self.criterion = nn.CrossEntropyLoss(label_smoothing=smoothing) |
| |
|
| | def forward(self, outputs: Dict[str, torch.Tensor], targets: torch.Tensor) -> torch.Tensor: |
| | """Compute combined loss.""" |
| | |
| | loss = self.criterion(outputs['logits'], targets) |
| |
|
| | |
| | if 'aux_logits' in outputs and self.aux_weight > 0: |
| | aux_loss = self.criterion(outputs['aux_logits'], targets) |
| | loss = loss + self.aux_weight * aux_loss |
| |
|
| | |
| | if 'geometric_alignments' in outputs and self.geo_weight > 0: |
| | |
| | geo_logits = outputs['geometric_alignments'].mean(dim=1) |
| | geo_loss = self.criterion(geo_logits, targets) |
| | loss = loss + self.geo_weight * geo_loss |
| |
|
| | return loss |
| |
|
| | |
| | |
| | |
| |
|
| | MODEL_CONFIGS = { |
| | 'pentachora_spark': PentachoraConfig( |
| | dim=100, depth=5, heads=4, mlp_ratio=4.0, |
| | preserve_structure_until_layer=1, |
| | dropout_rate=0.0, drop_path_rate=0.0 |
| | ), |
| | 'pentachora_tiny': PentachoraConfig( |
| | dim=384, depth=12, heads=6, mlp_ratio=4.0, |
| | preserve_structure_until_layer=6, |
| | dropout_rate=0.1, drop_path_rate=0.1 |
| | ), |
| | 'pentachora_small': PentachoraConfig( |
| | dim=512, depth=12, heads=8, mlp_ratio=4.0, |
| | preserve_structure_until_layer=6, |
| | dropout_rate=0.1, drop_path_rate=0.1 |
| | ), |
| | 'pentachora_base': PentachoraConfig( |
| | dim=768, depth=12, heads=12, mlp_ratio=4.0, |
| | preserve_structure_until_layer=8, |
| | dropout_rate=0.1, drop_path_rate=0.2 |
| | ), |
| | 'pentachora_large': PentachoraConfig( |
| | dim=1024, depth=24, heads=16, mlp_ratio=4.0, |
| | preserve_structure_until_layer=12, |
| | dropout_rate=0.1, drop_path_rate=0.3 |
| | ), |
| | } |
| |
|
| | def create_pentachora_vit(variant: str = 'pentachora_small', |
| | pretrained: bool = False, |
| | **kwargs) -> PentachoraViT: |
| | """ |
| | Create PentachoraViT model. |
| | |
| | Args: |
| | variant: Model variant name |
| | pretrained: Whether to load pretrained weights |
| | **kwargs: Override config parameters (including vocab_dim) |
| | |
| | Returns: |
| | PentachoraViT model |
| | """ |
| | if variant not in MODEL_CONFIGS: |
| | raise ValueError(f"Unknown variant: {variant}. Choose from {list(MODEL_CONFIGS.keys())}") |
| |
|
| | config = MODEL_CONFIGS[variant] |
| |
|
| | |
| | for key, value in kwargs.items(): |
| | setattr(config, key, value) |
| |
|
| | model = PentachoraViT(config) |
| |
|
| | if pretrained: |
| | warnings.warn("Pretrained weights not available yet") |
| |
|
| | return model |
| |
|
| | |
| | def pentachora_vit_spark(pretrained: bool = False, **kwargs) -> PentachoraViT: |
| | """Create spark variant (smallest).""" |
| | return create_pentachora_vit('pentachora_spark', pretrained=pretrained, **kwargs) |
| |
|
| | def pentachora_vit_tiny(pretrained: bool = False, **kwargs) -> PentachoraViT: |
| | """Create tiny variant.""" |
| | return create_pentachora_vit('pentachora_tiny', pretrained=pretrained, **kwargs) |
| |
|
| | def pentachora_vit_small(pretrained: bool = False, **kwargs) -> PentachoraViT: |
| | """Create small variant.""" |
| | return create_pentachora_vit('pentachora_small', pretrained=pretrained, **kwargs) |
| |
|
| | def pentachora_vit_base(pretrained: bool = False, **kwargs) -> PentachoraViT: |
| | """Create base variant.""" |
| | return create_pentachora_vit('pentachora_base', pretrained=pretrained, **kwargs) |
| |
|
| | def pentachora_vit_large(pretrained: bool = False, **kwargs) -> PentachoraViT: |
| | """Create large variant.""" |
| | return create_pentachora_vit('pentachora_large', pretrained=pretrained, **kwargs) |
| |
|
| | |
| | |
| | |
| |
|
| | def get_parameter_groups(model: PentachoraViT, |
| | weight_decay: float = 0.05) -> List[Dict[str, Any]]: |
| | """ |
| | Get parameter groups for optimizer with weight decay handling. |
| | |
| | Args: |
| | model: PentachoraViT model |
| | weight_decay: Weight decay value |
| | |
| | Returns: |
| | List of parameter group dictionaries |
| | """ |
| | no_decay = ['bias', 'norm', 'LayerNorm'] |
| | |
| | decay_params = [] |
| | no_decay_params = [] |
| | |
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | |
| | if any(nd in name for nd in no_decay): |
| | no_decay_params.append(param) |
| | else: |
| | decay_params.append(param) |
| | |
| | return [ |
| | {'params': decay_params, 'weight_decay': weight_decay}, |
| | {'params': no_decay_params, 'weight_decay': 0.0} |
| | ] |
| |
|
| | def count_parameters(model: nn.Module) -> Dict[str, int]: |
| | """Count model parameters.""" |
| | total = sum(p.numel() for p in model.parameters()) |
| | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | return { |
| | 'total': total, |
| | 'trainable': trainable, |
| | 'non_trainable': total - trainable |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def extract_features(model: PentachoraViT, |
| | images: torch.Tensor, |
| | feature_type: str = 'global_cls') -> torch.Tensor: |
| | """ |
| | Extract features from images using the model. |
| | |
| | Args: |
| | model: PentachoraViT model |
| | images: Input images [B, 3, H, W] |
| | feature_type: Type of features to extract |
| | - 'global_cls': Global CLS token |
| | - 'vertex_cls': Vertex CLS tokens |
| | - 'patches': Patch embeddings |
| | |
| | Returns: |
| | Extracted features |
| | """ |
| | model.eval() |
| | features = model.forward_features(images) |
| | return features.get(feature_type, features['global_cls']) |
| |
|
| | |
| | |
| | |
| |
|
| | def test_model(): |
| | """Test model creation and forward pass.""" |
| | print("Testing PentachoraViT Model with Geometric Attention") |
| | print("=" * 50) |
| | |
| | |
| | variants = ['pentachora_spark', 'pentachora_tiny', 'pentachora_small'] |
| | |
| | for variant in variants: |
| | print(f"\nTesting {variant}:") |
| | |
| | |
| | model = create_pentachora_vit( |
| | variant=variant, |
| | img_size=32, |
| | patch_size=4, |
| | num_classes=100, |
| | vocab_dim=64 |
| | ) |
| | |
| | |
| | params = count_parameters(model) |
| | print(f" Total parameters: {params['total']:,}") |
| | print(f" Trainable parameters: {params['trainable']:,}") |
| | |
| | |
| | x = torch.randn(2, 3, 32, 32) |
| | outputs = model(x) |
| | |
| | print(f" Output shapes:") |
| | print(f" Logits: {outputs['logits'].shape}") |
| | print(f" Aux logits: {outputs['aux_logits'].shape}") |
| | print(f" Geometric alignments: {outputs['geometric_alignments'].shape}") |
| | |
| | |
| | loss_fn = PentachoraLoss() |
| | targets = torch.randint(0, 100, (2,)) |
| | loss = loss_fn(outputs, targets) |
| | print(f" Loss: {loss.item():.4f}") |
| | |
| | |
| | features = extract_features(model, x, 'global_cls') |
| | print(f" Extracted features shape: {features.shape}") |
| | |
| | print("\n" + "=" * 50) |
| | print("All tests passed!") |
| |
|
| | if __name__ == "__main__": |
| | |
| | test_model() |
| | |
| | |
| | print("\nExample: Creating model for training with 64-dim vocabulary") |
| | model = pentachora_vit_spark( |
| | img_size=32, |
| | patch_size=4, |
| | num_classes=100, |
| | vocab_dim=64, |
| | dropout_rate=0.0, |
| | drop_path_rate=0.0 |
| | ) |
| | |
| | |
| | param_groups = get_parameter_groups(model, weight_decay=0.05) |
| | print(f"Number of parameter groups: {len(param_groups)}") |
| | |
| | |
| | images = torch.randn(4, 3, 32, 32) |
| | targets = torch.randint(0, 100, (4,)) |
| | |
| | |
| | outputs = model(images) |
| | |
| | |
| | criterion = PentachoraLoss(aux_weight=0.3, geo_weight=0.1) |
| | loss = criterion(outputs, targets) |
| | |
| | print(f"Training loss: {loss.item():.4f}") |
| | print("\nModel ready for training with geometric attention!") |