""" COLM Model Components ===================== Complex Oscillating Language Model — all neural network modules. Components: - ComplexRMSNorm: magnitude normalization preserving phase - ComplexOscillator: sin(W⊙Z+B)·tanh(Z) oscillating neuron - ComplexMixer: fixed unitary cross-dimension routing - OscillatingCausalScanner: O(N) causal sequence scanner - SparseGate: smooth sigmoid voltage-spike gate - ZeroLinearBlock: scanner + oscillating MLP block - COLM: full autoregressive model """ import math import torch import torch.nn as nn from torch.nn import functional as F # ============================================================================= # COMPLEX RMSNORM — norm the magnitude, preserve the angle # ============================================================================= class ComplexRMSNorm(nn.Module): """RMSNorm adapted for complex tensors. Normalizes the magnitude while preserving phase angles. Learnable weight is real-valued (scales magnitude).""" def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, Z): rms = torch.rsqrt((Z.real.square() + Z.imag.square()).mean(-1, keepdim=True) + self.eps) return Z * (rms * self.weight) # ============================================================================= # COMPLEX OSCILLATOR — sin(W⊙Z+B)·tanh(Z), W,B ∈ ℂ # ============================================================================= def _softcap_imag(z, limit=6.0): return torch.complex(z.real, limit * torch.tanh(z.imag / limit)) def safe_abs(Z, eps=1e-12): """Gradient-safe complex magnitude. torch.abs() on complex is sqrt(re²+im²), and sqrt'(0) = inf. Adding eps inside the sqrt prevents inf gradients when the sparse gate zeros out features. Forward values are unchanged to ~6 decimal places.""" return torch.sqrt(Z.real.square() + Z.imag.square() + eps) class ComplexOscillator(nn.Module): """Native Complex Oscillating Neuron. W = ω + iφ (frequency + phase as single complex param) B = real_bias + i·imag_bias (complex baseline) PyTorch supports complex sin() and tanh() natively. Wirtinger derivatives flow through automatically.""" def __init__(self, dim): super().__init__() # W: real part = frequency (ω), imag part = phase (φ) omega = torch.randn(dim) * 0.1 + 1.0 phi = torch.randn(dim) * 0.1 self.W = nn.Parameter(torch.complex(omega, phi)) # B: complex baseline self.B = nn.Parameter(torch.complex(torch.zeros(dim), torch.zeros(dim))) def forward(self, Z): # Z is cfloat. Inductor can fuse this into a single kernel. Z = _softcap_imag(Z, limit=math.pi/2 - 0.2) # stays below first pole at π/2 WZ = _softcap_imag(self.W * Z + self.B, limit=6.0) return torch.sin(WZ) * torch.tanh(Z) # ============================================================================= # COMPLEX MIXER — fixed unitary matrix, zero learnable params # ============================================================================= class ComplexMixer(nn.Module): """Zero-parameter cross-dimension routing via fixed unitary matrix. QR-orthogonalized complex matrix ensures energy preservation. NOTE: This is O(D²) per token — the FWHT was O(D log D). Chosen for torch.compile compatibility over raw compute efficiency. If compile handles FWHT well on your hardware, swap back.""" def __init__(self, dim): super().__init__() # Random complex matrix → QR decomposition → unitary Q real_part = torch.randn(dim, dim) imag_part = torch.randn(dim, dim) complex_mat = torch.complex(real_part, imag_part) q, _ = torch.linalg.qr(complex_mat) self.register_buffer('mix_matrix', q) def forward(self, Z): # Z: (B, T, D) @ (D, D) -> (B, T, D) return Z @ self.mix_matrix.T # ============================================================================= # O(N) COMPLEX OSCILLATOR CAUSAL SCANNER — replaces O(N²) attention # ============================================================================= class OscillatingCausalScanner(nn.Module): """O(N) sequence routing replacing scaled_dot_product_attention. Uses ComplexOscillator to generate: - gate: complex decay (magnitude=retention, angle=phase rotation) - val: complex value signal Then accumulates causally across sequence length T in O(N) time. This is mathematically related to Linear Attention / State Space Models (Mamba, RWKV, Griffin) but powered entirely by oscillating neurons.""" def __init__(self, dim, clamp=70.0): super().__init__() self.clamp = clamp self.osc_gate = ComplexOscillator(dim) self.osc_val = ComplexOscillator(dim) self.osc_out = ComplexOscillator(dim) # Tame the gate's initial W so first gates aren't too aggressive with torch.no_grad(): self.osc_gate.W.data = torch.complex( torch.empty(dim).uniform_(-0.05, 0.05), torch.empty(dim).uniform_(-0.05, 0.05) ) def forward(self, Z): # Z: (B, T, D) complex gate = self.osc_gate(Z) val = self.osc_val(Z) decay = torch.sigmoid(gate.real) phase = math.pi * torch.tanh(gate.imag / math.pi) # Build log_gate directly — no torch.polar, no .angle() # This avoids the atan2(0,0) NaN gradient when decay → 0 log_gate = torch.complex(torch.log(decay.clamp(min=1e-8)), phase) cum_log = torch.cumsum(log_gate, dim=1) CLAMP = self.clamp exp_real = cum_log.real.clamp(min=-CLAMP) exp_cum = torch.exp(torch.complex(exp_real, cum_log.imag)) neg_real = (-cum_log.real).clamp(max=CLAMP) exp_neg = torch.exp(torch.complex(neg_real, -cum_log.imag)) H = exp_cum * torch.cumsum(val * exp_neg, dim=1) # GRADIENT ECOLOGY: soft magnitude channel (preserves phase, smooth gradients) H_mag = safe_abs(H).clamp(min=1e-8) H = H * (torch.tanh(H_mag / 8.0) / H_mag) return self.osc_out(H) # ============================================================================= # SMOOTH SPARSE GATE — proper sigmoid # ============================================================================= class SparseGate(nn.Module): """Decoupled spike gate with learnable temperature. Uses smooth sigmoid for clean gradients. voltage = sigmoid(gate_w * x) spike = sigmoid((voltage - threshold) * temperature) output = x * spike """ def __init__(self, num_features, threshold_init=0.3): super().__init__() self.gate_w = nn.Parameter(torch.ones(num_features) * 0.25) self.threshold = nn.Parameter(torch.full((num_features,), threshold_init)) self.temperature = nn.Parameter(torch.ones(num_features) * 10.0) def forward(self, x): voltage = torch.sigmoid(self.gate_w * x) spike = torch.sigmoid((voltage - self.threshold) * self.temperature) return x * spike @torch.no_grad() def get_sparsity(self, x=None): if x is None: return 0.0 voltage = torch.sigmoid(self.gate_w * x) return (voltage > self.threshold).float().mean().item() # ============================================================================= # ZERO-LINEAR BLOCK — scanner + complex mixer/oscillator MLP # ============================================================================= class ZeroLinearBlock(nn.Module): """Complete transformer-replacement block. Sub-block 1: OscillatingCausalScanner (replaces attention) Sub-block 2: ComplexMixer→Oscillator→Mixer→Oscillator (replaces MLP) Both sub-blocks use pre-norm residual connections. Complex sinc resonance coupling at the end.""" def __init__(self, layer_idx, cfg): super().__init__() dim = cfg.n_embd self.norm1 = ComplexRMSNorm(dim) self.scanner = OscillatingCausalScanner(dim, clamp=cfg.scanner_clamp) self.norm2 = ComplexRMSNorm(dim) self.mix1 = ComplexMixer(dim) self.osc1 = ComplexOscillator(dim) self.mix2 = ComplexMixer(dim) self.osc2 = ComplexOscillator(dim) self.sparse_gate = SparseGate(dim) self.last_mlp_mag = None self.last_gate_open = None alpha_init = cfg.coupling_alpha_init[layer_idx] self.coupling_alpha = nn.Parameter( torch.complex(torch.tensor(alpha_init), torch.tensor(0.0)) ) print(f" Layer {layer_idx}: α = {alpha_init:.4f} (complex: {self.coupling_alpha.item()})") def forward(self, Z): # Sub-block 1: O(N) Causal Scanner (replaces attention) Z_res = Z Z_normed = self.norm1(Z) Z = Z_res + self.scanner(Z_normed) # Sub-block 2: Oscillating Zero-Linear "MLP" Z_res = Z Z_normed = self.norm2(Z) Z_mlp = self.mix1(Z_normed) Z_mlp = self.osc1(Z_mlp) Z_mlp = self.mix2(Z_mlp) Z_mlp = self.osc2(Z_mlp) # Voltage spike gate — feature-level sparsity mag = safe_abs(Z_mlp) self.last_mlp_mag = mag.detach() # Compute spike directly for clean logging sg = self.sparse_gate voltage = torch.sigmoid(sg.gate_w * mag) spike = torch.sigmoid((voltage - sg.threshold) * sg.temperature) self.last_gate_open = spike.detach() Z_mlp = spike * Z_mlp # gate on spike, apply to full complex # Complex sinc resonance coupling mag = safe_abs(Z_mlp) sinc_coupling = torch.sinc(mag / math.pi) * Z_mlp Z = Z_res + self.coupling_alpha * sinc_coupling return Z # ============================================================================= # COLM — Complex Oscillating Language Model # ============================================================================= class COLM(nn.Module): """Complex Oscillating Language Model. Architecture: - Real embedding → linear projection → complex conversion - ComplexOscillator initial oscillation - N × ZeroLinearBlock (scanner + oscillating MLP) - Complex → real concatenation → linear head """ def __init__(self, cfg): super().__init__() self.cfg = cfg # Embedding: real tokens → thin embed → linear up → convert to complex self.thin_embed = nn.Embedding(cfg.vocab_size, cfg.embed_dim) self.embed_up = nn.Linear(cfg.embed_dim, cfg.n_embd, bias=False) # Initial oscillation in real space before complex conversion self.embed_osc = ComplexOscillator(cfg.n_embd) # Position embedding (real-valued, added to real part) self.position_emb = nn.Embedding(cfg.block_size, cfg.n_embd) self.ln_pre = ComplexRMSNorm(cfg.n_embd) self.blocks = nn.ModuleList([ZeroLinearBlock(i, cfg) for i in range(cfg.n_layer)]) self.ln_f = ComplexRMSNorm(cfg.n_embd) # Output head: preserve full complex information by concatenating real + imag self.lm_head = nn.Linear(2 * cfg.n_embd, cfg.vocab_size, bias=False) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, Tseq = idx.size() # Real embedding path x_real = self.embed_up(self.thin_embed(idx)) # (B, T, n_embd) real # Add position embeddings (real) pos = torch.arange(0, Tseq, dtype=torch.long, device=idx.device) x_real = x_real + self.position_emb(pos) # Convert to complex: real part = features, imag part = 0 initially Z = torch.complex(x_real, torch.zeros_like(x_real)) # Initial complex oscillation Z = self.embed_osc(Z) Z = self.ln_pre(Z) for block in self.blocks: Z = block(Z) Z = self.ln_f(Z) # Preserve both real and imaginary channels for the classifier head x_out = torch.cat([Z.real, Z.imag], dim=-1) # (B, T, 2*n_embd) logits = self.lm_head(x_out) loss = None if targets is not None: loss = F.cross_entropy(logits.view(B * Tseq, -1), targets.view(B * Tseq)) return logits, loss