edeneldith commited on
Commit
e254270
·
verified ·
1 Parent(s): 8f1a194

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +337 -0
model.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ COLM Model Components
3
+ =====================
4
+ Complex Oscillating Language Model — all neural network modules.
5
+
6
+ Components:
7
+ - ComplexRMSNorm: magnitude normalization preserving phase
8
+ - ComplexOscillator: sin(W⊙Z+B)·tanh(Z) oscillating neuron
9
+ - ComplexMixer: fixed unitary cross-dimension routing
10
+ - OscillatingCausalScanner: O(N) causal sequence scanner
11
+ - SparseGate: smooth sigmoid voltage-spike gate
12
+ - ZeroLinearBlock: scanner + oscillating MLP block
13
+ - COLM: full autoregressive model
14
+ """
15
+
16
+ import math
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+
21
+
22
+ # =============================================================================
23
+ # COMPLEX RMSNORM — norm the magnitude, preserve the angle
24
+ # =============================================================================
25
+
26
+ class ComplexRMSNorm(nn.Module):
27
+ """RMSNorm adapted for complex tensors.
28
+ Normalizes the magnitude while preserving phase angles.
29
+ Learnable weight is real-valued (scales magnitude)."""
30
+
31
+ def __init__(self, dim, eps=1e-6):
32
+ super().__init__()
33
+ self.eps = eps
34
+ self.weight = nn.Parameter(torch.ones(dim))
35
+
36
+ def forward(self, Z):
37
+ rms = torch.rsqrt((Z.real.square() + Z.imag.square()).mean(-1, keepdim=True) + self.eps)
38
+ return Z * (rms * self.weight)
39
+
40
+
41
+ # =============================================================================
42
+ # COMPLEX OSCILLATOR — sin(W⊙Z+B)·tanh(Z), W,B ∈ ℂ
43
+ # =============================================================================
44
+
45
+ def _softcap_imag(z, limit=6.0):
46
+ return torch.complex(z.real, limit * torch.tanh(z.imag / limit))
47
+
48
+
49
+ def safe_abs(Z, eps=1e-12):
50
+ """Gradient-safe complex magnitude. torch.abs() on complex is sqrt(re²+im²),
51
+ and sqrt'(0) = inf. Adding eps inside the sqrt prevents inf gradients
52
+ when the sparse gate zeros out features. Forward values are unchanged
53
+ to ~6 decimal places."""
54
+ return torch.sqrt(Z.real.square() + Z.imag.square() + eps)
55
+
56
+
57
+ class ComplexOscillator(nn.Module):
58
+ """Native Complex Oscillating Neuron.
59
+ W = ω + iφ (frequency + phase as single complex param)
60
+ B = real_bias + i·imag_bias (complex baseline)
61
+
62
+ PyTorch supports complex sin() and tanh() natively.
63
+ Wirtinger derivatives flow through automatically."""
64
+
65
+ def __init__(self, dim):
66
+ super().__init__()
67
+ # W: real part = frequency (ω), imag part = phase (φ)
68
+ omega = torch.randn(dim) * 0.1 + 1.0
69
+ phi = torch.randn(dim) * 0.1
70
+ self.W = nn.Parameter(torch.complex(omega, phi))
71
+
72
+ # B: complex baseline
73
+ self.B = nn.Parameter(torch.complex(torch.zeros(dim), torch.zeros(dim)))
74
+
75
+ def forward(self, Z):
76
+ # Z is cfloat. Inductor can fuse this into a single kernel.
77
+ Z = _softcap_imag(Z, limit=math.pi/2 - 0.2) # stays below first pole at π/2
78
+ WZ = _softcap_imag(self.W * Z + self.B, limit=6.0)
79
+ return torch.sin(WZ) * torch.tanh(Z)
80
+
81
+
82
+ # =============================================================================
83
+ # COMPLEX MIXER — fixed unitary matrix, zero learnable params
84
+ # =============================================================================
85
+
86
+ class ComplexMixer(nn.Module):
87
+ """Zero-parameter cross-dimension routing via fixed unitary matrix.
88
+ QR-orthogonalized complex matrix ensures energy preservation.
89
+
90
+ NOTE: This is O(D²) per token — the FWHT was O(D log D).
91
+ Chosen for torch.compile compatibility over raw compute efficiency.
92
+ If compile handles FWHT well on your hardware, swap back."""
93
+
94
+ def __init__(self, dim):
95
+ super().__init__()
96
+ # Random complex matrix → QR decomposition → unitary Q
97
+ real_part = torch.randn(dim, dim)
98
+ imag_part = torch.randn(dim, dim)
99
+ complex_mat = torch.complex(real_part, imag_part)
100
+ q, _ = torch.linalg.qr(complex_mat)
101
+ self.register_buffer('mix_matrix', q)
102
+
103
+ def forward(self, Z):
104
+ # Z: (B, T, D) @ (D, D) -> (B, T, D)
105
+ return Z @ self.mix_matrix.T
106
+
107
+
108
+ # =============================================================================
109
+ # O(N) COMPLEX OSCILLATOR CAUSAL SCANNER — replaces O(N²) attention
110
+ # =============================================================================
111
+
112
+ class OscillatingCausalScanner(nn.Module):
113
+ """O(N) sequence routing replacing scaled_dot_product_attention.
114
+
115
+ Uses ComplexOscillator to generate:
116
+ - gate: complex decay (magnitude=retention, angle=phase rotation)
117
+ - val: complex value signal
118
+ Then accumulates causally across sequence length T in O(N) time.
119
+
120
+ This is mathematically related to Linear Attention / State Space Models
121
+ (Mamba, RWKV, Griffin) but powered entirely by oscillating neurons."""
122
+
123
+ def __init__(self, dim, clamp=70.0):
124
+ super().__init__()
125
+ self.clamp = clamp
126
+ self.osc_gate = ComplexOscillator(dim)
127
+ self.osc_val = ComplexOscillator(dim)
128
+ self.osc_out = ComplexOscillator(dim)
129
+
130
+ # Tame the gate's initial W so first gates aren't too aggressive
131
+ with torch.no_grad():
132
+ self.osc_gate.W.data = torch.complex(
133
+ torch.empty(dim).uniform_(-0.05, 0.05),
134
+ torch.empty(dim).uniform_(-0.05, 0.05)
135
+ )
136
+
137
+ def forward(self, Z):
138
+ # Z: (B, T, D) complex
139
+ gate = self.osc_gate(Z)
140
+ val = self.osc_val(Z)
141
+
142
+ decay = torch.sigmoid(gate.real)
143
+ phase = math.pi * torch.tanh(gate.imag / math.pi)
144
+
145
+ # Build log_gate directly — no torch.polar, no .angle()
146
+ # This avoids the atan2(0,0) NaN gradient when decay → 0
147
+ log_gate = torch.complex(torch.log(decay.clamp(min=1e-8)), phase)
148
+
149
+ cum_log = torch.cumsum(log_gate, dim=1)
150
+
151
+ CLAMP = self.clamp
152
+ exp_real = cum_log.real.clamp(min=-CLAMP)
153
+ exp_cum = torch.exp(torch.complex(exp_real, cum_log.imag))
154
+
155
+ neg_real = (-cum_log.real).clamp(max=CLAMP)
156
+ exp_neg = torch.exp(torch.complex(neg_real, -cum_log.imag))
157
+
158
+ H = exp_cum * torch.cumsum(val * exp_neg, dim=1)
159
+
160
+ # GRADIENT ECOLOGY: soft magnitude channel (preserves phase, smooth gradients)
161
+ H_mag = safe_abs(H).clamp(min=1e-8)
162
+ H = H * (torch.tanh(H_mag / 8.0) / H_mag)
163
+ return self.osc_out(H)
164
+
165
+
166
+ # =============================================================================
167
+ # SMOOTH SPARSE GATE — proper sigmoid
168
+ # =============================================================================
169
+
170
+ class SparseGate(nn.Module):
171
+ """Decoupled spike gate with learnable temperature.
172
+ Uses smooth sigmoid for clean gradients.
173
+
174
+ voltage = sigmoid(gate_w * x)
175
+ spike = sigmoid((voltage - threshold) * temperature)
176
+ output = x * spike
177
+ """
178
+
179
+ def __init__(self, num_features, threshold_init=0.3):
180
+ super().__init__()
181
+ self.gate_w = nn.Parameter(torch.ones(num_features) * 0.25)
182
+ self.threshold = nn.Parameter(torch.full((num_features,), threshold_init))
183
+ self.temperature = nn.Parameter(torch.ones(num_features) * 10.0)
184
+
185
+ def forward(self, x):
186
+ voltage = torch.sigmoid(self.gate_w * x)
187
+ spike = torch.sigmoid((voltage - self.threshold) * self.temperature)
188
+ return x * spike
189
+
190
+ @torch.no_grad()
191
+ def get_sparsity(self, x=None):
192
+ if x is None:
193
+ return 0.0
194
+ voltage = torch.sigmoid(self.gate_w * x)
195
+ return (voltage > self.threshold).float().mean().item()
196
+
197
+
198
+ # =============================================================================
199
+ # ZERO-LINEAR BLOCK — scanner + complex mixer/oscillator MLP
200
+ # =============================================================================
201
+
202
+ class ZeroLinearBlock(nn.Module):
203
+ """Complete transformer-replacement block.
204
+
205
+ Sub-block 1: OscillatingCausalScanner (replaces attention)
206
+ Sub-block 2: ComplexMixer→Oscillator→Mixer→Oscillator (replaces MLP)
207
+
208
+ Both sub-blocks use pre-norm residual connections.
209
+ Complex sinc resonance coupling at the end."""
210
+
211
+ def __init__(self, layer_idx, cfg):
212
+ super().__init__()
213
+ dim = cfg.n_embd
214
+
215
+ self.norm1 = ComplexRMSNorm(dim)
216
+ self.scanner = OscillatingCausalScanner(dim, clamp=cfg.scanner_clamp)
217
+
218
+ self.norm2 = ComplexRMSNorm(dim)
219
+ self.mix1 = ComplexMixer(dim)
220
+ self.osc1 = ComplexOscillator(dim)
221
+ self.mix2 = ComplexMixer(dim)
222
+ self.osc2 = ComplexOscillator(dim)
223
+ self.sparse_gate = SparseGate(dim)
224
+ self.last_mlp_mag = None
225
+ self.last_gate_open = None
226
+
227
+ alpha_init = cfg.coupling_alpha_init[layer_idx]
228
+ self.coupling_alpha = nn.Parameter(
229
+ torch.complex(torch.tensor(alpha_init), torch.tensor(0.0))
230
+ )
231
+ print(f" Layer {layer_idx}: α = {alpha_init:.4f} (complex: {self.coupling_alpha.item()})")
232
+
233
+ def forward(self, Z):
234
+ # Sub-block 1: O(N) Causal Scanner (replaces attention)
235
+ Z_res = Z
236
+ Z_normed = self.norm1(Z)
237
+ Z = Z_res + self.scanner(Z_normed)
238
+
239
+ # Sub-block 2: Oscillating Zero-Linear "MLP"
240
+ Z_res = Z
241
+ Z_normed = self.norm2(Z)
242
+ Z_mlp = self.mix1(Z_normed)
243
+ Z_mlp = self.osc1(Z_mlp)
244
+ Z_mlp = self.mix2(Z_mlp)
245
+ Z_mlp = self.osc2(Z_mlp)
246
+
247
+ # Voltage spike gate — feature-level sparsity
248
+ mag = safe_abs(Z_mlp)
249
+ self.last_mlp_mag = mag.detach()
250
+ # Compute spike directly for clean logging
251
+ sg = self.sparse_gate
252
+ voltage = torch.sigmoid(sg.gate_w * mag)
253
+ spike = torch.sigmoid((voltage - sg.threshold) * sg.temperature)
254
+ self.last_gate_open = spike.detach()
255
+ Z_mlp = spike * Z_mlp # gate on spike, apply to full complex
256
+
257
+ # Complex sinc resonance coupling
258
+ mag = safe_abs(Z_mlp)
259
+ sinc_coupling = torch.sinc(mag / math.pi) * Z_mlp
260
+
261
+ Z = Z_res + self.coupling_alpha * sinc_coupling
262
+
263
+ return Z
264
+
265
+
266
+ # =============================================================================
267
+ # COLM — Complex Oscillating Language Model
268
+ # =============================================================================
269
+
270
+ class COLM(nn.Module):
271
+ """Complex Oscillating Language Model.
272
+
273
+ Architecture:
274
+ - Real embedding → linear projection → complex conversion
275
+ - ComplexOscillator initial oscillation
276
+ - N × ZeroLinearBlock (scanner + oscillating MLP)
277
+ - Complex → real concatenation → linear head
278
+ """
279
+
280
+ def __init__(self, cfg):
281
+ super().__init__()
282
+ self.cfg = cfg
283
+
284
+ # Embedding: real tokens → thin embed → linear up → convert to complex
285
+ self.thin_embed = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
286
+ self.embed_up = nn.Linear(cfg.embed_dim, cfg.n_embd, bias=False)
287
+ # Initial oscillation in real space before complex conversion
288
+ self.embed_osc = ComplexOscillator(cfg.n_embd)
289
+
290
+ # Position embedding (real-valued, added to real part)
291
+ self.position_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
292
+
293
+ self.ln_pre = ComplexRMSNorm(cfg.n_embd)
294
+ self.blocks = nn.ModuleList([ZeroLinearBlock(i, cfg) for i in range(cfg.n_layer)])
295
+ self.ln_f = ComplexRMSNorm(cfg.n_embd)
296
+
297
+ # Output head: preserve full complex information by concatenating real + imag
298
+ self.lm_head = nn.Linear(2 * cfg.n_embd, cfg.vocab_size, bias=False)
299
+
300
+ self.apply(self._init_weights)
301
+
302
+ def _init_weights(self, module):
303
+ if isinstance(module, (nn.Linear, nn.Embedding)):
304
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
305
+
306
+ def forward(self, idx, targets=None):
307
+ B, Tseq = idx.size()
308
+
309
+ # Real embedding path
310
+ x_real = self.embed_up(self.thin_embed(idx)) # (B, T, n_embd) real
311
+
312
+ # Add position embeddings (real)
313
+ pos = torch.arange(0, Tseq, dtype=torch.long, device=idx.device)
314
+ x_real = x_real + self.position_emb(pos)
315
+
316
+ # Convert to complex: real part = features, imag part = 0 initially
317
+ Z = torch.complex(x_real, torch.zeros_like(x_real))
318
+
319
+ # Initial complex oscillation
320
+ Z = self.embed_osc(Z)
321
+
322
+ Z = self.ln_pre(Z)
323
+
324
+ for block in self.blocks:
325
+ Z = block(Z)
326
+
327
+ Z = self.ln_f(Z)
328
+
329
+ # Preserve both real and imaginary channels for the classifier head
330
+ x_out = torch.cat([Z.real, Z.imag], dim=-1) # (B, T, 2*n_embd)
331
+ logits = self.lm_head(x_out)
332
+
333
+ loss = None
334
+ if targets is not None:
335
+ loss = F.cross_entropy(logits.view(B * Tseq, -1), targets.view(B * Tseq))
336
+
337
+ return logits, loss