#!/usr/bin/env python3 """ chat_stage_b.py — Chat with Leek using the Stage B checkpoint. She responds to instructions now, not just text completion. Type your message, press Enter. Type 'quit' to exit. Usage: python3 chat_stage_b.py --block-size 512 python3 chat_stage_b.py --block-size 512 --temp 0.7 """ import argparse import sys from pathlib import Path import mlx.core as mx import mlx.utils as mlx_utils import numpy as np import sentencepiece as spm ROOT = Path(__file__).parent sys.path.insert(0, str(ROOT)) from leeknet_500m import LeekNet500M, TOKENIZER_MODEL, CKPT_DIR, BLOCK_SIZE def load_best_checkpoint(model): ckpts = sorted(CKPT_DIR.glob('stage_b_step*_best.npz'), key=lambda p: int(p.stem.split('step')[1].split('_')[0])) if not ckpts: ckpts = sorted(CKPT_DIR.glob('stage_b_step*.npz'), key=lambda p: int(p.stem.split('step')[1].split('_')[0])) if not ckpts: print('no Stage B checkpoint found') sys.exit(1) latest = ckpts[-1] print(f'loading: {latest.name}') w = np.load(latest) model.load_weights([(k, mx.array(v)) for k, v in w.items()]) def generate(model, tok, prompt_ids, max_new_tokens, temperature, block_size): ctx = mx.array([prompt_ids], dtype=mx.int32) generated = [] for _ in range(max_new_tokens): if ctx.shape[1] > block_size: ctx = ctx[:, -block_size:] logits = model(ctx) next_logits = logits[0, -1] if temperature <= 0.0: next_id = int(mx.argmax(next_logits).item()) else: next_logits = next_logits / temperature probs = mx.softmax(next_logits) mx.eval(probs) p = np.array(probs.tolist()) p = p / p.sum() next_id = int(np.random.choice(len(p), p=p)) if next_id == tok.eos_id(): break generated.append(next_id) ctx = mx.concatenate([ctx, mx.array([[next_id]])], axis=1) full_text = tok.decode(prompt_ids + generated) prev_text = tok.decode(prompt_ids + generated[:-1]) print(full_text[len(prev_text):], end='', flush=True) print() return generated def main(): parser = argparse.ArgumentParser() parser.add_argument('--block-size', type=int, default=512) parser.add_argument('--temp', type=float, default=0.8) parser.add_argument('--max-tokens', type=int, default=400) parser.add_argument('--system', type=str, default=None, help='system prompt prepended before conversation') parser.add_argument('--no-system', action='store_true', help='disable default system prompt') args = parser.parse_args() print('loading tokenizer...') tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL)) print('building model...') model = LeekNet500M(block_size=args.block_size) load_best_checkpoint(model) default_system = ( "You are a helpful, direct, and honest assistant. " "Answer questions clearly and accurately. " "Be concise. Do not ramble or use flowery language." ) if args.no_system: system = None elif args.system: system = args.system else: system = default_system print(f'\nready. block_size={args.block_size} temp={args.temp}') if system: print(f'system: {system}') print('type your message and press Enter. quit to exit.\n') history = [] if system: history.append(f'System: {system}') while True: try: user_input = input('Human: ').strip() except (EOFError, KeyboardInterrupt): print() break if not user_input or user_input.lower() in ('quit', 'exit', 'q'): break history.append(f'Human: {user_input}') prompt = '\n'.join(history) + '\nAssistant:' prompt_ids = tok.encode(prompt) print('Assistant: ', end='', flush=True) generated = generate(model, tok, prompt_ids, args.max_tokens, args.temp, args.block_size) response_text = tok.decode(generated).strip() history.append(f'Assistant: {response_text}') # keep history from growing past block_size while len(tok.encode('\n'.join(history))) > args.block_size - 100: if len(history) > 2: history = history[2:] else: break if __name__ == '__main__': main()