File size: 12,542 Bytes
e254270 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 | """
COLM Model Components
=====================
Complex Oscillating Language Model — all neural network modules.
Components:
- ComplexRMSNorm: magnitude normalization preserving phase
- ComplexOscillator: sin(W⊙Z+B)·tanh(Z) oscillating neuron
- ComplexMixer: fixed unitary cross-dimension routing
- OscillatingCausalScanner: O(N) causal sequence scanner
- SparseGate: smooth sigmoid voltage-spike gate
- ZeroLinearBlock: scanner + oscillating MLP block
- COLM: full autoregressive model
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
# =============================================================================
# COMPLEX RMSNORM — norm the magnitude, preserve the angle
# =============================================================================
class ComplexRMSNorm(nn.Module):
"""RMSNorm adapted for complex tensors.
Normalizes the magnitude while preserving phase angles.
Learnable weight is real-valued (scales magnitude)."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, Z):
rms = torch.rsqrt((Z.real.square() + Z.imag.square()).mean(-1, keepdim=True) + self.eps)
return Z * (rms * self.weight)
# =============================================================================
# COMPLEX OSCILLATOR — sin(W⊙Z+B)·tanh(Z), W,B ∈ ℂ
# =============================================================================
def _softcap_imag(z, limit=6.0):
return torch.complex(z.real, limit * torch.tanh(z.imag / limit))
def safe_abs(Z, eps=1e-12):
"""Gradient-safe complex magnitude. torch.abs() on complex is sqrt(re²+im²),
and sqrt'(0) = inf. Adding eps inside the sqrt prevents inf gradients
when the sparse gate zeros out features. Forward values are unchanged
to ~6 decimal places."""
return torch.sqrt(Z.real.square() + Z.imag.square() + eps)
class ComplexOscillator(nn.Module):
"""Native Complex Oscillating Neuron.
W = ω + iφ (frequency + phase as single complex param)
B = real_bias + i·imag_bias (complex baseline)
PyTorch supports complex sin() and tanh() natively.
Wirtinger derivatives flow through automatically."""
def __init__(self, dim):
super().__init__()
# W: real part = frequency (ω), imag part = phase (φ)
omega = torch.randn(dim) * 0.1 + 1.0
phi = torch.randn(dim) * 0.1
self.W = nn.Parameter(torch.complex(omega, phi))
# B: complex baseline
self.B = nn.Parameter(torch.complex(torch.zeros(dim), torch.zeros(dim)))
def forward(self, Z):
# Z is cfloat. Inductor can fuse this into a single kernel.
Z = _softcap_imag(Z, limit=math.pi/2 - 0.2) # stays below first pole at π/2
WZ = _softcap_imag(self.W * Z + self.B, limit=6.0)
return torch.sin(WZ) * torch.tanh(Z)
# =============================================================================
# COMPLEX MIXER — fixed unitary matrix, zero learnable params
# =============================================================================
class ComplexMixer(nn.Module):
"""Zero-parameter cross-dimension routing via fixed unitary matrix.
QR-orthogonalized complex matrix ensures energy preservation.
NOTE: This is O(D²) per token — the FWHT was O(D log D).
Chosen for torch.compile compatibility over raw compute efficiency.
If compile handles FWHT well on your hardware, swap back."""
def __init__(self, dim):
super().__init__()
# Random complex matrix → QR decomposition → unitary Q
real_part = torch.randn(dim, dim)
imag_part = torch.randn(dim, dim)
complex_mat = torch.complex(real_part, imag_part)
q, _ = torch.linalg.qr(complex_mat)
self.register_buffer('mix_matrix', q)
def forward(self, Z):
# Z: (B, T, D) @ (D, D) -> (B, T, D)
return Z @ self.mix_matrix.T
# =============================================================================
# O(N) COMPLEX OSCILLATOR CAUSAL SCANNER — replaces O(N²) attention
# =============================================================================
class OscillatingCausalScanner(nn.Module):
"""O(N) sequence routing replacing scaled_dot_product_attention.
Uses ComplexOscillator to generate:
- gate: complex decay (magnitude=retention, angle=phase rotation)
- val: complex value signal
Then accumulates causally across sequence length T in O(N) time.
This is mathematically related to Linear Attention / State Space Models
(Mamba, RWKV, Griffin) but powered entirely by oscillating neurons."""
def __init__(self, dim, clamp=70.0):
super().__init__()
self.clamp = clamp
self.osc_gate = ComplexOscillator(dim)
self.osc_val = ComplexOscillator(dim)
self.osc_out = ComplexOscillator(dim)
# Tame the gate's initial W so first gates aren't too aggressive
with torch.no_grad():
self.osc_gate.W.data = torch.complex(
torch.empty(dim).uniform_(-0.05, 0.05),
torch.empty(dim).uniform_(-0.05, 0.05)
)
def forward(self, Z):
# Z: (B, T, D) complex
gate = self.osc_gate(Z)
val = self.osc_val(Z)
decay = torch.sigmoid(gate.real)
phase = math.pi * torch.tanh(gate.imag / math.pi)
# Build log_gate directly — no torch.polar, no .angle()
# This avoids the atan2(0,0) NaN gradient when decay → 0
log_gate = torch.complex(torch.log(decay.clamp(min=1e-8)), phase)
cum_log = torch.cumsum(log_gate, dim=1)
CLAMP = self.clamp
exp_real = cum_log.real.clamp(min=-CLAMP)
exp_cum = torch.exp(torch.complex(exp_real, cum_log.imag))
neg_real = (-cum_log.real).clamp(max=CLAMP)
exp_neg = torch.exp(torch.complex(neg_real, -cum_log.imag))
H = exp_cum * torch.cumsum(val * exp_neg, dim=1)
# GRADIENT ECOLOGY: soft magnitude channel (preserves phase, smooth gradients)
H_mag = safe_abs(H).clamp(min=1e-8)
H = H * (torch.tanh(H_mag / 8.0) / H_mag)
return self.osc_out(H)
# =============================================================================
# SMOOTH SPARSE GATE — proper sigmoid
# =============================================================================
class SparseGate(nn.Module):
"""Decoupled spike gate with learnable temperature.
Uses smooth sigmoid for clean gradients.
voltage = sigmoid(gate_w * x)
spike = sigmoid((voltage - threshold) * temperature)
output = x * spike
"""
def __init__(self, num_features, threshold_init=0.3):
super().__init__()
self.gate_w = nn.Parameter(torch.ones(num_features) * 0.25)
self.threshold = nn.Parameter(torch.full((num_features,), threshold_init))
self.temperature = nn.Parameter(torch.ones(num_features) * 10.0)
def forward(self, x):
voltage = torch.sigmoid(self.gate_w * x)
spike = torch.sigmoid((voltage - self.threshold) * self.temperature)
return x * spike
@torch.no_grad()
def get_sparsity(self, x=None):
if x is None:
return 0.0
voltage = torch.sigmoid(self.gate_w * x)
return (voltage > self.threshold).float().mean().item()
# =============================================================================
# ZERO-LINEAR BLOCK — scanner + complex mixer/oscillator MLP
# =============================================================================
class ZeroLinearBlock(nn.Module):
"""Complete transformer-replacement block.
Sub-block 1: OscillatingCausalScanner (replaces attention)
Sub-block 2: ComplexMixer→Oscillator→Mixer→Oscillator (replaces MLP)
Both sub-blocks use pre-norm residual connections.
Complex sinc resonance coupling at the end."""
def __init__(self, layer_idx, cfg):
super().__init__()
dim = cfg.n_embd
self.norm1 = ComplexRMSNorm(dim)
self.scanner = OscillatingCausalScanner(dim, clamp=cfg.scanner_clamp)
self.norm2 = ComplexRMSNorm(dim)
self.mix1 = ComplexMixer(dim)
self.osc1 = ComplexOscillator(dim)
self.mix2 = ComplexMixer(dim)
self.osc2 = ComplexOscillator(dim)
self.sparse_gate = SparseGate(dim)
self.last_mlp_mag = None
self.last_gate_open = None
alpha_init = cfg.coupling_alpha_init[layer_idx]
self.coupling_alpha = nn.Parameter(
torch.complex(torch.tensor(alpha_init), torch.tensor(0.0))
)
print(f" Layer {layer_idx}: α = {alpha_init:.4f} (complex: {self.coupling_alpha.item()})")
def forward(self, Z):
# Sub-block 1: O(N) Causal Scanner (replaces attention)
Z_res = Z
Z_normed = self.norm1(Z)
Z = Z_res + self.scanner(Z_normed)
# Sub-block 2: Oscillating Zero-Linear "MLP"
Z_res = Z
Z_normed = self.norm2(Z)
Z_mlp = self.mix1(Z_normed)
Z_mlp = self.osc1(Z_mlp)
Z_mlp = self.mix2(Z_mlp)
Z_mlp = self.osc2(Z_mlp)
# Voltage spike gate — feature-level sparsity
mag = safe_abs(Z_mlp)
self.last_mlp_mag = mag.detach()
# Compute spike directly for clean logging
sg = self.sparse_gate
voltage = torch.sigmoid(sg.gate_w * mag)
spike = torch.sigmoid((voltage - sg.threshold) * sg.temperature)
self.last_gate_open = spike.detach()
Z_mlp = spike * Z_mlp # gate on spike, apply to full complex
# Complex sinc resonance coupling
mag = safe_abs(Z_mlp)
sinc_coupling = torch.sinc(mag / math.pi) * Z_mlp
Z = Z_res + self.coupling_alpha * sinc_coupling
return Z
# =============================================================================
# COLM — Complex Oscillating Language Model
# =============================================================================
class COLM(nn.Module):
"""Complex Oscillating Language Model.
Architecture:
- Real embedding → linear projection → complex conversion
- ComplexOscillator initial oscillation
- N × ZeroLinearBlock (scanner + oscillating MLP)
- Complex → real concatenation → linear head
"""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
# Embedding: real tokens → thin embed → linear up → convert to complex
self.thin_embed = nn.Embedding(cfg.vocab_size, cfg.embed_dim)
self.embed_up = nn.Linear(cfg.embed_dim, cfg.n_embd, bias=False)
# Initial oscillation in real space before complex conversion
self.embed_osc = ComplexOscillator(cfg.n_embd)
# Position embedding (real-valued, added to real part)
self.position_emb = nn.Embedding(cfg.block_size, cfg.n_embd)
self.ln_pre = ComplexRMSNorm(cfg.n_embd)
self.blocks = nn.ModuleList([ZeroLinearBlock(i, cfg) for i in range(cfg.n_layer)])
self.ln_f = ComplexRMSNorm(cfg.n_embd)
# Output head: preserve full complex information by concatenating real + imag
self.lm_head = nn.Linear(2 * cfg.n_embd, cfg.vocab_size, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, Tseq = idx.size()
# Real embedding path
x_real = self.embed_up(self.thin_embed(idx)) # (B, T, n_embd) real
# Add position embeddings (real)
pos = torch.arange(0, Tseq, dtype=torch.long, device=idx.device)
x_real = x_real + self.position_emb(pos)
# Convert to complex: real part = features, imag part = 0 initially
Z = torch.complex(x_real, torch.zeros_like(x_real))
# Initial complex oscillation
Z = self.embed_osc(Z)
Z = self.ln_pre(Z)
for block in self.blocks:
Z = block(Z)
Z = self.ln_f(Z)
# Preserve both real and imaginary channels for the classifier head
x_out = torch.cat([Z.real, Z.imag], dim=-1) # (B, T, 2*n_embd)
logits = self.lm_head(x_out)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(B * Tseq, -1), targets.view(B * Tseq))
return logits, loss
|