Text Generation
MLX
English
mamba
ssm
hybrid
transformer
from-scratch
custom-architecture
apple-silicon
Instructions to use TreeLeek/TCF-1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use TreeLeek/TCF-1 with MLX:
# Make sure mlx-lm is installed # pip install --upgrade mlx-lm # if on a CUDA device, also pip install mlx[cuda] # Generate text with mlx-lm from mlx_lm import load, generate model, tokenizer = load("TreeLeek/TCF-1") prompt = "Once upon a time in" text = generate(model, tokenizer, prompt=prompt, verbose=True) - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
- MLX LM
How to use TreeLeek/TCF-1 with MLX LM:
Generate or start a chat session
# Install MLX LM uv tool install mlx-lm # Generate some text mlx_lm.generate --model "TreeLeek/TCF-1" --prompt "Once upon a time"
File size: 4,502 Bytes
ac8476e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #!/usr/bin/env python3
"""
chat_stage_b.py — Chat with Leek using the Stage B checkpoint.
She responds to instructions now, not just text completion.
Type your message, press Enter. Type 'quit' to exit.
Usage:
python3 chat_stage_b.py --block-size 512
python3 chat_stage_b.py --block-size 512 --temp 0.7
"""
import argparse
import sys
from pathlib import Path
import mlx.core as mx
import mlx.utils as mlx_utils
import numpy as np
import sentencepiece as spm
ROOT = Path(__file__).parent
sys.path.insert(0, str(ROOT))
from leeknet_500m import LeekNet500M, TOKENIZER_MODEL, CKPT_DIR, BLOCK_SIZE
def load_best_checkpoint(model):
ckpts = sorted(CKPT_DIR.glob('stage_b_step*_best.npz'),
key=lambda p: int(p.stem.split('step')[1].split('_')[0]))
if not ckpts:
ckpts = sorted(CKPT_DIR.glob('stage_b_step*.npz'),
key=lambda p: int(p.stem.split('step')[1].split('_')[0]))
if not ckpts:
print('no Stage B checkpoint found')
sys.exit(1)
latest = ckpts[-1]
print(f'loading: {latest.name}')
w = np.load(latest)
model.load_weights([(k, mx.array(v)) for k, v in w.items()])
def generate(model, tok, prompt_ids, max_new_tokens, temperature, block_size):
ctx = mx.array([prompt_ids], dtype=mx.int32)
generated = []
for _ in range(max_new_tokens):
if ctx.shape[1] > block_size:
ctx = ctx[:, -block_size:]
logits = model(ctx)
next_logits = logits[0, -1]
if temperature <= 0.0:
next_id = int(mx.argmax(next_logits).item())
else:
next_logits = next_logits / temperature
probs = mx.softmax(next_logits)
mx.eval(probs)
p = np.array(probs.tolist())
p = p / p.sum()
next_id = int(np.random.choice(len(p), p=p))
if next_id == tok.eos_id():
break
generated.append(next_id)
ctx = mx.concatenate([ctx, mx.array([[next_id]])], axis=1)
full_text = tok.decode(prompt_ids + generated)
prev_text = tok.decode(prompt_ids + generated[:-1])
print(full_text[len(prev_text):], end='', flush=True)
print()
return generated
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--block-size', type=int, default=512)
parser.add_argument('--temp', type=float, default=0.8)
parser.add_argument('--max-tokens', type=int, default=400)
parser.add_argument('--system', type=str, default=None,
help='system prompt prepended before conversation')
parser.add_argument('--no-system', action='store_true',
help='disable default system prompt')
args = parser.parse_args()
print('loading tokenizer...')
tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL))
print('building model...')
model = LeekNet500M(block_size=args.block_size)
load_best_checkpoint(model)
default_system = (
"You are a helpful, direct, and honest assistant. "
"Answer questions clearly and accurately. "
"Be concise. Do not ramble or use flowery language."
)
if args.no_system:
system = None
elif args.system:
system = args.system
else:
system = default_system
print(f'\nready. block_size={args.block_size} temp={args.temp}')
if system:
print(f'system: {system}')
print('type your message and press Enter. quit to exit.\n')
history = []
if system:
history.append(f'System: {system}')
while True:
try:
user_input = input('Human: ').strip()
except (EOFError, KeyboardInterrupt):
print()
break
if not user_input or user_input.lower() in ('quit', 'exit', 'q'):
break
history.append(f'Human: {user_input}')
prompt = '\n'.join(history) + '\nAssistant:'
prompt_ids = tok.encode(prompt)
print('Assistant: ', end='', flush=True)
generated = generate(model, tok, prompt_ids, args.max_tokens, args.temp, args.block_size)
response_text = tok.decode(generated).strip()
history.append(f'Assistant: {response_text}')
# keep history from growing past block_size
while len(tok.encode('\n'.join(history))) > args.block_size - 100:
if len(history) > 2:
history = history[2:]
else:
break
if __name__ == '__main__':
main()
|