TCF-1 / leeknet_500m.py
TreeLeek's picture
Upload leeknet_500m.py with huggingface_hub
f6e3ff4 verified
#!/usr/bin/env python3
"""
leeknet_500m.py — Scaled TCF-1 architecture for ~500M params.
Same hybrid attention + Mamba SSM design as the 36M character-level model.
Differences:
- BPE tokenizer (vocab 32k) instead of character-level
- Wider: n_embed 1024 (vs 512)
- Deeper: 12 hybrid pairs (vs 4)
- Longer context: block_size 2048 (vs 512)
- Persistent SSM state still threads through all pairs and across tokens
Architecture (per hybrid pair):
Attention (reasons over context)
+ Mamba SSM (holds and updates persistent state)
+ FeedForward (transforms)
Usage:
python3 leeknet_500m.py info # show parameter count
python3 leeknet_500m.py train_a # Stage A pretraining
python3 leeknet_500m.py train_b # Stage B SFT
python3 leeknet_500m.py train_c # Stage C voice imprint
python3 leeknet_500m.py chat # interactive
"""
import math
import json
import sys
import time
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils as mlx_utils
import numpy as np
import sentencepiece as spm
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
ROOT = Path(__file__).parent
TOKENIZER_DIR = ROOT / 'tokenizer'
DATA_A = ROOT / 'data' / 'A_knowledge'
DATA_B = ROOT / 'data' / 'B_instruction'
VOICE_DIR = ROOT / 'memory' / 'corpus'
CKPT_DIR = ROOT / 'checkpoints_500m'
CKPT_DIR.mkdir(exist_ok=True)
TOKENIZER_MODEL = TOKENIZER_DIR / 'leek_bpe_32k.model'
# ---------------------------------------------------------------------------
# Config — scales from the 36M version
# ---------------------------------------------------------------------------
N_VOCAB = 32000 # from BPE tokenizer
N_EMBED = 1024 # was 512
N_HEAD = 16 # was 8
N_PAIRS = 12 # was 4
SSM_D_STATE = 16
SSM_D_CONV = 4
SSM_EXPAND = 2
DROPOUT = 0.0 # disabled — relying on data diversity
BLOCK_SIZE = 2048 # was 512
# Tools (still emitted as text — harness handles execution)
TOOLS = ['<none>', 'query_soul', 'bash', 'read_file', 'write_file', 'query_memory']
# Training defaults — adjust per stage
BATCH_SIZE = 8
LEARN_RATE = 3e-4
WARMUP_STEPS = 500
WEIGHT_DECAY = 0.1
# ---------------------------------------------------------------------------
# SSM block — Mamba-style selective state
# ---------------------------------------------------------------------------
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = int(expand * d_model)
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
kernel_size=d_conv,
padding=d_conv - 1,
bias=True,
)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
self.dt_proj = nn.Linear(1, self.d_inner, bias=True)
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
self.norm = nn.LayerNorm(d_model)
A = np.arange(1, d_state + 1, dtype=np.float32)
self.A_log = mx.array(np.log(A))
self.D = mx.ones(self.d_inner)
def __call__(self, x, h_prev=None):
B, T, D = x.shape
x_in = self.norm(x)
xz = self.in_proj(x_in)
x_, z = xz[..., :self.d_inner], xz[..., self.d_inner:]
x_conv = self.conv1d(x_)[:, :T, :]
x_act = mx.maximum(x_conv, 0) * mx.sigmoid(x_conv) # silu-ish
xproj = self.x_proj(x_act)
dt = xproj[..., :1]
B_ = xproj[..., 1:1+self.d_state]
C = xproj[..., 1+self.d_state:]
delta = nn.softplus(self.dt_proj(dt))
A = -mx.exp(self.A_log)
# serial scan with persistent state
h = h_prev if h_prev is not None else mx.zeros((B, self.d_inner, self.d_state))
ys = []
for t in range(T):
dt_t = delta[:, t, :] # (B, d_inner)
x_t = x_act[:, t, :] # (B, d_inner)
B_t = B_[:, t, :] # (B, d_state)
C_t = C[:, t, :] # (B, d_state)
# discretize A and B per timestep
dA = mx.exp(dt_t[:, :, None] * A[None, None, :]) # (B, d_inner, d_state)
dB = dt_t[:, :, None] * B_t[:, None, :] # (B, d_inner, d_state)
# state update: h_t = dA * h_{t-1} + dB * x_t
h = dA * h + dB * x_t[:, :, None] # (B, d_inner, d_state)
# output projection: y_t = sum_state(h_t * C_t)
y = (h * C_t[:, None, :]).sum(axis=-1) # (B, d_inner)
ys.append(y[:, None, :])
y_out = mx.concatenate(ys, axis=1)
y_out = y_out + self.D * x_act
y_out = y_out * mx.sigmoid(z)
return x + self.out_proj(y_out), h
# ---------------------------------------------------------------------------
# Attention block
# ---------------------------------------------------------------------------
class AttentionBlock(nn.Module):
def __init__(self, n_embed, n_head, dropout):
super().__init__()
assert n_embed % n_head == 0
self.n_head = n_head
self.head_dim = n_embed // n_head
self.qkv = nn.Linear(n_embed, 3 * n_embed, bias=False)
self.proj = nn.Linear(n_embed, n_embed, bias=False)
self.norm = nn.LayerNorm(n_embed)
self.drop = nn.Dropout(dropout)
def __call__(self, x):
B, T, D = x.shape
x_in = self.norm(x)
qkv = self.qkv(x_in)
qkv = qkv.reshape(B, T, 3, self.n_head, self.head_dim).transpose(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
scores = (q @ k.transpose(0, 1, 3, 2)) / math.sqrt(self.head_dim)
mask = mx.tril(mx.ones((T, T))) == 0
scores = mx.where(mask, -1e9, scores)
attn = mx.softmax(scores, axis=-1)
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, D)
return x + self.drop(self.proj(out))
# ---------------------------------------------------------------------------
# FeedForward
# ---------------------------------------------------------------------------
class FeedForward(nn.Module):
def __init__(self, n_embed, dropout):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embed, 4 * n_embed, bias=False),
nn.GELU(),
nn.Linear(4 * n_embed, n_embed, bias=False),
nn.Dropout(dropout),
)
self.norm = nn.LayerNorm(n_embed)
def __call__(self, x):
return x + self.net(self.norm(x))
# ---------------------------------------------------------------------------
# Hybrid pair: Attention + SSM + FFN
# ---------------------------------------------------------------------------
class HybridPair(nn.Module):
def __init__(self, n_embed, n_head, dropout):
super().__init__()
self.attn = AttentionBlock(n_embed, n_head, dropout)
self.ssm = MambaBlock(n_embed, SSM_D_STATE, SSM_D_CONV, SSM_EXPAND)
self.ff = FeedForward(n_embed, dropout)
def __call__(self, x, h=None):
x = self.attn(x)
x, h = self.ssm(x, h)
x = self.ff(x)
return x, h
# ---------------------------------------------------------------------------
# LeekNet 500M
# ---------------------------------------------------------------------------
class LeekNet500M(nn.Module):
def __init__(self, vocab_size=N_VOCAB, n_embed=N_EMBED, n_head=N_HEAD,
n_pairs=N_PAIRS, block_size=BLOCK_SIZE, dropout=DROPOUT):
super().__init__()
self.block_size = block_size
self.tok_embed = nn.Embedding(vocab_size, n_embed)
self.pos_embed = nn.Embedding(block_size, n_embed)
self.drop = nn.Dropout(dropout)
self.pairs = [HybridPair(n_embed, n_head, dropout) for _ in range(n_pairs)]
self.ln_final = nn.LayerNorm(n_embed)
self.lm_head = nn.Linear(n_embed, vocab_size, bias=False)
def forward(self, idx, states=None):
B, T = idx.shape
pos = mx.arange(T)
x = self.drop(self.tok_embed(idx) + self.pos_embed(pos))
if states is None:
states = [None] * len(self.pairs)
new_states = []
for pair, h in zip(self.pairs, states):
x, h = pair(x, h)
new_states.append(h)
x = self.ln_final(x)
return x, new_states
def __call__(self, idx, n_think=1):
states = None
for _ in range(n_think):
x, states = self.forward(idx, states)
return self.lm_head(x)
# ---------------------------------------------------------------------------
# Quick sanity / param count
# ---------------------------------------------------------------------------
def info():
model = LeekNet500M()
n_params = sum(v.size for _, v in mlx_utils.tree_flatten(model.parameters()))
print(f'\nLeekNet 500M:')
print(f' vocab: {N_VOCAB:,}')
print(f' n_embed: {N_EMBED}')
print(f' n_pairs: {N_PAIRS}')
print(f' n_head: {N_HEAD}')
print(f' block_size: {BLOCK_SIZE}')
print(f' parameters: {n_params/1e6:.1f}M')
tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL))
print(f' tokenizer: {TOKENIZER_MODEL.name}')
print(f' vocab_size: {tok.vocab_size()}')
# ---------------------------------------------------------------------------
# Entry
# ---------------------------------------------------------------------------
if __name__ == '__main__':
cmd = sys.argv[1] if len(sys.argv) > 1 else 'info'
if cmd == 'info':
info()
else:
print(f'training entry points (train_a/b/c) will be wired in next.')
print(f'usage: python3 leeknet_500m.py info')