#!/usr/bin/env python3 """ leeknet_500m.py — Scaled TCF-1 architecture for ~500M params. Same hybrid attention + Mamba SSM design as the 36M character-level model. Differences: - BPE tokenizer (vocab 32k) instead of character-level - Wider: n_embed 1024 (vs 512) - Deeper: 12 hybrid pairs (vs 4) - Longer context: block_size 2048 (vs 512) - Persistent SSM state still threads through all pairs and across tokens Architecture (per hybrid pair): Attention (reasons over context) + Mamba SSM (holds and updates persistent state) + FeedForward (transforms) Usage: python3 leeknet_500m.py info # show parameter count python3 leeknet_500m.py train_a # Stage A pretraining python3 leeknet_500m.py train_b # Stage B SFT python3 leeknet_500m.py train_c # Stage C voice imprint python3 leeknet_500m.py chat # interactive """ import math import json import sys import time from pathlib import Path import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim import mlx.utils as mlx_utils import numpy as np import sentencepiece as spm # --------------------------------------------------------------------------- # Paths # --------------------------------------------------------------------------- ROOT = Path(__file__).parent TOKENIZER_DIR = ROOT / 'tokenizer' DATA_A = ROOT / 'data' / 'A_knowledge' DATA_B = ROOT / 'data' / 'B_instruction' VOICE_DIR = ROOT / 'memory' / 'corpus' CKPT_DIR = ROOT / 'checkpoints_500m' CKPT_DIR.mkdir(exist_ok=True) TOKENIZER_MODEL = TOKENIZER_DIR / 'leek_bpe_32k.model' # --------------------------------------------------------------------------- # Config — scales from the 36M version # --------------------------------------------------------------------------- N_VOCAB = 32000 # from BPE tokenizer N_EMBED = 1024 # was 512 N_HEAD = 16 # was 8 N_PAIRS = 12 # was 4 SSM_D_STATE = 16 SSM_D_CONV = 4 SSM_EXPAND = 2 DROPOUT = 0.0 # disabled — relying on data diversity BLOCK_SIZE = 2048 # was 512 # Tools (still emitted as text — harness handles execution) TOOLS = ['', 'query_soul', 'bash', 'read_file', 'write_file', 'query_memory'] # Training defaults — adjust per stage BATCH_SIZE = 8 LEARN_RATE = 3e-4 WARMUP_STEPS = 500 WEIGHT_DECAY = 0.1 # --------------------------------------------------------------------------- # SSM block — Mamba-style selective state # --------------------------------------------------------------------------- class MambaBlock(nn.Module): def __init__(self, d_model, d_state=16, d_conv=4, expand=2): super().__init__() self.d_model = d_model self.d_state = d_state self.d_inner = int(expand * d_model) self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, kernel_size=d_conv, padding=d_conv - 1, bias=True, ) self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) self.dt_proj = nn.Linear(1, self.d_inner, bias=True) self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) self.norm = nn.LayerNorm(d_model) A = np.arange(1, d_state + 1, dtype=np.float32) self.A_log = mx.array(np.log(A)) self.D = mx.ones(self.d_inner) def __call__(self, x, h_prev=None): B, T, D = x.shape x_in = self.norm(x) xz = self.in_proj(x_in) x_, z = xz[..., :self.d_inner], xz[..., self.d_inner:] x_conv = self.conv1d(x_)[:, :T, :] x_act = mx.maximum(x_conv, 0) * mx.sigmoid(x_conv) # silu-ish xproj = self.x_proj(x_act) dt = xproj[..., :1] B_ = xproj[..., 1:1+self.d_state] C = xproj[..., 1+self.d_state:] delta = nn.softplus(self.dt_proj(dt)) A = -mx.exp(self.A_log) # serial scan with persistent state h = h_prev if h_prev is not None else mx.zeros((B, self.d_inner, self.d_state)) ys = [] for t in range(T): dt_t = delta[:, t, :] # (B, d_inner) x_t = x_act[:, t, :] # (B, d_inner) B_t = B_[:, t, :] # (B, d_state) C_t = C[:, t, :] # (B, d_state) # discretize A and B per timestep dA = mx.exp(dt_t[:, :, None] * A[None, None, :]) # (B, d_inner, d_state) dB = dt_t[:, :, None] * B_t[:, None, :] # (B, d_inner, d_state) # state update: h_t = dA * h_{t-1} + dB * x_t h = dA * h + dB * x_t[:, :, None] # (B, d_inner, d_state) # output projection: y_t = sum_state(h_t * C_t) y = (h * C_t[:, None, :]).sum(axis=-1) # (B, d_inner) ys.append(y[:, None, :]) y_out = mx.concatenate(ys, axis=1) y_out = y_out + self.D * x_act y_out = y_out * mx.sigmoid(z) return x + self.out_proj(y_out), h # --------------------------------------------------------------------------- # Attention block # --------------------------------------------------------------------------- class AttentionBlock(nn.Module): def __init__(self, n_embed, n_head, dropout): super().__init__() assert n_embed % n_head == 0 self.n_head = n_head self.head_dim = n_embed // n_head self.qkv = nn.Linear(n_embed, 3 * n_embed, bias=False) self.proj = nn.Linear(n_embed, n_embed, bias=False) self.norm = nn.LayerNorm(n_embed) self.drop = nn.Dropout(dropout) def __call__(self, x): B, T, D = x.shape x_in = self.norm(x) qkv = self.qkv(x_in) qkv = qkv.reshape(B, T, 3, self.n_head, self.head_dim).transpose(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] scores = (q @ k.transpose(0, 1, 3, 2)) / math.sqrt(self.head_dim) mask = mx.tril(mx.ones((T, T))) == 0 scores = mx.where(mask, -1e9, scores) attn = mx.softmax(scores, axis=-1) out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, D) return x + self.drop(self.proj(out)) # --------------------------------------------------------------------------- # FeedForward # --------------------------------------------------------------------------- class FeedForward(nn.Module): def __init__(self, n_embed, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(n_embed, 4 * n_embed, bias=False), nn.GELU(), nn.Linear(4 * n_embed, n_embed, bias=False), nn.Dropout(dropout), ) self.norm = nn.LayerNorm(n_embed) def __call__(self, x): return x + self.net(self.norm(x)) # --------------------------------------------------------------------------- # Hybrid pair: Attention + SSM + FFN # --------------------------------------------------------------------------- class HybridPair(nn.Module): def __init__(self, n_embed, n_head, dropout): super().__init__() self.attn = AttentionBlock(n_embed, n_head, dropout) self.ssm = MambaBlock(n_embed, SSM_D_STATE, SSM_D_CONV, SSM_EXPAND) self.ff = FeedForward(n_embed, dropout) def __call__(self, x, h=None): x = self.attn(x) x, h = self.ssm(x, h) x = self.ff(x) return x, h # --------------------------------------------------------------------------- # LeekNet 500M # --------------------------------------------------------------------------- class LeekNet500M(nn.Module): def __init__(self, vocab_size=N_VOCAB, n_embed=N_EMBED, n_head=N_HEAD, n_pairs=N_PAIRS, block_size=BLOCK_SIZE, dropout=DROPOUT): super().__init__() self.block_size = block_size self.tok_embed = nn.Embedding(vocab_size, n_embed) self.pos_embed = nn.Embedding(block_size, n_embed) self.drop = nn.Dropout(dropout) self.pairs = [HybridPair(n_embed, n_head, dropout) for _ in range(n_pairs)] self.ln_final = nn.LayerNorm(n_embed) self.lm_head = nn.Linear(n_embed, vocab_size, bias=False) def forward(self, idx, states=None): B, T = idx.shape pos = mx.arange(T) x = self.drop(self.tok_embed(idx) + self.pos_embed(pos)) if states is None: states = [None] * len(self.pairs) new_states = [] for pair, h in zip(self.pairs, states): x, h = pair(x, h) new_states.append(h) x = self.ln_final(x) return x, new_states def __call__(self, idx, n_think=1): states = None for _ in range(n_think): x, states = self.forward(idx, states) return self.lm_head(x) # --------------------------------------------------------------------------- # Quick sanity / param count # --------------------------------------------------------------------------- def info(): model = LeekNet500M() n_params = sum(v.size for _, v in mlx_utils.tree_flatten(model.parameters())) print(f'\nLeekNet 500M:') print(f' vocab: {N_VOCAB:,}') print(f' n_embed: {N_EMBED}') print(f' n_pairs: {N_PAIRS}') print(f' n_head: {N_HEAD}') print(f' block_size: {BLOCK_SIZE}') print(f' parameters: {n_params/1e6:.1f}M') tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL)) print(f' tokenizer: {TOKENIZER_MODEL.name}') print(f' vocab_size: {tok.vocab_size()}') # --------------------------------------------------------------------------- # Entry # --------------------------------------------------------------------------- if __name__ == '__main__': cmd = sys.argv[1] if len(sys.argv) > 1 else 'info' if cmd == 'info': info() else: print(f'training entry points (train_a/b/c) will be wired in next.') print(f'usage: python3 leeknet_500m.py info')