Text Generation
MLX
English
mamba
ssm
hybrid
transformer
from-scratch
custom-architecture
apple-silicon
Instructions to use TreeLeek/TCF-1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use TreeLeek/TCF-1 with MLX:
# Make sure mlx-lm is installed # pip install --upgrade mlx-lm # if on a CUDA device, also pip install mlx[cuda] # Generate text with mlx-lm from mlx_lm import load, generate model, tokenizer = load("TreeLeek/TCF-1") prompt = "Once upon a time in" text = generate(model, tokenizer, prompt=prompt, verbose=True) - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
- MLX LM
How to use TreeLeek/TCF-1 with MLX LM:
Generate or start a chat session
# Install MLX LM uv tool install mlx-lm # Generate some text mlx_lm.generate --model "TreeLeek/TCF-1" --prompt "Once upon a time"
| #!/usr/bin/env python3 | |
| """ | |
| leeknet_500m.py — Scaled TCF-1 architecture for ~500M params. | |
| Same hybrid attention + Mamba SSM design as the 36M character-level model. | |
| Differences: | |
| - BPE tokenizer (vocab 32k) instead of character-level | |
| - Wider: n_embed 1024 (vs 512) | |
| - Deeper: 12 hybrid pairs (vs 4) | |
| - Longer context: block_size 2048 (vs 512) | |
| - Persistent SSM state still threads through all pairs and across tokens | |
| Architecture (per hybrid pair): | |
| Attention (reasons over context) | |
| + Mamba SSM (holds and updates persistent state) | |
| + FeedForward (transforms) | |
| Usage: | |
| python3 leeknet_500m.py info # show parameter count | |
| python3 leeknet_500m.py train_a # Stage A pretraining | |
| python3 leeknet_500m.py train_b # Stage B SFT | |
| python3 leeknet_500m.py train_c # Stage C voice imprint | |
| python3 leeknet_500m.py chat # interactive | |
| """ | |
| import math | |
| import json | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import mlx.optimizers as optim | |
| import mlx.utils as mlx_utils | |
| import numpy as np | |
| import sentencepiece as spm | |
| # --------------------------------------------------------------------------- | |
| # Paths | |
| # --------------------------------------------------------------------------- | |
| ROOT = Path(__file__).parent | |
| TOKENIZER_DIR = ROOT / 'tokenizer' | |
| DATA_A = ROOT / 'data' / 'A_knowledge' | |
| DATA_B = ROOT / 'data' / 'B_instruction' | |
| VOICE_DIR = ROOT / 'memory' / 'corpus' | |
| CKPT_DIR = ROOT / 'checkpoints_500m' | |
| CKPT_DIR.mkdir(exist_ok=True) | |
| TOKENIZER_MODEL = TOKENIZER_DIR / 'leek_bpe_32k.model' | |
| # --------------------------------------------------------------------------- | |
| # Config — scales from the 36M version | |
| # --------------------------------------------------------------------------- | |
| N_VOCAB = 32000 # from BPE tokenizer | |
| N_EMBED = 1024 # was 512 | |
| N_HEAD = 16 # was 8 | |
| N_PAIRS = 12 # was 4 | |
| SSM_D_STATE = 16 | |
| SSM_D_CONV = 4 | |
| SSM_EXPAND = 2 | |
| DROPOUT = 0.0 # disabled — relying on data diversity | |
| BLOCK_SIZE = 2048 # was 512 | |
| # Tools (still emitted as text — harness handles execution) | |
| TOOLS = ['<none>', 'query_soul', 'bash', 'read_file', 'write_file', 'query_memory'] | |
| # Training defaults — adjust per stage | |
| BATCH_SIZE = 8 | |
| LEARN_RATE = 3e-4 | |
| WARMUP_STEPS = 500 | |
| WEIGHT_DECAY = 0.1 | |
| # --------------------------------------------------------------------------- | |
| # SSM block — Mamba-style selective state | |
| # --------------------------------------------------------------------------- | |
| class MambaBlock(nn.Module): | |
| def __init__(self, d_model, d_state=16, d_conv=4, expand=2): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.d_state = d_state | |
| self.d_inner = int(expand * d_model) | |
| self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False) | |
| self.conv1d = nn.Conv1d( | |
| in_channels=self.d_inner, | |
| out_channels=self.d_inner, | |
| kernel_size=d_conv, | |
| padding=d_conv - 1, | |
| bias=True, | |
| ) | |
| self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False) | |
| self.dt_proj = nn.Linear(1, self.d_inner, bias=True) | |
| self.out_proj = nn.Linear(self.d_inner, d_model, bias=False) | |
| self.norm = nn.LayerNorm(d_model) | |
| A = np.arange(1, d_state + 1, dtype=np.float32) | |
| self.A_log = mx.array(np.log(A)) | |
| self.D = mx.ones(self.d_inner) | |
| def __call__(self, x, h_prev=None): | |
| B, T, D = x.shape | |
| x_in = self.norm(x) | |
| xz = self.in_proj(x_in) | |
| x_, z = xz[..., :self.d_inner], xz[..., self.d_inner:] | |
| x_conv = self.conv1d(x_)[:, :T, :] | |
| x_act = mx.maximum(x_conv, 0) * mx.sigmoid(x_conv) # silu-ish | |
| xproj = self.x_proj(x_act) | |
| dt = xproj[..., :1] | |
| B_ = xproj[..., 1:1+self.d_state] | |
| C = xproj[..., 1+self.d_state:] | |
| delta = nn.softplus(self.dt_proj(dt)) | |
| A = -mx.exp(self.A_log) | |
| # serial scan with persistent state | |
| h = h_prev if h_prev is not None else mx.zeros((B, self.d_inner, self.d_state)) | |
| ys = [] | |
| for t in range(T): | |
| dt_t = delta[:, t, :] # (B, d_inner) | |
| x_t = x_act[:, t, :] # (B, d_inner) | |
| B_t = B_[:, t, :] # (B, d_state) | |
| C_t = C[:, t, :] # (B, d_state) | |
| # discretize A and B per timestep | |
| dA = mx.exp(dt_t[:, :, None] * A[None, None, :]) # (B, d_inner, d_state) | |
| dB = dt_t[:, :, None] * B_t[:, None, :] # (B, d_inner, d_state) | |
| # state update: h_t = dA * h_{t-1} + dB * x_t | |
| h = dA * h + dB * x_t[:, :, None] # (B, d_inner, d_state) | |
| # output projection: y_t = sum_state(h_t * C_t) | |
| y = (h * C_t[:, None, :]).sum(axis=-1) # (B, d_inner) | |
| ys.append(y[:, None, :]) | |
| y_out = mx.concatenate(ys, axis=1) | |
| y_out = y_out + self.D * x_act | |
| y_out = y_out * mx.sigmoid(z) | |
| return x + self.out_proj(y_out), h | |
| # --------------------------------------------------------------------------- | |
| # Attention block | |
| # --------------------------------------------------------------------------- | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, n_embed, n_head, dropout): | |
| super().__init__() | |
| assert n_embed % n_head == 0 | |
| self.n_head = n_head | |
| self.head_dim = n_embed // n_head | |
| self.qkv = nn.Linear(n_embed, 3 * n_embed, bias=False) | |
| self.proj = nn.Linear(n_embed, n_embed, bias=False) | |
| self.norm = nn.LayerNorm(n_embed) | |
| self.drop = nn.Dropout(dropout) | |
| def __call__(self, x): | |
| B, T, D = x.shape | |
| x_in = self.norm(x) | |
| qkv = self.qkv(x_in) | |
| qkv = qkv.reshape(B, T, 3, self.n_head, self.head_dim).transpose(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| scores = (q @ k.transpose(0, 1, 3, 2)) / math.sqrt(self.head_dim) | |
| mask = mx.tril(mx.ones((T, T))) == 0 | |
| scores = mx.where(mask, -1e9, scores) | |
| attn = mx.softmax(scores, axis=-1) | |
| out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, D) | |
| return x + self.drop(self.proj(out)) | |
| # --------------------------------------------------------------------------- | |
| # FeedForward | |
| # --------------------------------------------------------------------------- | |
| class FeedForward(nn.Module): | |
| def __init__(self, n_embed, dropout): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(n_embed, 4 * n_embed, bias=False), | |
| nn.GELU(), | |
| nn.Linear(4 * n_embed, n_embed, bias=False), | |
| nn.Dropout(dropout), | |
| ) | |
| self.norm = nn.LayerNorm(n_embed) | |
| def __call__(self, x): | |
| return x + self.net(self.norm(x)) | |
| # --------------------------------------------------------------------------- | |
| # Hybrid pair: Attention + SSM + FFN | |
| # --------------------------------------------------------------------------- | |
| class HybridPair(nn.Module): | |
| def __init__(self, n_embed, n_head, dropout): | |
| super().__init__() | |
| self.attn = AttentionBlock(n_embed, n_head, dropout) | |
| self.ssm = MambaBlock(n_embed, SSM_D_STATE, SSM_D_CONV, SSM_EXPAND) | |
| self.ff = FeedForward(n_embed, dropout) | |
| def __call__(self, x, h=None): | |
| x = self.attn(x) | |
| x, h = self.ssm(x, h) | |
| x = self.ff(x) | |
| return x, h | |
| # --------------------------------------------------------------------------- | |
| # LeekNet 500M | |
| # --------------------------------------------------------------------------- | |
| class LeekNet500M(nn.Module): | |
| def __init__(self, vocab_size=N_VOCAB, n_embed=N_EMBED, n_head=N_HEAD, | |
| n_pairs=N_PAIRS, block_size=BLOCK_SIZE, dropout=DROPOUT): | |
| super().__init__() | |
| self.block_size = block_size | |
| self.tok_embed = nn.Embedding(vocab_size, n_embed) | |
| self.pos_embed = nn.Embedding(block_size, n_embed) | |
| self.drop = nn.Dropout(dropout) | |
| self.pairs = [HybridPair(n_embed, n_head, dropout) for _ in range(n_pairs)] | |
| self.ln_final = nn.LayerNorm(n_embed) | |
| self.lm_head = nn.Linear(n_embed, vocab_size, bias=False) | |
| def forward(self, idx, states=None): | |
| B, T = idx.shape | |
| pos = mx.arange(T) | |
| x = self.drop(self.tok_embed(idx) + self.pos_embed(pos)) | |
| if states is None: | |
| states = [None] * len(self.pairs) | |
| new_states = [] | |
| for pair, h in zip(self.pairs, states): | |
| x, h = pair(x, h) | |
| new_states.append(h) | |
| x = self.ln_final(x) | |
| return x, new_states | |
| def __call__(self, idx, n_think=1): | |
| states = None | |
| for _ in range(n_think): | |
| x, states = self.forward(idx, states) | |
| return self.lm_head(x) | |
| # --------------------------------------------------------------------------- | |
| # Quick sanity / param count | |
| # --------------------------------------------------------------------------- | |
| def info(): | |
| model = LeekNet500M() | |
| n_params = sum(v.size for _, v in mlx_utils.tree_flatten(model.parameters())) | |
| print(f'\nLeekNet 500M:') | |
| print(f' vocab: {N_VOCAB:,}') | |
| print(f' n_embed: {N_EMBED}') | |
| print(f' n_pairs: {N_PAIRS}') | |
| print(f' n_head: {N_HEAD}') | |
| print(f' block_size: {BLOCK_SIZE}') | |
| print(f' parameters: {n_params/1e6:.1f}M') | |
| tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL)) | |
| print(f' tokenizer: {TOKENIZER_MODEL.name}') | |
| print(f' vocab_size: {tok.vocab_size()}') | |
| # --------------------------------------------------------------------------- | |
| # Entry | |
| # --------------------------------------------------------------------------- | |
| if __name__ == '__main__': | |
| cmd = sys.argv[1] if len(sys.argv) > 1 else 'info' | |
| if cmd == 'info': | |
| info() | |
| else: | |
| print(f'training entry points (train_a/b/c) will be wired in next.') | |
| print(f'usage: python3 leeknet_500m.py info') | |