AbstractPhil commited on
Commit
1b902bc
·
verified ·
1 Parent(s): 5522bcd

Create model_v1.py

Browse files
Files changed (1) hide show
  1. model_v1.py +218 -0
model_v1.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SpectralViT — Pure SpectralCell Transformer
3
+ =============================================
4
+ No conv backbone. No external attention. Stacked SpectralCells are the model.
5
+
6
+ Architecture:
7
+ Image → PatchEmbed(4×4) → 64 tokens × embed_dim
8
+ → Cayley hypersphere positional encoding (multi-plane rotations on S^{d-1})
9
+ → SpectralCell × depth
10
+ → LayerNorm → mean pool → classify
11
+
12
+ Positional encoding on the hypersphere:
13
+ Each position has K learnable rotation angles in K fixed 2D planes.
14
+ Rotation in plane (2k, 2k+1) by angle θ:
15
+ x[2k] = cos(θ) · x[2k] - sin(θ) · x[2k+1]
16
+ x[2k+1] = sin(θ) · x[2k] + cos(θ) · x[2k+1]
17
+ Composing K plane rotations = rich orthogonal rotation.
18
+ Preserves norms. Operates naturally on S^{d-1}.
19
+ Learnable angles, not fixed sinusoidal.
20
+
21
+ SpectralCell and cv_of are in namespace from prior cell execution.
22
+ """
23
+
24
+ import math
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+
29
+
30
+ # ── Cayley Hypersphere Positional Encoding ───────────────────────
31
+
32
+ class CayleyPositionalEncoding(nn.Module):
33
+ """Multi-plane rotation positional encoding on the hypersphere.
34
+
35
+ Each position gets K learnable rotation angles applied in K paired
36
+ dimension planes. Composing K Givens rotations produces a rich
37
+ orthogonal transformation that preserves embedding norm.
38
+
39
+ For embed_dim=256: K=128 planes, each position has 128 angles.
40
+ 64 positions × 128 angles = 8,192 learnable parameters.
41
+
42
+ This is geometrically natural — the SpectralCell projects onto S^{D-1},
43
+ and Cayley rotations are the native transformations of the hypersphere.
44
+ """
45
+ def __init__(self, n_positions, embed_dim):
46
+ super().__init__()
47
+ assert embed_dim % 2 == 0, "embed_dim must be even for paired rotations"
48
+ self.n_positions = n_positions
49
+ self.embed_dim = embed_dim
50
+ self.n_planes = embed_dim // 2
51
+
52
+ # Learnable rotation angles: (n_positions, n_planes)
53
+ # Initialize small — near-identity rotation at start
54
+ self.angles = nn.Parameter(torch.randn(n_positions, self.n_planes) * 0.02)
55
+
56
+ def forward(self, x):
57
+ """x: (B, N, D) → (B, N, D) with position-dependent rotation."""
58
+ B, N, D = x.shape
59
+ angles = self.angles[:N] # (N, K)
60
+
61
+ cos_a = angles.cos() # (N, K)
62
+ sin_a = angles.sin() # (N, K)
63
+
64
+ # Split into even/odd dimension pairs
65
+ x_even = x[:, :, 0::2] # (B, N, K)
66
+ x_odd = x[:, :, 1::2] # (B, N, K)
67
+
68
+ # Givens rotation per plane per position
69
+ x_rot_even = cos_a.unsqueeze(0) * x_even - sin_a.unsqueeze(0) * x_odd
70
+ x_rot_odd = sin_a.unsqueeze(0) * x_even + cos_a.unsqueeze(0) * x_odd
71
+
72
+ # Interleave back
73
+ out = torch.stack([x_rot_even, x_rot_odd], dim=-1) # (B, N, K, 2)
74
+ return out.reshape(B, N, D)
75
+
76
+
77
+ # ── Patch Embedding ──────────────────────────────────────────────
78
+
79
+ class PatchEmbed(nn.Module):
80
+ """Image → patches → linear projection.
81
+ 32×32 with patch_size=4 → 8×8 = 64 tokens.
82
+ """
83
+ def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
84
+ super().__init__()
85
+ self.n_patches = (img_size // patch_size) ** 2
86
+ self.proj = nn.Conv2d(in_channels, embed_dim,
87
+ kernel_size=patch_size, stride=patch_size)
88
+
89
+ def forward(self, x):
90
+ # x: (B, 3, H, W) → (B, embed_dim, H/ps, W/ps) → (B, N, embed_dim)
91
+ return self.proj(x).flatten(2).transpose(1, 2)
92
+
93
+
94
+ # ── SpectralViT ─────────────────────────────────────────────────
95
+
96
+ class SpectralViT(nn.Module):
97
+ """Pure SpectralCell vision transformer.
98
+
99
+ No conv backbone. No external attention.
100
+ Stacked SpectralCells with Cayley hypersphere positional encoding.
101
+
102
+ Args:
103
+ img_size: input image size (32 for CIFAR)
104
+ patch_size: patch size (4 → 64 tokens)
105
+ in_channels: input channels (3)
106
+ embed_dim: token embedding dimension
107
+ depth: number of SpectralCell blocks
108
+ cell_V: V parameter for SpectralCell
109
+ cell_D: D parameter for SpectralCell
110
+ cell_hidden: hidden dimension inside each cell
111
+ cell_depth: residual MLP depth inside each cell
112
+ n_cross: cross-attention layers per cell
113
+ n_heads: attention heads in cell cross-attention
114
+ n_classes: classification output
115
+ dropout: classifier dropout
116
+ """
117
+ def __init__(
118
+ self,
119
+ img_size=32,
120
+ patch_size=4,
121
+ in_channels=3,
122
+ embed_dim=256,
123
+ depth=6,
124
+ cell_V=16,
125
+ cell_D=16,
126
+ cell_hidden=256,
127
+ cell_depth=2,
128
+ n_cross=2,
129
+ n_heads=4,
130
+ n_classes=100,
131
+ dropout=0.1,
132
+ ):
133
+ super().__init__()
134
+ self.embed_dim = embed_dim
135
+ self.depth = depth
136
+ n_patches = (img_size // patch_size) ** 2
137
+
138
+ # Patch embedding
139
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
140
+
141
+ # Cayley hypersphere positional encoding
142
+ self.pos_enc = CayleyPositionalEncoding(n_patches, embed_dim)
143
+
144
+ # Stacked SpectralCells — the entire backbone
145
+ self.cells = nn.ModuleList([
146
+ SpectralCell(
147
+ token_dim=embed_dim, V=cell_V, D=cell_D,
148
+ hidden=cell_hidden, depth=cell_depth,
149
+ n_cross=n_cross, n_heads=n_heads,
150
+ max_alpha=0.2,
151
+ ) for _ in range(depth)
152
+ ])
153
+
154
+ # Pre-norm before each cell
155
+ self.norms = nn.ModuleList([
156
+ nn.LayerNorm(embed_dim) for _ in range(depth)
157
+ ])
158
+
159
+ # Final norm + classifier
160
+ self.final_norm = nn.LayerNorm(embed_dim)
161
+ self.classifier = nn.Sequential(
162
+ nn.Linear(embed_dim, embed_dim),
163
+ nn.GELU(),
164
+ nn.Dropout(dropout),
165
+ nn.Linear(embed_dim, n_classes),
166
+ )
167
+
168
+ def forward(self, x):
169
+ """x: (B, 3, H, W) → dict with logits and cell outputs."""
170
+ # Patch embed → positional encoding
171
+ tokens = self.patch_embed(x) # (B, N, embed_dim)
172
+ tokens = self.pos_enc(tokens) # rotated on hypersphere
173
+
174
+ # Stacked SpectralCells with residual connections
175
+ cell_outputs = []
176
+ for i, (cell, norm) in enumerate(zip(self.cells, self.norms)):
177
+ normed = norm(tokens)
178
+ cell_out = cell.format(normed)
179
+ tokens = tokens + cell_out['output'] # residual
180
+ cell_outputs.append(cell_out)
181
+
182
+ # Pool + classify
183
+ tokens = self.final_norm(tokens)
184
+ pooled = tokens.mean(dim=1) # (B, embed_dim)
185
+ logits = self.classifier(pooled)
186
+
187
+ return {
188
+ 'logits': logits,
189
+ 'cell_outputs': cell_outputs,
190
+ 'tokens': tokens,
191
+ }
192
+
193
+ def get_cross_attn_params(self):
194
+ """Cross-attention params for separate grad clipping."""
195
+ params = []
196
+ for name, p in self.named_parameters():
197
+ if 'cross_attn' in name:
198
+ params.append(p)
199
+ return params
200
+
201
+ def summary(self):
202
+ n_params = sum(p.numel() for p in self.parameters())
203
+ n_embed = sum(p.numel() for p in self.patch_embed.parameters())
204
+ n_pos = sum(p.numel() for p in self.pos_enc.parameters())
205
+ n_cells = sum(p.numel() for p in self.cells.parameters())
206
+ n_norms = sum(p.numel() for p in self.norms.parameters()) + sum(p.numel() for p in self.final_norm.parameters())
207
+ n_head = sum(p.numel() for p in self.classifier.parameters())
208
+ n_cross = sum(p.numel() for p in self.get_cross_attn_params())
209
+
210
+ print(f"SpectralViT:")
211
+ print(f" Patch embed: {n_embed:,}")
212
+ print(f" Cayley PE: {n_pos:,} ({self.pos_enc.n_planes} rotation planes × {self.pos_enc.n_positions} positions)")
213
+ print(f" Cells ({self.depth}×): {n_cells:,} ({n_cells // self.depth:,} per cell)")
214
+ print(f" LayerNorms: {n_norms:,}")
215
+ print(f" Classifier: {n_head:,}")
216
+ print(f" Cross-attn: {n_cross:,} (clipped at 0.5)")
217
+ print(f" Total: {n_params:,}")
218
+ print(f" Architecture: PatchEmbed(4×4) → CayleyPE → {self.depth}× SpectralCell → pool → classify")