File size: 9,745 Bytes
436f304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
SpectralViT β€” Pure SpectralCell Transformer
=============================================
No conv backbone. No external attention. Stacked SpectralCells are the model.

Architecture:
    Image β†’ PatchEmbed(4Γ—4) β†’ 64 tokens Γ— embed_dim
    β†’ Cayley hypersphere positional encoding (multi-plane rotations on S^{d-1})
    β†’ SpectralCell Γ— depth
    β†’ LayerNorm β†’ mean pool β†’ classify

Positional encoding on the hypersphere:
    Each position has K learnable rotation angles in K fixed 2D planes.
    Rotation in plane (2k, 2k+1) by angle ΞΈ:
        x[2k]   = cos(ΞΈ) Β· x[2k] - sin(ΞΈ) Β· x[2k+1]
        x[2k+1] = sin(ΞΈ) Β· x[2k] + cos(ΞΈ) Β· x[2k+1]
    Composing K plane rotations = rich orthogonal rotation.
    Preserves norms. Operates naturally on S^{d-1}.
    Learnable angles, not fixed sinusoidal.

SpectralCell and cv_of are in namespace from prior cell execution.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# ── Cayley Hypersphere Positional Encoding ───────────────────────

class CayleyPositionalEncoding(nn.Module):
    """Multi-plane rotation positional encoding on the hypersphere.

    Each position gets K learnable rotation angles applied in K paired
    dimension planes. Composing K Givens rotations produces a rich
    orthogonal transformation that preserves embedding norm.

    For embed_dim=256: K=128 planes, each position has 128 angles.
    64 positions Γ— 128 angles = 8,192 learnable parameters.

    This is geometrically natural β€” the SpectralCell projects onto S^{D-1},
    and Cayley rotations are the native transformations of the hypersphere.
    """
    def __init__(self, n_positions, embed_dim):
        super().__init__()
        assert embed_dim % 2 == 0, "embed_dim must be even for paired rotations"
        self.n_positions = n_positions
        self.embed_dim = embed_dim
        self.n_planes = embed_dim // 2

        # Learnable rotation angles: (n_positions, n_planes)
        # Initialize small β€” near-identity rotation at start
        self.angles = nn.Parameter(torch.randn(n_positions, self.n_planes) * 0.02)

    def forward(self, x):
        """x: (B, N, D) β†’ (B, N, D) with position-dependent rotation."""
        B, N, D = x.shape
        angles = self.angles[:N]  # (N, K)

        cos_a = angles.cos()  # (N, K)
        sin_a = angles.sin()  # (N, K)

        # Split into even/odd dimension pairs
        x_even = x[:, :, 0::2]   # (B, N, K)
        x_odd = x[:, :, 1::2]    # (B, N, K)

        # Givens rotation per plane per position
        x_rot_even = cos_a.unsqueeze(0) * x_even - sin_a.unsqueeze(0) * x_odd
        x_rot_odd = sin_a.unsqueeze(0) * x_even + cos_a.unsqueeze(0) * x_odd

        # Interleave back
        out = torch.stack([x_rot_even, x_rot_odd], dim=-1)  # (B, N, K, 2)
        return out.reshape(B, N, D)


# ── Patch Embedding ──────────────────────────────────────────────

class PatchEmbed(nn.Module):
    """Image β†’ patches β†’ linear projection.
    32Γ—32 with patch_size=4 β†’ 8Γ—8 = 64 tokens.
    """
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: (B, 3, H, W) β†’ (B, embed_dim, H/ps, W/ps) β†’ (B, N, embed_dim)
        return self.proj(x).flatten(2).transpose(1, 2)


# ── SpectralViT ─────────────────────────────────────────────────

