File size: 12,542 Bytes
e254270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
"""
COLM Model Components
=====================
Complex Oscillating Language Model — all neural network modules.

Components:
  - ComplexRMSNorm: magnitude normalization preserving phase
  - ComplexOscillator: sin(W⊙Z+B)·tanh(Z) oscillating neuron
  - ComplexMixer: fixed unitary cross-dimension routing
  - OscillatingCausalScanner: O(N) causal sequence scanner
  - SparseGate: smooth sigmoid voltage-spike gate
  - ZeroLinearBlock: scanner + oscillating MLP block
  - COLM: full autoregressive model
"""

import math
import torch
import torch.nn as nn
from torch.nn import functional as F


# =============================================================================
# COMPLEX RMSNORM — norm the magnitude, preserve the angle
# =============================================================================

class ComplexRMSNorm(nn.Module):
    """RMSNorm adapted for complex tensors.
    Normalizes the magnitude while preserving phase angles.
    Learnable weight is real-valued (scales magnitude)."""

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, Z):
        rms = torch.rsqrt((Z.real.square() + Z.imag.square()).mean(-1, keepdim=True) + self.eps)
        return Z * (rms * self.weight)


# =============================================================================
# COMPLEX OSCILLATOR — sin(W⊙Z+B)·tanh(Z), W,B ∈ ℂ
# =============================================================================

def _softcap_imag(z, limit=6.0):
    return torch.complex(z.real, limit * torch.tanh(z.imag / limit))


def safe_abs(Z, eps=1e-12):
    """Gradient-safe complex magnitude. torch.abs() on complex is sqrt(re²+im²),
    and sqrt'(0) = inf. Adding eps inside the sqrt prevents inf gradients
    when the sparse gate zeros out features. Forward values are unchanged
    to ~6 decimal places."""
    return torch.sqrt(Z.real.square() + Z.imag.square() + eps)


class ComplexOscillator(nn.Module):
    """Native Complex Oscillating Neuron.
    W = ω + iφ (frequency + phase as single complex param)
    B = real_bias + i·imag_bias (complex baseline)

    PyTorch supports complex sin() and tanh() natively.
    Wirtinger derivatives flow through automatically."""

    def __init__(self, dim):
        super().__init__()
        # W: real part = frequency (ω), imag part = phase (φ)
        omega = torch.randn(dim) * 0.1 + 1.0
        phi = torch.randn(dim) * 0.1
        self.W = nn.Parameter(torch.complex(omega, phi))

        # B: complex baseline
        self.B = nn.Parameter(torch.complex(torch.zeros(dim), torch.zeros(dim)))

    def forward(self, Z):
        # Z is cfloat. Inductor can fuse this into a single kernel.
        Z = _softcap_imag(Z, limit=math.pi/2 - 0.2)  # stays below first pole at π/2
        WZ = _softcap_imag(self.W * Z + self.B, limit=6.0)
        return torch.sin(WZ) * torch.tanh(Z)


# =============================================================================
# COMPLEX MIXER — fixed unitary matrix, zero learnable params
# =============================================================================

class ComplexMixer(nn.Module):
    """Zero-parameter cross-dimension routing via fixed unitary matrix.
    QR-orthogonalized complex matrix ensures energy preservation.

    NOTE: This is O(D²) per token — the FWHT was O(D log D).
    Chosen for torch.compile compatibility over raw compute efficiency.
    If compile handles FWHT well on your hardware, swap back."""

    def __init__(self, dim):
        super().__init__()
        # Random complex matrix → QR decomposition → unitary Q
        real_part = torch.randn(dim, dim)
        imag_part = torch.randn(dim, dim)
        complex_mat = torch.complex(real_part, imag_part)
        q, _ = torch.linalg.qr(complex_mat)
        self.register_buffer('mix_matrix', q)

    def forward(self, Z):
        # Z: (B, T, D) @ (D, D) -> (B, T, D)
        return Z @ self.mix_matrix.T


# =============================================================================
# O(N) COMPLEX OSCILLATOR CAUSAL SCANNER — replaces O(N²) attention
# =============================================================================

