SymbioGPT-10M-space / symbio_model.py
LisaMegaWatts's picture
Upload symbio_model.py with huggingface_hub
280ed9e verified
"""SymbioGPT β€” Multi-organelle GPT with learned per-channel gating.
4 organelles: CausalConv + Monarch + LongConv + CausalSelfAttention
fused via OrganelleGate with learnable temperature.
Architecture: RoPE, RMSNorm, SwiGLU, SkipGate, weight-tied output.
"""
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ═══════════════════════════════════════════════════════════════════
# Configuration
# ═══════════════════════════════════════════════════════════════════
@dataclass
class SymbioConfig:
d_model: int = 320
n_layers: int = 8
n_heads: int = 5
head_dim: int = 64
ffn_mult: int = 4
dropout: float = 0.0
context_length: int = 256
vocab_size: int = 2000
weight_tying: bool = True
organelles: Tuple[str, ...] = ("causal_conv", "monarch", "long_conv", "attention")
conv_kernel_size: int = 4
n_monarch_heads: int = 1
gate_temperature_init: float = 1.0
free_energy_beta: float = 0.001
per_layer_organelles: Optional[List[Tuple[str, ...]]] = None
@property
def p(self) -> int:
return int(math.isqrt(self.context_length))
@property
def n_organelles(self) -> int:
return len(self.organelles)
# ═══════════════════════════════════════════════════════════════════
# Building blocks
# ═══════════════════════════════════════════════════════════════════
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x / rms * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 2048):
super().__init__()
freqs = 1.0 / (10000.0 ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len).float()
angles = torch.outer(positions, freqs)
self.register_buffer("cos_cache", angles.cos())
self.register_buffer("sin_cache", angles.sin())
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.size(2)
half = x.size(-1) // 2
x1, x2 = x[..., :half], x[..., half:]
cos = self.cos_cache[:seq_len, :half].unsqueeze(0).unsqueeze(0)
sin = self.sin_cache[:seq_len, :half].unsqueeze(0).unsqueeze(0)
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
class SwiGLU(nn.Module):
def __init__(self, d_model: int, ffn_mult: int = 4):
super().__init__()
raw_hidden = 2 * d_model * ffn_mult // 3
hidden_dim = max(64, (raw_hidden // 64) * 64)
self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
self.v = nn.Linear(d_model, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.v(x))
class CausalSelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, head_dim: int, dropout: float = 0.0):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
total_dim = n_heads * head_dim
self.wq = nn.Linear(d_model, total_dim, bias=False)
self.wk = nn.Linear(d_model, total_dim, bias=False)
self.wv = nn.Linear(d_model, total_dim, bias=False)
self.wo = nn.Linear(total_dim, d_model, bias=False)
self.attn_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
def forward(self, x, rope, mask=None):
B, T, D = x.shape
H, HD = self.n_heads, self.head_dim
q = self.wq(x).view(B, T, H, HD).transpose(1, 2)
k = self.wk(x).view(B, T, H, HD).transpose(1, 2)
v = self.wv(x).view(B, T, H, HD).transpose(1, 2)
q = rope(q)
k = rope(k)
scale = 1.0 / math.sqrt(HD)
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
attn = attn + mask
attn = F.softmax(attn, dim=-1)
attn = self.attn_dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, H * HD)
return self.wo(out)
# ═══════════════════════════════════════════════════════════════════
# Organelles
# ═══════════════════════════════════════════════════════════════════
class CausalDepthwiseConv1d(nn.Module):
def __init__(self, channels: int, kernel_size: int = 4):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.weight = nn.Parameter(torch.empty(channels, 1, kernel_size))
nn.init.normal_(self.weight, std=math.sqrt(1.0 / kernel_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, D = x.shape
x_t = x.transpose(1, 2)
x_padded = F.pad(x_t, (self.kernel_size - 1, 0))
out = F.conv1d(x_padded, self.weight, groups=D)
return out.transpose(1, 2)
class MonarchMatrix(nn.Module):
def __init__(self, seq_len: int):
super().__init__()
p = int(math.isqrt(seq_len))
self.seq_len = seq_len
self.p = p
scale = math.sqrt(2.0 / (p + p))
self.L1 = nn.Parameter(torch.randn(p, p, p) * scale)
self.L2 = nn.Parameter(torch.randn(p, p, p) * scale)
@staticmethod
def _julia_batched_mul(A, B):
return torch.bmm(A.permute(2, 0, 1), B.permute(2, 0, 1)).permute(1, 2, 0)
def realize(self):
p = self.p
T = self.seq_len
I_T = torch.eye(T, device=self.L1.device, dtype=self.L1.dtype)
x = I_T.reshape(p, p, T)
x = x.permute(0, 2, 1)
x = self._julia_batched_mul(self.L2, x)
x = x.permute(0, 2, 1)
x = x.permute(1, 0, 2)
x = x.permute(0, 2, 1)
x = self._julia_batched_mul(self.L1, x)
x = x.permute(0, 2, 1)
x = x.permute(1, 0, 2)
return x.reshape(T, T)
def forward(self, x, causal_mask=None):
B, T, D_head = x.shape
M = self.realize()[:T, :T]
if causal_mask is not None:
M = M * causal_mask[:T, :T]
x_flat = x.permute(1, 0, 2).reshape(T, B * D_head)
y_flat = M @ x_flat
return y_flat.reshape(T, B, D_head).permute(1, 0, 2)
class LongConv(nn.Module):
def __init__(self, channels: int, seq_len: int):
super().__init__()
self.channels = channels
self.seq_len = seq_len
self.kernel = nn.Parameter(torch.empty(channels, 1, seq_len))
self._init_weights()
def _init_weights(self):
scale = math.sqrt(1.0 / self.seq_len)
nn.init.normal_(self.kernel, mean=0.0, std=scale)
with torch.no_grad():
decay = torch.exp(-0.1 * torch.arange(self.seq_len, dtype=torch.float32))
self.kernel.mul_(decay.unsqueeze(0).unsqueeze(0))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, D = x.shape
K = self.seq_len
x_t = x.transpose(1, 2)
x_padded = F.pad(x_t, (K - 1, 0))
out = F.conv1d(x_padded, self.kernel, groups=D)
return out.transpose(1, 2)
# ═══════════════════════════════════════════════════════════════════
# Gate and mixer
# ═══════════════════════════════════════════════════════════════════
class OrganelleGate(nn.Module):
def __init__(self, dim: int, n_organelles: int, temperature_init: float = 1.0):
super().__init__()
self.n_organelles = n_organelles
self.logits = nn.Parameter(torch.zeros(n_organelles, dim))
self.temperature = nn.Parameter(torch.tensor([temperature_init]))
def forward(self, organelle_outputs, organelle_mask=None):
logits = self.logits
if organelle_mask is not None:
mask_additive = torch.zeros_like(logits)
for i in range(self.n_organelles):
if not organelle_mask[i]:
mask_additive[i, :] = -1e10
logits = logits + mask_additive
tau = self.temperature.clamp(min=0.01)
weights = F.softmax(logits / tau, dim=0)
out = torch.zeros_like(organelle_outputs[0])
for i in range(self.n_organelles):
w = weights[i].unsqueeze(0).unsqueeze(0)
out = out + w * organelle_outputs[i]
return out
class SkipGate(nn.Module):
def __init__(self):
super().__init__()
self.scale = nn.Parameter(torch.ones(1))
def forward(self, x):
return self.scale * x
class SymbioSequenceMixer(nn.Module):
def __init__(self, config: SymbioConfig):
super().__init__()
self.config = config
d = config.d_model
T = config.context_length
self.organelle_names = list(config.organelles)
self.organelle_modules = nn.ModuleDict()
for name in self.organelle_names:
if name == "causal_conv":
self.organelle_modules[name] = CausalDepthwiseConv1d(d, config.conv_kernel_size)
elif name == "monarch":
self.organelle_modules[name] = nn.ModuleList(
[MonarchMatrix(T) for _ in range(config.n_monarch_heads)]
)
elif name == "long_conv":
self.organelle_modules[name] = LongConv(d, T)
elif name == "attention":
self.organelle_modules[name] = CausalSelfAttention(
d, config.n_heads, config.head_dim, config.dropout
)
self.gate = OrganelleGate(d, len(self.organelle_names), config.gate_temperature_init)
if "monarch" in self.organelle_names:
self.register_buffer("monarch_causal_mask", torch.tril(torch.ones(T, T)))
def forward(self, x, rope, attn_mask=None, organelle_mask=None):
B, T, D = x.shape
outputs = []
for name in self.organelle_names:
if name == "causal_conv":
out = self.organelle_modules[name](x)
elif name == "monarch":
heads = self.organelle_modules[name]
n_mh = len(heads)
hd = D // n_mh
slices = []
for i, monarch in enumerate(heads):
x_slice = x[:, :, i * hd : (i + 1) * hd]
y_slice = monarch(x_slice, self.monarch_causal_mask)
slices.append(y_slice)
out = torch.cat(slices, dim=-1)
elif name == "long_conv":
out = self.organelle_modules[name](x)
elif name == "attention":
out = self.organelle_modules[name](x, rope, attn_mask)
outputs.append(out)
return self.gate(tuple(outputs), organelle_mask)
# ═══════════════════════════════════════════════════════════════════
# SymbioBlock and SymbioGPT
# ═══════════════════════════════════════════════════════════════════
class SymbioBlock(nn.Module):
def __init__(self, config: SymbioConfig, layer_organelles=None):
super().__init__()
d = config.d_model
if layer_organelles is not None:
from dataclasses import replace
layer_config = replace(config, organelles=layer_organelles)
else:
layer_config = config
self.ln1 = RMSNorm(d)
self.seq_mixer = SymbioSequenceMixer(layer_config)
self.skip1 = SkipGate()
self.ln2 = RMSNorm(d)
self.ffn = SwiGLU(d, config.ffn_mult)
self.skip2 = SkipGate()
def forward(self, x, rope, attn_mask=None, organelle_mask=None):
x = x + self.skip1(self.seq_mixer(self.ln1(x), rope, attn_mask, organelle_mask))
x = x + self.skip2(self.ffn(self.ln2(x)))
return x
class SymbioGPT(nn.Module):
def __init__(self, config: SymbioConfig):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
self.rope = RotaryEmbedding(config.head_dim, config.context_length)
blocks = []
for i in range(config.n_layers):
layer_org = None
if config.per_layer_organelles is not None:
layer_org = config.per_layer_organelles[i]
blocks.append(SymbioBlock(config, layer_org))
self.blocks = nn.ModuleList(blocks)
self.ln_f = RMSNorm(config.d_model)
if config.weight_tying:
self.head = None
else:
self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
def forward(self, input_ids, organelle_mask=None):
B, T = input_ids.shape
x = self.tok_emb(input_ids)
attn_mask = torch.triu(
torch.full((T, T), float("-inf"), device=x.device, dtype=x.dtype), diagonal=1
)
for block in self.blocks:
x = block(x, self.rope, attn_mask, organelle_mask)
x = self.ln_f(x)
if self.head is not None:
return self.head(x)
return F.linear(x, self.tok_emb.weight)