""" SpectralViT — Pure SpectralCell Transformer ============================================= No conv backbone. No external attention. Stacked SpectralCells are the model. Architecture: Image → PatchEmbed(4×4) → 64 tokens × embed_dim → Cayley hypersphere positional encoding (multi-plane rotations on S^{d-1}) → SpectralCell × depth → LayerNorm → mean pool → classify Positional encoding on the hypersphere: Each position has K learnable rotation angles in K fixed 2D planes. Rotation in plane (2k, 2k+1) by angle θ: x[2k] = cos(θ) · x[2k] - sin(θ) · x[2k+1] x[2k+1] = sin(θ) · x[2k] + cos(θ) · x[2k+1] Composing K plane rotations = rich orthogonal rotation. Preserves norms. Operates naturally on S^{d-1}. Learnable angles, not fixed sinusoidal. SpectralCell and cv_of are in namespace from prior cell execution. """ import math import torch import torch.nn as nn import torch.nn.functional as F # ── Cayley Hypersphere Positional Encoding ─────────────────────── class CayleyPositionalEncoding(nn.Module): """Multi-plane rotation positional encoding on the hypersphere. Each position gets K learnable rotation angles applied in K paired dimension planes. Composing K Givens rotations produces a rich orthogonal transformation that preserves embedding norm. For embed_dim=256: K=128 planes, each position has 128 angles. 64 positions × 128 angles = 8,192 learnable parameters. This is geometrically natural — the SpectralCell projects onto S^{D-1}, and Cayley rotations are the native transformations of the hypersphere. """ def __init__(self, n_positions, embed_dim): super().__init__() assert embed_dim % 2 == 0, "embed_dim must be even for paired rotations" self.n_positions = n_positions self.embed_dim = embed_dim self.n_planes = embed_dim // 2 # Learnable rotation angles: (n_positions, n_planes) # Initialize small — near-identity rotation at start self.angles = nn.Parameter(torch.randn(n_positions, self.n_planes) * 0.02) def forward(self, x): """x: (B, N, D) → (B, N, D) with position-dependent rotation.""" B, N, D = x.shape angles = self.angles[:N] # (N, K) cos_a = angles.cos() # (N, K) sin_a = angles.sin() # (N, K) # Split into even/odd dimension pairs x_even = x[:, :, 0::2] # (B, N, K) x_odd = x[:, :, 1::2] # (B, N, K) # Givens rotation per plane per position x_rot_even = cos_a.unsqueeze(0) * x_even - sin_a.unsqueeze(0) * x_odd x_rot_odd = sin_a.unsqueeze(0) * x_even + cos_a.unsqueeze(0) * x_odd # Interleave back out = torch.stack([x_rot_even, x_rot_odd], dim=-1) # (B, N, K, 2) return out.reshape(B, N, D) # ── Patch Embedding ────────────────────────────────────────────── class PatchEmbed(nn.Module): """Image → patches → linear projection. 32×32 with patch_size=4 → 8×8 = 64 tokens. """ def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256): super().__init__() self.n_patches = (img_size // patch_size) ** 2 self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): # x: (B, 3, H, W) → (B, embed_dim, H/ps, W/ps) → (B, N, embed_dim) return self.proj(x).flatten(2).transpose(1, 2) # ── SpectralViT ───────────────────────────────────────────────── class SpectralViT(nn.Module): """Pure SpectralCell vision transformer. No conv backbone. No external attention. Stacked SpectralCells with Cayley hypersphere positional encoding. Args: img_size: input image size (32 for CIFAR) patch_size: patch size (4 → 64 tokens) in_channels: input channels (3) embed_dim: token embedding dimension depth: number of SpectralCell blocks cell_V: V parameter for SpectralCell cell_D: D parameter for SpectralCell cell_hidden: hidden dimension inside each cell cell_depth: residual MLP depth inside each cell n_cross: cross-attention layers per cell n_heads: attention heads in cell cross-attention n_classes: classification output dropout: classifier dropout """ def __init__( self, img_size=32, patch_size=4, in_channels=3, embed_dim=256, depth=6, cell_V=16, cell_D=16, cell_hidden=256, cell_depth=2, n_cross=2, n_heads=4, n_classes=100, dropout=0.1, ): super().__init__() self.embed_dim = embed_dim self.depth = depth n_patches = (img_size // patch_size) ** 2 # Patch embedding self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim) # Cayley hypersphere positional encoding self.pos_enc = CayleyPositionalEncoding(n_patches, embed_dim) # Stacked SpectralCells — the entire backbone self.cells = nn.ModuleList([ SpectralCell( token_dim=embed_dim, V=cell_V, D=cell_D, hidden=cell_hidden, depth=cell_depth, n_cross=n_cross, n_heads=n_heads, max_alpha=0.2, ) for _ in range(depth) ]) # Pre-norm before each cell self.norms = nn.ModuleList([ nn.LayerNorm(embed_dim) for _ in range(depth) ]) # Final norm + classifier self.final_norm = nn.LayerNorm(embed_dim) self.classifier = nn.Sequential( nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(embed_dim, n_classes), ) def forward(self, x): """x: (B, 3, H, W) → dict with logits and last cell output.""" # Patch embed → positional encoding tokens = self.patch_embed(x) # (B, N, embed_dim) tokens = self.pos_enc(tokens) # rotated on hypersphere # Cells 0..depth-2: just the output tensor (no dict bloat) for i in range(self.depth - 1): normed = self.norms[i](tokens) tokens = tokens + self.cells[i](normed) # .forward() → tensor only # Last cell: full .format() for CV measurement normed = self.norms[-1](tokens) last_cell_out = self.cells[-1].format(normed) tokens = tokens + last_cell_out['output'] # Pool + classify tokens = self.final_norm(tokens) pooled = tokens.mean(dim=1) # (B, embed_dim) logits = self.classifier(pooled) return { 'logits': logits, 'last_cell': last_cell_out, } def get_cross_attn_params(self): """Cross-attention params for separate grad clipping.""" params = [] for name, p in self.named_parameters(): if 'cross_attn' in name: params.append(p) return params def summary(self): n_params = sum(p.numel() for p in self.parameters()) n_embed = sum(p.numel() for p in self.patch_embed.parameters()) n_pos = sum(p.numel() for p in self.pos_enc.parameters()) n_cells = sum(p.numel() for p in self.cells.parameters()) n_norms = sum(p.numel() for p in self.norms.parameters()) + sum(p.numel() for p in self.final_norm.parameters()) n_head = sum(p.numel() for p in self.classifier.parameters()) n_cross = sum(p.numel() for p in self.get_cross_attn_params()) print(f"SpectralViT:") print(f" Patch embed: {n_embed:,}") print(f" Cayley PE: {n_pos:,} ({self.pos_enc.n_planes} rotation planes × {self.pos_enc.n_positions} positions)") print(f" Cells ({self.depth}×): {n_cells:,} ({n_cells // self.depth:,} per cell)") print(f" LayerNorms: {n_norms:,}") print(f" Classifier: {n_head:,}") print(f" Cross-attn: {n_cross:,} (clipped at 0.5)") print(f" Total: {n_params:,}") print(f" Architecture: PatchEmbed(4×4) → CayleyPE → {self.depth}× SpectralCell → pool → classify")