class SpectralViT(nn.Module):
    """Pure SpectralCell vision transformer.

    No conv backbone. No external attention.
    Stacked SpectralCells with Cayley hypersphere positional encoding.

    Args:
        img_size:    input image size (32 for CIFAR)
        patch_size:  patch size (4 β†’ 64 tokens)
        in_channels: input channels (3)
        embed_dim:   token embedding dimension
        depth:       number of SpectralCell blocks
        cell_V:      V parameter for SpectralCell
        cell_D:      D parameter for SpectralCell (16 for primary, 2 for degenerate)
        cell_hidden: hidden dimension inside each cell
        cell_depth:  residual MLP depth inside each cell
        n_cross:     cross-attention layers per cell
        n_heads:     attention heads in cell cross-attention
        cm_every:    CM enabled every N cells (0 = never)
        n_classes:   classification output
        dropout:     classifier dropout
    """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=3,
        embed_dim=256,
        depth=6,
        cell_V=16,
        cell_D=16,
        cell_hidden=256,
        cell_depth=2,
        n_cross=2,
        n_heads=4,
        cm_every=3,
        n_classes=100,
        dropout=0.1,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.depth = depth
        self.cm_every = cm_every
        n_patches = (img_size // patch_size) ** 2

        # Patch embedding
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)

        # Cayley hypersphere positional encoding
        self.pos_enc = CayleyPositionalEncoding(n_patches, embed_dim)

        # Mixed cell stack:
        #   CM cells (every cm_every): primary β€” full SVD, CM enabled
        #   Non-CM cells: degenerate β€” slice_d=2, 8Γ— Triton D=2 kernels, no CM
        self.cm_cell_indices = []
        cells = []
        for i in range(depth):
            is_cm_cell = cm_every > 0 and (i % cm_every == cm_every - 1 or i == depth - 1)
            if is_cm_cell:
                self.cm_cell_indices.append(i)
            cells.append(SpectralCell(
                token_dim=embed_dim, V=cell_V, D=cell_D,
                hidden=cell_hidden, depth=cell_depth,
                n_cross=n_cross, n_heads=n_heads,
                max_alpha=0.2,
                cm_enabled=is_cm_cell,
                cm_points=5, cm_samples=200, cm_min=1e-16,
                slice_d=0 if is_cm_cell else 2,  # primary=full SVD, conduit=sliced D=2
            ))
        self.cells = nn.ModuleList(cells)

        # Pre-norm before each cell
        self.norms = nn.ModuleList([
            nn.LayerNorm(embed_dim) for _ in range(depth)
        ])

        # Final norm + classifier
        self.final_norm = nn.LayerNorm(embed_dim)
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim, n_classes),
        )

    def forward(self, x):
        """x: (B, 3, H, W) β†’ dict with logits and last cell output."""
        # Patch embed β†’ positional encoding
        tokens = self.patch_embed(x)          # (B, N, embed_dim)
        tokens = self.pos_enc(tokens)          # rotated on hypersphere

        # Cells 0..depth-2: just the output tensor (no dict bloat)
        for i in range(self.depth - 1):
            normed = self.norms[i](tokens)
            tokens = tokens + self.cells[i](normed)  # .forward() β†’ tensor only

        # Last cell: full .format() for CV measurement
        normed = self.norms[-1](tokens)
        last_cell_out = self.cells[-1].format(normed)
        tokens = tokens + last_cell_out['output']

        # Pool + classify
        tokens = self.final_norm(tokens)
        pooled = tokens.mean(dim=1)               # (B, embed_dim)
        logits = self.classifier(pooled)

        return {
            'logits': logits,
            'last_cell': last_cell_out,
        }

    def get_cross_attn_params(self):
        """Cross-attention params for separate grad clipping."""
        params = []
        for name, p in self.named_parameters():
            if 'cross_attn' in name:
                params.append(p)
        return params

    def summary(self):
        n_params = sum(p.numel() for p in self.parameters())
        n_embed = sum(p.numel() for p in self.patch_embed.parameters())
        n_pos = sum(p.numel() for p in self.pos_enc.parameters())
        n_cells = sum(p.numel() for p in self.cells.parameters())
        n_norms = sum(p.numel() for p in self.norms.parameters()) + sum(p.numel() for p in self.final_norm.parameters())
        n_head = sum(p.numel() for p in self.classifier.parameters())
        n_cross = sum(p.numel() for p in self.get_cross_attn_params())

        n_cm = len(self.cm_cell_indices)
        n_conduit = self.depth - n_cm
        cell0 = self.cells[0]

        print(f"SpectralViT:")
        print(f"  Patch embed:  {n_embed:,}")
        print(f"  Cayley PE:    {n_pos:,} ({self.pos_enc.n_planes} rotation planes Γ— {self.pos_enc.n_positions} positions)")
        print(f"  Cells ({self.depth}Γ—):  {n_cells:,}  ({n_cells // self.depth:,} per cell)")
        print(f"    D={cell0.D}, V={cell0.V}, hidden={cell0.hidden}")
        n_degen = sum(1 for c in self.cells if c.slice_d > 0)
        n_full = self.depth - n_degen
        print(f"    CM cells: {self.cm_cell_indices} ({n_cm} primary)")
        print(f"    SVD split: {n_full} full + {n_degen} sliced (D=2 Γ— {cell0.D // 2})")
        print(f"  LayerNorms:   {n_norms:,}")
        print(f"  Classifier:   {n_head:,}")
        print(f"  Cross-attn:   {n_cross:,} (clipped at 0.5)")
        print(f"  Total:        {n_params:,}")
        print(f"  Architecture: PatchEmbed β†’ CayleyPE β†’ {self.depth}Γ— SpectralCell (CM every {self.cm_every}) β†’ pool β†’ classify")