class OscillatingCausalScanner(nn.Module):
    """O(N) sequence routing replacing scaled_dot_product_attention.

    Uses ComplexOscillator to generate:
      - gate: complex decay (magnitude=retention, angle=phase rotation)
      - val: complex value signal
    Then accumulates causally across sequence length T in O(N) time.

    This is mathematically related to Linear Attention / State Space Models
    (Mamba, RWKV, Griffin) but powered entirely by oscillating neurons."""

    def __init__(self, dim, clamp=70.0):
        super().__init__()
        self.clamp = clamp
        self.osc_gate = ComplexOscillator(dim)
        self.osc_val = ComplexOscillator(dim)
        self.osc_out = ComplexOscillator(dim)

        # Tame the gate's initial W so first gates aren't too aggressive
        with torch.no_grad():
            self.osc_gate.W.data = torch.complex(
                torch.empty(dim).uniform_(-0.05, 0.05),
                torch.empty(dim).uniform_(-0.05, 0.05)
            )

    def forward(self, Z):
        # Z: (B, T, D) complex
        gate = self.osc_gate(Z)
        val = self.osc_val(Z)

        decay = torch.sigmoid(gate.real)
        phase = math.pi * torch.tanh(gate.imag / math.pi)

        # Build log_gate directly — no torch.polar, no .angle()
        # This avoids the atan2(0,0) NaN gradient when decay → 0
        log_gate = torch.complex(torch.log(decay.clamp(min=1e-8)), phase)

        cum_log = torch.cumsum(log_gate, dim=1)

        CLAMP = self.clamp
        exp_real = cum_log.real.clamp(min=-CLAMP)
        exp_cum = torch.exp(torch.complex(exp_real, cum_log.imag))

        neg_real = (-cum_log.real).clamp(max=CLAMP)
        exp_neg = torch.exp(torch.complex(neg_real, -cum_log.imag))

        H = exp_cum * torch.cumsum(val * exp_neg, dim=1)

        # GRADIENT ECOLOGY: soft magnitude channel (preserves phase, smooth gradients)
        H_mag = safe_abs(H).clamp(min=1e-8)
        H = H * (torch.tanh(H_mag / 8.0) / H_mag)
        return self.osc_out(H)


# =============================================================================
# SMOOTH SPARSE GATE — proper sigmoid
# =============================================================================

class SparseGate(nn.Module):
    """Decoupled spike gate with learnable temperature.
    Uses smooth sigmoid for clean gradients.

    voltage = sigmoid(gate_w * x)
    spike = sigmoid((voltage - threshold) * temperature)
    output = x * spike
    """

    def __init__(self, num_features, threshold_init=0.3):
        super().__init__()
        self.gate_w = nn.Parameter(torch.ones(num_features) * 0.25)
        self.threshold = nn.Parameter(torch.full((num_features,), threshold_init))
        self.temperature = nn.Parameter(torch.ones(num_features) * 10.0)

    def forward(self, x):
        voltage = torch.sigmoid(self.gate_w * x)
        spike = torch.sigmoid((voltage - self.threshold) * self.temperature)
        return x * spike

    @torch.no_grad()
    def get_sparsity(self, x=None):
        if x is None:
            return 0.0
        voltage = torch.sigmoid(self.gate_w * x)
        return (voltage > self.threshold).float().mean().item()


# =============================================================================
# ZERO-LINEAR BLOCK — scanner + complex mixer/oscillator MLP
# =============================================================================

