TinkyBrain-31M / chat.py
Hoodrobot's picture
Upload chat.py with huggingface_hub
43321dd verified
#!/usr/bin/env python3
"""
AAC Micro Brain — Interactive Chat
Generates conversational responses from the trained MicroBrain model.
"""
import json
import re
import mlx.core as mx
import mlx.nn as nn
from model import MicroBrain
PAD, BOS, EOS, SEP, UNK = 0, 1, 2, 3, 4
# Default to v3 checkpoint (all phases)
CHECKPOINT_DIR = "/Volumes/PRO-G40/models/aac-micro-brain/checkpoints"
class SimpleTokenizer:
def __init__(self):
self.word2idx = {"<pad>": 0, "<bos>": 1, "<eos>": 2, "<sep>": 3, "<unk>": 4}
self.idx2word = {v: k for k, v in self.word2idx.items()}
def encode(self, text):
return [self.word2idx.get(w, UNK) for w in re.findall(r"[a-z']+|[.,!?]", text.lower())]
def decode(self, ids):
return " ".join(self.idx2word.get(i, "?") for i in ids if i > 4)
@classmethod
def load(cls, path):
tok = cls()
with open(path) as f:
tok.word2idx = json.load(f)
tok.idx2word = {v: k for k, v in tok.word2idx.items()}
return tok
@property
def vocab_size(self):
return len(self.word2idx)
def generate_greedy(model, tokenizer, prompt, max_tokens=20):
tokens = [BOS] + tokenizer.encode(prompt) + [SEP]
for _ in range(max_tokens):
x = mx.array([tokens])
logits = model(x)
next_token = mx.argmax(logits[0, -1, :]).item()
if next_token in (PAD, EOS, SEP):
break
tokens.append(next_token)
sep_idx = tokens.index(SEP) + 1 if SEP in tokens else 0
return tokenizer.decode(tokens[sep_idx:])
def generate_sample(model, tokenizer, prompt, max_tokens=20, temperature=0.7, top_k=5):
tokens = [BOS] + tokenizer.encode(prompt) + [SEP]
for _ in range(max_tokens):
x = mx.array([tokens])
logits = model(x)
next_logits = logits[0, -1, :]
if top_k > 0 and top_k < next_logits.shape[0]:
top_k_indices = mx.argpartition(next_logits, kth=-top_k)[-top_k:]
mask = mx.full(next_logits.shape, float('-inf'))
mask[top_k_indices] = next_logits[top_k_indices]
next_logits = mask
next_logits = next_logits / temperature
probs = mx.softmax(next_logits, axis=-1)
next_token = mx.random.categorical(probs).item()
if next_token in (PAD, EOS, SEP):
break
tokens.append(next_token)
sep_idx = tokens.index(SEP) + 1 if SEP in tokens else 0
return tokenizer.decode(tokens[sep_idx:])
def generate_suggestions(model, tokenizer, prompt, n=6):
"""Generate multiple unique response suggestions."""
suggestions = []
seen = set()
# Always include greedy
greedy = generate_greedy(model, tokenizer, prompt)
if greedy:
suggestions.append(greedy)
seen.add(greedy.lower())
# Sample diverse options
for temp in [0.5, 0.7, 0.9, 1.0, 1.2, 1.5]:
for k in [3, 5, 8]:
if len(suggestions) >= n:
break
s = generate_sample(model, tokenizer, prompt, temperature=temp, top_k=k)
if s and s.lower() not in seen:
suggestions.append(s)
seen.add(s.lower())
if len(suggestions) >= n:
break
return suggestions[:n]
def find_checkpoint():
"""Find the best available checkpoint."""
import os
# Check for v3 meta to see if training completed
v3_meta = os.path.join(CHECKPOINT_DIR, "v3_meta.json")
if os.path.exists(v3_meta):
candidates = [
("v3_best.safetensors", "v3_tokenizer.json", "v3 (all phases)"),
("full_best.safetensors", "full_tokenizer.json", "v2 (phase 1+2)"),
]
else:
# v3 still training — prefer v2 which is complete
candidates = [
("full_best.safetensors", "full_tokenizer.json", "v2 (phase 1+2)"),
("v3_best.safetensors", "v3_tokenizer.json", "v3 (training...)"),
]
candidates.append(("curriculum_best.safetensors", "curriculum_tokenizer.json", "curriculum"))
for weights, tok, desc in candidates:
wp = os.path.join(CHECKPOINT_DIR, weights)
tp = os.path.join(CHECKPOINT_DIR, tok)
if os.path.exists(wp) and os.path.exists(tp):
return wp, tp, desc
return None, None, None
def load_model_config(tokenizer_path):
"""Infer model config from metadata or tokenizer."""
import os
meta_candidates = [
os.path.join(CHECKPOINT_DIR, "v3_meta.json"),
os.path.join(CHECKPOINT_DIR, "full_meta.json"),
]
for mp in meta_candidates:
if os.path.exists(mp):
with open(mp) as f:
meta = json.load(f)
vs = meta.get("vocab_size", 0)
np_ = meta.get("n_params", 0)
if vs and np_:
return vs, np_
# Infer from tokenizer
tok = SimpleTokenizer.load(tokenizer_path)
return tok.vocab_size, 0
def main():
weights_path, tok_path, desc = find_checkpoint()
if not weights_path:
print("No checkpoint found! Train a model first.")
return
print("=" * 50)
print(" AAC Micro Brain — Chat")
print("=" * 50)
print(f"\n Loading: {desc}")
tokenizer = SimpleTokenizer.load(tok_path)
vocab_size = tokenizer.vocab_size
print(f" Vocab: {vocab_size} words")
# Auto-detect model architecture from param count
# Try to load metadata
import os
meta_path = weights_path.replace("_best.safetensors", "_meta.json")
n_params = 0
if os.path.exists(meta_path):
with open(meta_path) as f:
meta = json.load(f)
n_params = meta.get("n_params", 0)
# Choose architecture based on param count or vocab size
if n_params > 15_000_000 or vocab_size > 5500:
d, h, L, dff = 512, 8, 6, 1024
elif n_params > 6_000_000 or vocab_size > 4000:
d, h, L, dff = 384, 6, 5, 768
elif n_params > 2_000_000:
d, h, L, dff = 256, 4, 4, 512
elif n_params > 500_000:
d, h, L, dff = 128, 4, 3, 256
else:
d, h, L, dff = 64, 2, 2, 128
model = MicroBrain(
vocab_size=vocab_size,
d_model=d, n_heads=h, n_layers=L, d_ff=dff,
max_seq_len=32,
)
model.load_weights(weights_path)
mx.eval(model.parameters())
from mlx.utils import tree_flatten
actual_params = sum(v.size for _, v in tree_flatten(model.parameters()))
print(f" Model: {actual_params:,} params ({actual_params/1e6:.1f}M)")
print(f" Architecture: d={d} h={h} L={L}")
print("\n Type a phrase. The model suggests responses.")
print(" Type 'quit' to exit.\n" + "-" * 50)
while True:
try:
user_input = input("\n>> ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye!")
break
if not user_input or user_input.lower() in ("quit", "exit", "q"):
print("Bye!")
break
# Greedy response
greedy = generate_greedy(model, tokenizer, user_input)
print(f"\n Best: {greedy}")
# Multiple suggestions
suggestions = generate_suggestions(model, tokenizer, user_input, n=6)
if len(suggestions) > 1:
print(" Alternatives:")
for i, s in enumerate(suggestions[1:], 2):
print(f" {i}. {s}")
if __name__ == "__main__":
main()