TCF-1 / chat_stage_b.py
TreeLeek's picture
Upload chat_stage_b.py with huggingface_hub
ac8476e verified
#!/usr/bin/env python3
"""
chat_stage_b.py — Chat with Leek using the Stage B checkpoint.
She responds to instructions now, not just text completion.
Type your message, press Enter. Type 'quit' to exit.
Usage:
python3 chat_stage_b.py --block-size 512
python3 chat_stage_b.py --block-size 512 --temp 0.7
"""
import argparse
import sys
from pathlib import Path
import mlx.core as mx
import mlx.utils as mlx_utils
import numpy as np
import sentencepiece as spm
ROOT = Path(__file__).parent
sys.path.insert(0, str(ROOT))
from leeknet_500m import LeekNet500M, TOKENIZER_MODEL, CKPT_DIR, BLOCK_SIZE
def load_best_checkpoint(model):
ckpts = sorted(CKPT_DIR.glob('stage_b_step*_best.npz'),
key=lambda p: int(p.stem.split('step')[1].split('_')[0]))
if not ckpts:
ckpts = sorted(CKPT_DIR.glob('stage_b_step*.npz'),
key=lambda p: int(p.stem.split('step')[1].split('_')[0]))
if not ckpts:
print('no Stage B checkpoint found')
sys.exit(1)
latest = ckpts[-1]
print(f'loading: {latest.name}')
w = np.load(latest)
model.load_weights([(k, mx.array(v)) for k, v in w.items()])
def generate(model, tok, prompt_ids, max_new_tokens, temperature, block_size):
ctx = mx.array([prompt_ids], dtype=mx.int32)
generated = []
for _ in range(max_new_tokens):
if ctx.shape[1] > block_size:
ctx = ctx[:, -block_size:]
logits = model(ctx)
next_logits = logits[0, -1]
if temperature <= 0.0:
next_id = int(mx.argmax(next_logits).item())
else:
next_logits = next_logits / temperature
probs = mx.softmax(next_logits)
mx.eval(probs)
p = np.array(probs.tolist())
p = p / p.sum()
next_id = int(np.random.choice(len(p), p=p))
if next_id == tok.eos_id():
break
generated.append(next_id)
ctx = mx.concatenate([ctx, mx.array([[next_id]])], axis=1)
full_text = tok.decode(prompt_ids + generated)
prev_text = tok.decode(prompt_ids + generated[:-1])
print(full_text[len(prev_text):], end='', flush=True)
print()
return generated
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--block-size', type=int, default=512)
parser.add_argument('--temp', type=float, default=0.8)
parser.add_argument('--max-tokens', type=int, default=400)
parser.add_argument('--system', type=str, default=None,
help='system prompt prepended before conversation')
parser.add_argument('--no-system', action='store_true',
help='disable default system prompt')
args = parser.parse_args()
print('loading tokenizer...')
tok = spm.SentencePieceProcessor(model_file=str(TOKENIZER_MODEL))
print('building model...')
model = LeekNet500M(block_size=args.block_size)
load_best_checkpoint(model)
default_system = (
"You are a helpful, direct, and honest assistant. "
"Answer questions clearly and accurately. "
"Be concise. Do not ramble or use flowery language."
)
if args.no_system:
system = None
elif args.system:
system = args.system
else:
system = default_system
print(f'\nready. block_size={args.block_size} temp={args.temp}')
if system:
print(f'system: {system}')
print('type your message and press Enter. quit to exit.\n')
history = []
if system:
history.append(f'System: {system}')
while True:
try:
user_input = input('Human: ').strip()
except (EOFError, KeyboardInterrupt):
print()
break
if not user_input or user_input.lower() in ('quit', 'exit', 'q'):
break
history.append(f'Human: {user_input}')
prompt = '\n'.join(history) + '\nAssistant:'
prompt_ids = tok.encode(prompt)
print('Assistant: ', end='', flush=True)
generated = generate(model, tok, prompt_ids, args.max_tokens, args.temp, args.block_size)
response_text = tok.decode(generated).strip()
history.append(f'Assistant: {response_text}')
# keep history from growing past block_size
while len(tok.encode('\n'.join(history))) > args.block_size - 100:
if len(history) > 2:
history = history[2:]
else:
break
if __name__ == '__main__':
main()