| """ |
| COLM Model Components |
| ===================== |
| Complex Oscillating Language Model — all neural network modules. |
| |
| Components: |
| - ComplexRMSNorm: magnitude normalization preserving phase |
| - ComplexOscillator: sin(W⊙Z+B)·tanh(Z) oscillating neuron |
| - ComplexMixer: fixed unitary cross-dimension routing |
| - OscillatingCausalScanner: O(N) causal sequence scanner |
| - SparseGate: smooth sigmoid voltage-spike gate |
| - ZeroLinearBlock: scanner + oscillating MLP block |
| - COLM: full autoregressive model |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
|
|
|
|
| |
| |
| |
|
|
| class ComplexRMSNorm(nn.Module): |
| """RMSNorm adapted for complex tensors. |
| Normalizes the magnitude while preserving phase angles. |
| Learnable weight is real-valued (scales magnitude).""" |
|
|
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, Z): |
| rms = torch.rsqrt((Z.real.square() + Z.imag.square()).mean(-1, keepdim=True) + self.eps) |
| return Z * (rms * self.weight) |
|
|
|
|
| |
| |
| |
|
|
| def _softcap_imag(z, limit=6.0): |
| return torch.complex(z.real, limit * torch.tanh(z.imag / limit)) |
|
|
|
|
| def safe_abs(Z, eps=1e-12): |
| """Gradient-safe complex magnitude. torch.abs() on complex is sqrt(re²+im²), |
| and sqrt'(0) = inf. Adding eps inside the sqrt prevents inf gradients |
| when the sparse gate zeros out features. Forward values are unchanged |
| to ~6 decimal places.""" |
| return torch.sqrt(Z.real.square() + Z.imag.square() + eps) |
|
|
|
|
| class ComplexOscillator(nn.Module): |
| """Native Complex Oscillating Neuron. |
| W = ω + iφ (frequency + phase as single complex param) |
| B = real_bias + i·imag_bias (complex baseline) |
| |
| PyTorch supports complex sin() and tanh() natively. |
| Wirtinger derivatives flow through automatically.""" |
|
|
| def __init__(self, dim): |
| super().__init__() |
| |
| omega = torch.randn(dim) * 0.1 + 1.0 |
| phi = torch.randn(dim) * 0.1 |
| self.W = nn.Parameter(torch.complex(omega, phi)) |
|
|
| |
| self.B = nn.Parameter(torch.complex(torch.zeros(dim), torch.zeros(dim))) |
|
|
| def forward(self, Z): |
| |
| Z = _softcap_imag(Z, limit=math.pi/2 - 0.2) |
| WZ = _softcap_imag(self.W * Z + self.B, limit=6.0) |
| return torch.sin(WZ) * torch.tanh(Z) |
|
|
|
|
| |
| |
| |
|
|
| class ComplexMixer(nn.Module): |
| """Zero-parameter cross-dimension routing via fixed unitary matrix. |
| QR-orthogonalized complex matrix ensures energy preservation. |
| |
| NOTE: This is O(D²) per token — the FWHT was O(D log D). |
| Chosen for torch.compile compatibility over raw compute efficiency. |
| If compile handles FWHT well on your hardware, swap back.""" |
|
|
| def __init__(self, dim): |
| super().__init__() |
| |
| real_part = torch.randn(dim, dim) |
| imag_part = torch.randn(dim, dim) |
| complex_mat = torch.complex(real_part, imag_part) |
| q, _ = torch.linalg.qr(complex_mat) |
| self.register_buffer('mix_matrix', q) |
|
|
| def forward(self, Z): |
| |
| return Z @ self.mix_matrix.T |
|
|
|
|
| |
| |
| |
|
|
| class OscillatingCausalScanner(nn.Module): |
| """O(N) sequence routing replacing scaled_dot_product_attention. |
| |
| Uses ComplexOscillator to generate: |
| - gate: complex decay (magnitude=retention, angle=phase rotation) |
| - val: complex value signal |
| Then accumulates causally across sequence length T in O(N) time. |
| |
| This is mathematically related to Linear Attention / State Space Models |
| (Mamba, RWKV, Griffin) but powered entirely by oscillating neurons.""" |
|
|
| def __init__(self, dim, clamp=70.0): |
| super().__init__() |
| self.clamp = clamp |
| self.osc_gate = ComplexOscillator(dim) |
| self.osc_val = ComplexOscillator(dim) |
| self.osc_out = ComplexOscillator(dim) |
|
|
| |
| with torch.no_grad(): |
| self.osc_gate.W.data = torch.complex( |
| torch.empty(dim).uniform_(-0.05, 0.05), |
| torch.empty(dim).uniform_(-0.05, 0.05) |
| ) |
|
|
| def forward(self, Z): |
| |
| gate = self.osc_gate(Z) |
| val = self.osc_val(Z) |
|
|
| decay = torch.sigmoid(gate.real) |
| phase = math.pi * torch.tanh(gate.imag / math.pi) |
|
|
| |
| |
| log_gate = torch.complex(torch.log(decay.clamp(min=1e-8)), phase) |
|
|
| cum_log = torch.cumsum(log_gate, dim=1) |
|
|
| CLAMP = self.clamp |
| exp_real = cum_log.real.clamp(min=-CLAMP) |
| exp_cum = torch.exp(torch.complex(exp_real, cum_log.imag)) |
|
|
| neg_real = (-cum_log.real).clamp(max=CLAMP) |
| exp_neg = torch.exp(torch.complex(neg_real, -cum_log.imag)) |
|
|
| H = exp_cum * torch.cumsum(val * exp_neg, dim=1) |
|
|
| |
| H_mag = safe_abs(H).clamp(min=1e-8) |
| H = H * (torch.tanh(H_mag / 8.0) / H_mag) |
| return self.osc_out(H) |
|
|
|
|
| |
| |
| |
|
|
| class SparseGate(nn.Module): |
| """Decoupled spike gate with learnable temperature. |
| Uses smooth sigmoid for clean gradients. |
| |
| voltage = sigmoid(gate_w * x) |
| spike = sigmoid((voltage - threshold) * temperature) |
| output = x * spike |
| """ |
|
|
| def __init__(self, num_features, threshold_init=0.3): |
| super().__init__() |
| self.gate_w = nn.Parameter(torch.ones(num_features) * 0.25) |
| self.threshold = nn.Parameter(torch.full((num_features,), threshold_init)) |
| self.temperature = nn.Parameter(torch.ones(num_features) * 10.0) |
|
|
| def forward(self, x): |
| voltage = torch.sigmoid(self.gate_w * x) |
| spike = torch.sigmoid((voltage - self.threshold) * self.temperature) |
| return x * spike |
|
|
| @torch.no_grad() |
| def get_sparsity(self, x=None): |
| if x is None: |
| return 0.0 |
| voltage = torch.sigmoid(self.gate_w * x) |
| return (voltage > self.threshold).float().mean().item() |
|
|
|
|
| |
| |
| |
|
|
| class ZeroLinearBlock(nn.Module): |
| """Complete transformer-replacement block. |
| |
| Sub-block 1: OscillatingCausalScanner (replaces attention) |
| Sub-block 2: ComplexMixer→Oscillator→Mixer→Oscillator (replaces MLP) |
| |
| Both sub-blocks use pre-norm residual connections. |
| Complex sinc resonance coupling at the end.""" |
|
|
| def __init__(self, layer_idx, cfg): |
| super().__init__() |
| dim = cfg.n_embd |
|
|
| self.norm1 = ComplexRMSNorm(dim) |
| self.scanner = OscillatingCausalScanner(dim, clamp=cfg.scanner_clamp) |
|
|
| self.norm2 = ComplexRMSNorm(dim) |
| self.mix1 = ComplexMixer(dim) |
| self.osc1 = ComplexOscillator(dim) |
| self.mix2 = ComplexMixer(dim) |
| self.osc2 = ComplexOscillator(dim) |
| self.sparse_gate = SparseGate(dim) |
| self.last_mlp_mag = None |
| self.last_gate_open = None |
|
|
| alpha_init = cfg.coupling_alpha_init[layer_idx] |
| self.coupling_alpha = nn.Parameter( |
| torch.complex(torch.tensor(alpha_init), torch.tensor(0.0)) |
| ) |
| print(f" Layer {layer_idx}: α = {alpha_init:.4f} (complex: {self.coupling_alpha.item()})") |
|
|
| def forward(self, Z): |
| |
| Z_res = Z |
| Z_normed = self.norm1(Z) |
| Z = Z_res + self.scanner(Z_normed) |
|
|
| |
| Z_res = Z |
| Z_normed = self.norm2(Z) |
| Z_mlp = self.mix1(Z_normed) |
| Z_mlp = self.osc1(Z_mlp) |
| Z_mlp = self.mix2(Z_mlp) |
| Z_mlp = self.osc2(Z_mlp) |
|
|
| |
| mag = safe_abs(Z_mlp) |
| self.last_mlp_mag = mag.detach() |
| |
| sg = self.sparse_gate |
| voltage = torch.sigmoid(sg.gate_w * mag) |
| spike = torch.sigmoid((voltage - sg.threshold) * sg.temperature) |
| self.last_gate_open = spike.detach() |
| Z_mlp = spike * Z_mlp |
|
|
| |
| mag = safe_abs(Z_mlp) |
| sinc_coupling = torch.sinc(mag / math.pi) * Z_mlp |
|
|
| Z = Z_res + self.coupling_alpha * sinc_coupling |
|
|
| return Z |
|
|
|
|
| |
| |
| |
|
|
| class COLM(nn.Module): |
| """Complex Oscillating Language Model. |
| |
| Architecture: |
| - Real embedding → linear projection → complex conversion |
| - ComplexOscillator initial oscillation |
| - N × ZeroLinearBlock (scanner + oscillating MLP) |
| - Complex → real concatenation → linear head |
| """ |
|
|
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
|
|
| |
| self.thin_embed = nn.Embedding(cfg.vocab_size, cfg.embed_dim) |
| self.embed_up = nn.Linear(cfg.embed_dim, cfg.n_embd, bias=False) |
| |
| self.embed_osc = ComplexOscillator(cfg.n_embd) |
|
|
| |
| self.position_emb = nn.Embedding(cfg.block_size, cfg.n_embd) |
|
|
| self.ln_pre = ComplexRMSNorm(cfg.n_embd) |
| self.blocks = nn.ModuleList([ZeroLinearBlock(i, cfg) for i in range(cfg.n_layer)]) |
| self.ln_f = ComplexRMSNorm(cfg.n_embd) |
|
|
| |
| self.lm_head = nn.Linear(2 * cfg.n_embd, cfg.vocab_size, bias=False) |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, idx, targets=None): |
| B, Tseq = idx.size() |
|
|
| |
| x_real = self.embed_up(self.thin_embed(idx)) |
|
|
| |
| pos = torch.arange(0, Tseq, dtype=torch.long, device=idx.device) |
| x_real = x_real + self.position_emb(pos) |
|
|
| |
| Z = torch.complex(x_real, torch.zeros_like(x_real)) |
|
|
| |
| Z = self.embed_osc(Z) |
|
|
| Z = self.ln_pre(Z) |
|
|
| for block in self.blocks: |
| Z = block(Z) |
|
|
| Z = self.ln_f(Z) |
|
|
| |
| x_out = torch.cat([Z.real, Z.imag], dim=-1) |
| logits = self.lm_head(x_out) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy(logits.view(B * Tseq, -1), targets.view(B * Tseq)) |
|
|
| return logits, loss |
|
|