class ZeroLinearBlock(nn.Module):
    """Complete transformer-replacement block.

    Sub-block 1: OscillatingCausalScanner (replaces attention)
    Sub-block 2: ComplexMixer→Oscillator→Mixer→Oscillator (replaces MLP)

    Both sub-blocks use pre-norm residual connections.
    Complex sinc resonance coupling at the end."""

    def __init__(self, layer_idx, cfg):
        super().__init__()
        dim = cfg.n_embd

        self.norm1 = ComplexRMSNorm(dim)
        self.scanner = OscillatingCausalScanner(dim, clamp=cfg.scanner_clamp)

        self.norm2 = ComplexRMSNorm(dim)
        self.mix1 = ComplexMixer(dim)
        self.osc1 = ComplexOscillator(dim)
        self.mix2 = ComplexMixer(dim)
        self.osc2 = ComplexOscillator(dim)
        self.sparse_gate = SparseGate(dim)
        self.last_mlp_mag = None
        self.last_gate_open = None

        alpha_init = cfg.coupling_alpha_init[layer_idx]
        self.coupling_alpha = nn.Parameter(
            torch.complex(torch.tensor(alpha_init), torch.tensor(0.0))
        )
        print(f"  Layer {layer_idx}: α = {alpha_init:.4f} (complex: {self.coupling_alpha.item()})")

    def forward(self, Z):
        # Sub-block 1: O(N) Causal Scanner (replaces attention)
        Z_res = Z
        Z_normed = self.norm1(Z)
        Z = Z_res + self.scanner(Z_normed)

        # Sub-block 2: Oscillating Zero-Linear "MLP"
        Z_res = Z
        Z_normed = self.norm2(Z)
        Z_mlp = self.mix1(Z_normed)
        Z_mlp = self.osc1(Z_mlp)
        Z_mlp = self.mix2(Z_mlp)
        Z_mlp = self.osc2(Z_mlp)

        # Voltage spike gate — feature-level sparsity
        mag = safe_abs(Z_mlp)
        self.last_mlp_mag = mag.detach()
        # Compute spike directly for clean logging
        sg = self.sparse_gate
        voltage = torch.sigmoid(sg.gate_w * mag)
        spike = torch.sigmoid((voltage - sg.threshold) * sg.temperature)
        self.last_gate_open = spike.detach()
        Z_mlp = spike * Z_mlp  # gate on spike, apply to full complex

        # Complex sinc resonance coupling
        mag = safe_abs(Z_mlp)
        sinc_coupling = torch.sinc(mag / math.pi) * Z_mlp

        Z = Z_res + self.coupling_alpha * sinc_coupling

        return Z


# =============================================================================
# COLM — Complex Oscillating Language Model
# =============================================================================

class COLM(nn.Module):
    """Complex Oscillating Language Model.

    Architecture:
      - Real embedding → linear projection → complex conversion
      - ComplexOscillator initial oscillation
      - N × ZeroLinearBlock (scanner + oscillating MLP)
      - Complex → real concatenation → linear head
    """

    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        # Embedding: real tokens → thin embed → linear up → convert to complex
        self.thin_embed = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
        self.embed_up = nn.Linear(cfg.embed_dim, cfg.n_embd, bias=False)
        # Initial oscillation in real space before complex conversion
        self.embed_osc = ComplexOscillator(cfg.n_embd)

        # Position embedding (real-valued, added to real part)
        self.position_emb = nn.Embedding(cfg.block_size, cfg.n_embd)

        self.ln_pre = ComplexRMSNorm(cfg.n_embd)
        self.blocks = nn.ModuleList([ZeroLinearBlock(i, cfg) for i in range(cfg.n_layer)])
        self.ln_f = ComplexRMSNorm(cfg.n_embd)

        # Output head: preserve full complex information by concatenating real + imag
        self.lm_head = nn.Linear(2 * cfg.n_embd, cfg.vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, Tseq = idx.size()

        # Real embedding path
        x_real = self.embed_up(self.thin_embed(idx))  # (B, T, n_embd) real

        # Add position embeddings (real)
        pos = torch.arange(0, Tseq, dtype=torch.long, device=idx.device)
        x_real = x_real + self.position_emb(pos)

        # Convert to complex: real part = features, imag part = 0 initially
        Z = torch.complex(x_real, torch.zeros_like(x_real))

        # Initial complex oscillation
        Z = self.embed_osc(Z)

        Z = self.ln_pre(Z)

        for block in self.blocks:
            Z = block(Z)

        Z = self.ln_f(Z)

        # Preserve both real and imaginary channels for the classifier head
        x_out = torch.cat([Z.real, Z.imag], dim=-1)  # (B, T, 2*n_embd)
        logits = self.lm_head(x_out)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(B * Tseq, -1), targets.view(B * Tseq))

        return logits, loss