| """ |
| SpectralViT β Pure SpectralCell Transformer |
| ============================================= |
| No conv backbone. No external attention. Stacked SpectralCells are the model. |
| |
| Architecture: |
| Image β PatchEmbed(4Γ4) β 64 tokens Γ embed_dim |
| β Cayley hypersphere positional encoding (multi-plane rotations on S^{d-1}) |
| β SpectralCell Γ depth |
| β LayerNorm β mean pool β classify |
| |
| Positional encoding on the hypersphere: |
| Each position has K learnable rotation angles in K fixed 2D planes. |
| Rotation in plane (2k, 2k+1) by angle ΞΈ: |
| x[2k] = cos(ΞΈ) Β· x[2k] - sin(ΞΈ) Β· x[2k+1] |
| x[2k+1] = sin(ΞΈ) Β· x[2k] + cos(ΞΈ) Β· x[2k+1] |
| Composing K plane rotations = rich orthogonal rotation. |
| Preserves norms. Operates naturally on S^{d-1}. |
| Learnable angles, not fixed sinusoidal. |
| |
| SpectralCell and cv_of are in namespace from prior cell execution. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
|
|
| class CayleyPositionalEncoding(nn.Module): |
| """Multi-plane rotation positional encoding on the hypersphere. |
| |
| Each position gets K learnable rotation angles applied in K paired |
| dimension planes. Composing K Givens rotations produces a rich |
| orthogonal transformation that preserves embedding norm. |
| |
| For embed_dim=256: K=128 planes, each position has 128 angles. |
| 64 positions Γ 128 angles = 8,192 learnable parameters. |
| |
| This is geometrically natural β the SpectralCell projects onto S^{D-1}, |
| and Cayley rotations are the native transformations of the hypersphere. |
| """ |
| def __init__(self, n_positions, embed_dim): |
| super().__init__() |
| assert embed_dim % 2 == 0, "embed_dim must be even for paired rotations" |
| self.n_positions = n_positions |
| self.embed_dim = embed_dim |
| self.n_planes = embed_dim // 2 |
|
|
| |
| |
| self.angles = nn.Parameter(torch.randn(n_positions, self.n_planes) * 0.02) |
|
|
| def forward(self, x): |
| """x: (B, N, D) β (B, N, D) with position-dependent rotation.""" |
| B, N, D = x.shape |
| angles = self.angles[:N] |
|
|
| cos_a = angles.cos() |
| sin_a = angles.sin() |
|
|
| |
| x_even = x[:, :, 0::2] |
| x_odd = x[:, :, 1::2] |
|
|
| |
| x_rot_even = cos_a.unsqueeze(0) * x_even - sin_a.unsqueeze(0) * x_odd |
| x_rot_odd = sin_a.unsqueeze(0) * x_even + cos_a.unsqueeze(0) * x_odd |
|
|
| |
| out = torch.stack([x_rot_even, x_rot_odd], dim=-1) |
| return out.reshape(B, N, D) |
|
|
|
|
| |
|
|
| class PatchEmbed(nn.Module): |
| """Image β patches β linear projection. |
| 32Γ32 with patch_size=4 β 8Γ8 = 64 tokens. |
| """ |
| def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256): |
| super().__init__() |
| self.n_patches = (img_size // patch_size) ** 2 |
| self.proj = nn.Conv2d(in_channels, embed_dim, |
| kernel_size=patch_size, stride=patch_size) |
|
|
| def forward(self, x): |
| |
| return self.proj(x).flatten(2).transpose(1, 2) |
|
|
|
|
| |
|
|
| class SpectralViT(nn.Module): |
| """Pure SpectralCell vision transformer. |
| |
| No conv backbone. No external attention. |
| Stacked SpectralCells with Cayley hypersphere positional encoding. |
| |
| Args: |
| img_size: input image size (32 for CIFAR) |
| patch_size: patch size (4 β 64 tokens) |
| in_channels: input channels (3) |
| embed_dim: token embedding dimension |
| depth: number of SpectralCell blocks |
| cell_V: V parameter for SpectralCell |
| cell_D: D parameter for SpectralCell |
| cell_hidden: hidden dimension inside each cell |
| cell_depth: residual MLP depth inside each cell |
| n_cross: cross-attention layers per cell |
| n_heads: attention heads in cell cross-attention |
| n_classes: classification output |
| dropout: classifier dropout |
| """ |
| def __init__( |
| self, |
| img_size=32, |
| patch_size=4, |
| in_channels=3, |
| embed_dim=256, |
| depth=6, |
| cell_V=16, |
| cell_D=16, |
| cell_hidden=256, |
| cell_depth=2, |
| n_cross=2, |
| n_heads=4, |
| n_classes=100, |
| dropout=0.1, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.depth = depth |
| n_patches = (img_size // patch_size) ** 2 |
|
|
| |
| self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim) |
|
|
| |
| self.pos_enc = CayleyPositionalEncoding(n_patches, embed_dim) |
|
|
| |
| self.cells = nn.ModuleList([ |
| SpectralCell( |
| token_dim=embed_dim, V=cell_V, D=cell_D, |
| hidden=cell_hidden, depth=cell_depth, |
| n_cross=n_cross, n_heads=n_heads, |
| max_alpha=0.2, |
| ) for _ in range(depth) |
| ]) |
|
|
| |
| self.norms = nn.ModuleList([ |
| nn.LayerNorm(embed_dim) for _ in range(depth) |
| ]) |
|
|
| |
| self.final_norm = nn.LayerNorm(embed_dim) |
| self.classifier = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(embed_dim, n_classes), |
| ) |
|
|
| def forward(self, x): |
| """x: (B, 3, H, W) β dict with logits and last cell output.""" |
| |
| tokens = self.patch_embed(x) |
| tokens = self.pos_enc(tokens) |
|
|
| |
| for i in range(self.depth - 1): |
| normed = self.norms[i](tokens) |
| tokens = tokens + self.cells[i](normed) |
|
|
| |
| normed = self.norms[-1](tokens) |
| last_cell_out = self.cells[-1].format(normed) |
| tokens = tokens + last_cell_out['output'] |
|
|
| |
| tokens = self.final_norm(tokens) |
| pooled = tokens.mean(dim=1) |
| logits = self.classifier(pooled) |
|
|
| return { |
| 'logits': logits, |
| 'last_cell': last_cell_out, |
| } |
|
|
| def get_cross_attn_params(self): |
| """Cross-attention params for separate grad clipping.""" |
| params = [] |
| for name, p in self.named_parameters(): |
| if 'cross_attn' in name: |
| params.append(p) |
| return params |
|
|
| def summary(self): |
| n_params = sum(p.numel() for p in self.parameters()) |
| n_embed = sum(p.numel() for p in self.patch_embed.parameters()) |
| n_pos = sum(p.numel() for p in self.pos_enc.parameters()) |
| n_cells = sum(p.numel() for p in self.cells.parameters()) |
| n_norms = sum(p.numel() for p in self.norms.parameters()) + sum(p.numel() for p in self.final_norm.parameters()) |
| n_head = sum(p.numel() for p in self.classifier.parameters()) |
| n_cross = sum(p.numel() for p in self.get_cross_attn_params()) |
|
|
| print(f"SpectralViT:") |
| print(f" Patch embed: {n_embed:,}") |
| print(f" Cayley PE: {n_pos:,} ({self.pos_enc.n_planes} rotation planes Γ {self.pos_enc.n_positions} positions)") |
| print(f" Cells ({self.depth}Γ): {n_cells:,} ({n_cells // self.depth:,} per cell)") |
| print(f" LayerNorms: {n_norms:,}") |
| print(f" Classifier: {n_head:,}") |
| print(f" Cross-attn: {n_cross:,} (clipped at 0.5)") |
| print(f" Total: {n_params:,}") |
| print(f" Architecture: PatchEmbed(4Γ4) β CayleyPE β {self.depth}Γ SpectralCell β pool β classify") |