COLM / model.py
edeneldith's picture
Upload model.py with huggingface_hub
e254270 verified
"""
COLM Model Components
=====================
Complex Oscillating Language Model — all neural network modules.
Components:
- ComplexRMSNorm: magnitude normalization preserving phase
- ComplexOscillator: sin(W⊙Z+B)·tanh(Z) oscillating neuron
- ComplexMixer: fixed unitary cross-dimension routing
- OscillatingCausalScanner: O(N) causal sequence scanner
- SparseGate: smooth sigmoid voltage-spike gate
- ZeroLinearBlock: scanner + oscillating MLP block
- COLM: full autoregressive model
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
# =============================================================================
# COMPLEX RMSNORM — norm the magnitude, preserve the angle
# =============================================================================
class ComplexRMSNorm(nn.Module):
"""RMSNorm adapted for complex tensors.
Normalizes the magnitude while preserving phase angles.
Learnable weight is real-valued (scales magnitude)."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, Z):
rms = torch.rsqrt((Z.real.square() + Z.imag.square()).mean(-1, keepdim=True) + self.eps)
return Z * (rms * self.weight)
# =============================================================================
# COMPLEX OSCILLATOR — sin(W⊙Z+B)·tanh(Z), W,B ∈ ℂ
# =============================================================================
def _softcap_imag(z, limit=6.0):
return torch.complex(z.real, limit * torch.tanh(z.imag / limit))
def safe_abs(Z, eps=1e-12):
"""Gradient-safe complex magnitude. torch.abs() on complex is sqrt(re²+im²),
and sqrt'(0) = inf. Adding eps inside the sqrt prevents inf gradients
when the sparse gate zeros out features. Forward values are unchanged
to ~6 decimal places."""
return torch.sqrt(Z.real.square() + Z.imag.square() + eps)
class ComplexOscillator(nn.Module):
"""Native Complex Oscillating Neuron.
W = ω + iφ (frequency + phase as single complex param)
B = real_bias + i·imag_bias (complex baseline)
PyTorch supports complex sin() and tanh() natively.
Wirtinger derivatives flow through automatically."""
def __init__(self, dim):
super().__init__()
# W: real part = frequency (ω), imag part = phase (φ)
omega = torch.randn(dim) * 0.1 + 1.0
phi = torch.randn(dim) * 0.1
self.W = nn.Parameter(torch.complex(omega, phi))
# B: complex baseline
self.B = nn.Parameter(torch.complex(torch.zeros(dim), torch.zeros(dim)))
def forward(self, Z):
# Z is cfloat. Inductor can fuse this into a single kernel.
Z = _softcap_imag(Z, limit=math.pi/2 - 0.2) # stays below first pole at π/2
WZ = _softcap_imag(self.W * Z + self.B, limit=6.0)
return torch.sin(WZ) * torch.tanh(Z)
# =============================================================================
# COMPLEX MIXER — fixed unitary matrix, zero learnable params
# =============================================================================
class ComplexMixer(nn.Module):
"""Zero-parameter cross-dimension routing via fixed unitary matrix.
QR-orthogonalized complex matrix ensures energy preservation.
NOTE: This is O(D²) per token — the FWHT was O(D log D).
Chosen for torch.compile compatibility over raw compute efficiency.
If compile handles FWHT well on your hardware, swap back."""
def __init__(self, dim):
super().__init__()
# Random complex matrix → QR decomposition → unitary Q
real_part = torch.randn(dim, dim)
imag_part = torch.randn(dim, dim)
complex_mat = torch.complex(real_part, imag_part)
q, _ = torch.linalg.qr(complex_mat)
self.register_buffer('mix_matrix', q)
def forward(self, Z):
# Z: (B, T, D) @ (D, D) -> (B, T, D)
return Z @ self.mix_matrix.T
# =============================================================================
# O(N) COMPLEX OSCILLATOR CAUSAL SCANNER — replaces O(N²) attention
# =============================================================================
class OscillatingCausalScanner(nn.Module):
"""O(N) sequence routing replacing scaled_dot_product_attention.
Uses ComplexOscillator to generate:
- gate: complex decay (magnitude=retention, angle=phase rotation)
- val: complex value signal
Then accumulates causally across sequence length T in O(N) time.
This is mathematically related to Linear Attention / State Space Models
(Mamba, RWKV, Griffin) but powered entirely by oscillating neurons."""
def __init__(self, dim, clamp=70.0):
super().__init__()
self.clamp = clamp
self.osc_gate = ComplexOscillator(dim)
self.osc_val = ComplexOscillator(dim)
self.osc_out = ComplexOscillator(dim)
# Tame the gate's initial W so first gates aren't too aggressive
with torch.no_grad():
self.osc_gate.W.data = torch.complex(
torch.empty(dim).uniform_(-0.05, 0.05),
torch.empty(dim).uniform_(-0.05, 0.05)
)
def forward(self, Z):
# Z: (B, T, D) complex
gate = self.osc_gate(Z)
val = self.osc_val(Z)
decay = torch.sigmoid(gate.real)
phase = math.pi * torch.tanh(gate.imag / math.pi)
# Build log_gate directly — no torch.polar, no .angle()
# This avoids the atan2(0,0) NaN gradient when decay → 0
log_gate = torch.complex(torch.log(decay.clamp(min=1e-8)), phase)
cum_log = torch.cumsum(log_gate, dim=1)
CLAMP = self.clamp
exp_real = cum_log.real.clamp(min=-CLAMP)
exp_cum = torch.exp(torch.complex(exp_real, cum_log.imag))
neg_real = (-cum_log.real).clamp(max=CLAMP)
exp_neg = torch.exp(torch.complex(neg_real, -cum_log.imag))
H = exp_cum * torch.cumsum(val * exp_neg, dim=1)
# GRADIENT ECOLOGY: soft magnitude channel (preserves phase, smooth gradients)
H_mag = safe_abs(H).clamp(min=1e-8)
H = H * (torch.tanh(H_mag / 8.0) / H_mag)
return self.osc_out(H)
# =============================================================================
# SMOOTH SPARSE GATE — proper sigmoid
# =============================================================================
class SparseGate(nn.Module):
"""Decoupled spike gate with learnable temperature.
Uses smooth sigmoid for clean gradients.
voltage = sigmoid(gate_w * x)
spike = sigmoid((voltage - threshold) * temperature)
output = x * spike
"""
def __init__(self, num_features, threshold_init=0.3):
super().__init__()
self.gate_w = nn.Parameter(torch.ones(num_features) * 0.25)
self.threshold = nn.Parameter(torch.full((num_features,), threshold_init))
self.temperature = nn.Parameter(torch.ones(num_features) * 10.0)
def forward(self, x):
voltage = torch.sigmoid(self.gate_w * x)
spike = torch.sigmoid((voltage - self.threshold) * self.temperature)
return x * spike
@torch.no_grad()
def get_sparsity(self, x=None):
if x is None:
return 0.0
voltage = torch.sigmoid(self.gate_w * x)
return (voltage > self.threshold).float().mean().item()
# =============================================================================
# ZERO-LINEAR BLOCK — scanner + complex mixer/oscillator MLP
# =============================================================================
class ZeroLinearBlock(nn.Module):
"""Complete transformer-replacement block.
Sub-block 1: OscillatingCausalScanner (replaces attention)
Sub-block 2: ComplexMixer→Oscillator→Mixer→Oscillator (replaces MLP)
Both sub-blocks use pre-norm residual connections.
Complex sinc resonance coupling at the end."""
def __init__(self, layer_idx, cfg):
super().__init__()
dim = cfg.n_embd
self.norm1 = ComplexRMSNorm(dim)
self.scanner = OscillatingCausalScanner(dim, clamp=cfg.scanner_clamp)
self.norm2 = ComplexRMSNorm(dim)
self.mix1 = ComplexMixer(dim)
self.osc1 = ComplexOscillator(dim)
self.mix2 = ComplexMixer(dim)
self.osc2 = ComplexOscillator(dim)
self.sparse_gate = SparseGate(dim)
self.last_mlp_mag = None
self.last_gate_open = None
alpha_init = cfg.coupling_alpha_init[layer_idx]
self.coupling_alpha = nn.Parameter(
torch.complex(torch.tensor(alpha_init), torch.tensor(0.0))
)
print(f" Layer {layer_idx}: α = {alpha_init:.4f} (complex: {self.coupling_alpha.item()})")
def forward(self, Z):
# Sub-block 1: O(N) Causal Scanner (replaces attention)
Z_res = Z
Z_normed = self.norm1(Z)
Z = Z_res + self.scanner(Z_normed)
# Sub-block 2: Oscillating Zero-Linear "MLP"
Z_res = Z
Z_normed = self.norm2(Z)
Z_mlp = self.mix1(Z_normed)
Z_mlp = self.osc1(Z_mlp)
Z_mlp = self.mix2(Z_mlp)
Z_mlp = self.osc2(Z_mlp)
# Voltage spike gate — feature-level sparsity
mag = safe_abs(Z_mlp)
self.last_mlp_mag = mag.detach()
# Compute spike directly for clean logging
sg = self.sparse_gate
voltage = torch.sigmoid(sg.gate_w * mag)
spike = torch.sigmoid((voltage - sg.threshold) * sg.temperature)
self.last_gate_open = spike.detach()
Z_mlp = spike * Z_mlp # gate on spike, apply to full complex
# Complex sinc resonance coupling
mag = safe_abs(Z_mlp)
sinc_coupling = torch.sinc(mag / math.pi) * Z_mlp
Z = Z_res + self.coupling_alpha * sinc_coupling
return Z
# =============================================================================
# COLM — Complex Oscillating Language Model
# =============================================================================
class COLM(nn.Module):
"""Complex Oscillating Language Model.
Architecture:
- Real embedding → linear projection → complex conversion
- ComplexOscillator initial oscillation
- N × ZeroLinearBlock (scanner + oscillating MLP)
- Complex → real concatenation → linear head
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# Embedding: real tokens → thin embed → linear up → convert to complex
self.thin_embed = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
self.embed_up = nn.Linear(cfg.embed_dim, cfg.n_embd, bias=False)
# Initial oscillation in real space before complex conversion
self.embed_osc = ComplexOscillator(cfg.n_embd)
# Position embedding (real-valued, added to real part)
self.position_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
self.ln_pre = ComplexRMSNorm(cfg.n_embd)
self.blocks = nn.ModuleList([ZeroLinearBlock(i, cfg) for i in range(cfg.n_layer)])
self.ln_f = ComplexRMSNorm(cfg.n_embd)
# Output head: preserve full complex information by concatenating real + imag
self.lm_head = nn.Linear(2 * cfg.n_embd, cfg.vocab_size, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, Tseq = idx.size()
# Real embedding path
x_real = self.embed_up(self.thin_embed(idx)) # (B, T, n_embd) real
# Add position embeddings (real)
pos = torch.arange(0, Tseq, dtype=torch.long, device=idx.device)
x_real = x_real + self.position_emb(pos)
# Convert to complex: real part = features, imag part = 0 initially
Z = torch.complex(x_real, torch.zeros_like(x_real))
# Initial complex oscillation
Z = self.embed_osc(Z)
Z = self.ln_pre(Z)
for block in self.blocks:
Z = block(Z)
Z = self.ln_f(Z)
# Preserve both real and imaginary channels for the classifier head
x_out = torch.cat([Z.real, Z.imag], dim=-1) # (B, T, 2*n_embd)
logits = self.lm_head(x_out)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(B * Tseq, -1), targets.view(B * Tseq))
return logits, loss