AbstractPhil's picture
added profiling, may slow down
1884baf verified
"""
SpectralViT β€” Pure SpectralCell Transformer
=============================================
No conv backbone. No external attention. Stacked SpectralCells are the model.
Architecture:
Image β†’ PatchEmbed(4Γ—4) β†’ 64 tokens Γ— embed_dim
β†’ Cayley hypersphere positional encoding (multi-plane rotations on S^{d-1})
β†’ SpectralCell Γ— depth
β†’ LayerNorm β†’ mean pool β†’ classify
Positional encoding on the hypersphere:
Each position has K learnable rotation angles in K fixed 2D planes.
Rotation in plane (2k, 2k+1) by angle ΞΈ:
x[2k] = cos(ΞΈ) Β· x[2k] - sin(ΞΈ) Β· x[2k+1]
x[2k+1] = sin(ΞΈ) Β· x[2k] + cos(ΞΈ) Β· x[2k+1]
Composing K plane rotations = rich orthogonal rotation.
Preserves norms. Operates naturally on S^{d-1}.
Learnable angles, not fixed sinusoidal.
SpectralCell and cv_of are in namespace from prior cell execution.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ── Cayley Hypersphere Positional Encoding ───────────────────────
class CayleyPositionalEncoding(nn.Module):
"""Multi-plane rotation positional encoding on the hypersphere.
Each position gets K learnable rotation angles applied in K paired
dimension planes. Composing K Givens rotations produces a rich
orthogonal transformation that preserves embedding norm.
For embed_dim=256: K=128 planes, each position has 128 angles.
64 positions Γ— 128 angles = 8,192 learnable parameters.
This is geometrically natural β€” the SpectralCell projects onto S^{D-1},
and Cayley rotations are the native transformations of the hypersphere.
"""
def __init__(self, n_positions, embed_dim):
super().__init__()
assert embed_dim % 2 == 0, "embed_dim must be even for paired rotations"
self.n_positions = n_positions
self.embed_dim = embed_dim
self.n_planes = embed_dim // 2
# Learnable rotation angles: (n_positions, n_planes)
# Initialize small β€” near-identity rotation at start
self.angles = nn.Parameter(torch.randn(n_positions, self.n_planes) * 0.02)
def forward(self, x):
"""x: (B, N, D) β†’ (B, N, D) with position-dependent rotation."""
B, N, D = x.shape
angles = self.angles[:N] # (N, K)
cos_a = angles.cos() # (N, K)
sin_a = angles.sin() # (N, K)
# Split into even/odd dimension pairs
x_even = x[:, :, 0::2] # (B, N, K)
x_odd = x[:, :, 1::2] # (B, N, K)
# Givens rotation per plane per position
x_rot_even = cos_a.unsqueeze(0) * x_even - sin_a.unsqueeze(0) * x_odd
x_rot_odd = sin_a.unsqueeze(0) * x_even + cos_a.unsqueeze(0) * x_odd
# Interleave back
out = torch.stack([x_rot_even, x_rot_odd], dim=-1) # (B, N, K, 2)
return out.reshape(B, N, D)
# ── Patch Embedding ──────────────────────────────────────────────
class PatchEmbed(nn.Module):
"""Image β†’ patches β†’ linear projection.
32Γ—32 with patch_size=4 β†’ 8Γ—8 = 64 tokens.
"""
def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
super().__init__()
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: (B, 3, H, W) β†’ (B, embed_dim, H/ps, W/ps) β†’ (B, N, embed_dim)
return self.proj(x).flatten(2).transpose(1, 2)
# ── SpectralViT ─────────────────────────────────────────────────
class SpectralViT(nn.Module):
"""Pure SpectralCell vision transformer.
No conv backbone. No external attention.
Stacked SpectralCells with Cayley hypersphere positional encoding.
Args:
img_size: input image size (32 for CIFAR)
patch_size: patch size (4 β†’ 64 tokens)
in_channels: input channels (3)
embed_dim: token embedding dimension
depth: number of SpectralCell blocks
cell_V: V parameter for SpectralCell
cell_D: D parameter for SpectralCell
cell_hidden: hidden dimension inside each cell
cell_depth: residual MLP depth inside each cell
n_cross: cross-attention layers per cell
n_heads: attention heads in cell cross-attention
n_classes: classification output
dropout: classifier dropout
"""
def __init__(
self,
img_size=32,
patch_size=4,
in_channels=3,
embed_dim=256,
depth=6,
cell_V=16,
cell_D=16,
cell_hidden=256,
cell_depth=2,
n_cross=2,
n_heads=4,
n_classes=100,
dropout=0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.depth = depth
n_patches = (img_size // patch_size) ** 2
# Patch embedding
self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
# Cayley hypersphere positional encoding
self.pos_enc = CayleyPositionalEncoding(n_patches, embed_dim)
# Stacked SpectralCells β€” the entire backbone
self.cells = nn.ModuleList([
SpectralCell(
token_dim=embed_dim, V=cell_V, D=cell_D,
hidden=cell_hidden, depth=cell_depth,
n_cross=n_cross, n_heads=n_heads,
max_alpha=0.2,
) for _ in range(depth)
])
# Pre-norm before each cell
self.norms = nn.ModuleList([
nn.LayerNorm(embed_dim) for _ in range(depth)
])
# Final norm + classifier
self.final_norm = nn.LayerNorm(embed_dim)
self.classifier = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(embed_dim, n_classes),
)
def forward(self, x):
"""x: (B, 3, H, W) β†’ dict with logits and last cell output."""
# Patch embed β†’ positional encoding
tokens = self.patch_embed(x) # (B, N, embed_dim)
tokens = self.pos_enc(tokens) # rotated on hypersphere
# Cells 0..depth-2: just the output tensor (no dict bloat)
for i in range(self.depth - 1):
normed = self.norms[i](tokens)
tokens = tokens + self.cells[i](normed) # .forward() β†’ tensor only
# Last cell: full .format() for CV measurement
normed = self.norms[-1](tokens)
last_cell_out = self.cells[-1].format(normed)
tokens = tokens + last_cell_out['output']
# Pool + classify
tokens = self.final_norm(tokens)
pooled = tokens.mean(dim=1) # (B, embed_dim)
logits = self.classifier(pooled)
return {
'logits': logits,
'last_cell': last_cell_out,
}
def get_cross_attn_params(self):
"""Cross-attention params for separate grad clipping."""
params = []
for name, p in self.named_parameters():
if 'cross_attn' in name:
params.append(p)
return params
def summary(self):
n_params = sum(p.numel() for p in self.parameters())
n_embed = sum(p.numel() for p in self.patch_embed.parameters())
n_pos = sum(p.numel() for p in self.pos_enc.parameters())
n_cells = sum(p.numel() for p in self.cells.parameters())
n_norms = sum(p.numel() for p in self.norms.parameters()) + sum(p.numel() for p in self.final_norm.parameters())
n_head = sum(p.numel() for p in self.classifier.parameters())
n_cross = sum(p.numel() for p in self.get_cross_attn_params())
print(f"SpectralViT:")
print(f" Patch embed: {n_embed:,}")
print(f" Cayley PE: {n_pos:,} ({self.pos_enc.n_planes} rotation planes Γ— {self.pos_enc.n_positions} positions)")
print(f" Cells ({self.depth}Γ—): {n_cells:,} ({n_cells // self.depth:,} per cell)")
print(f" LayerNorms: {n_norms:,}")
print(f" Classifier: {n_head:,}")
print(f" Cross-attn: {n_cross:,} (clipped at 0.5)")
print(f" Total: {n_params:,}")
print(f" Architecture: PatchEmbed(4Γ—4) β†’ CayleyPE β†’ {self.depth}Γ— SpectralCell β†’ pool β†’ classify")