import os import sys import torch import json import gradio as gr from contextlib import nullcontext # Add current directory to path so we can import nanochat sys.path.append(os.path.dirname(__file__)) from nanochat.gpt import GPT, GPTConfig from nanochat.tokenizer import RustBPETokenizer from nanochat.engine import Engine # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- DEVICE = "cpu" # Hugging Face Free Tier is CPU only HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "MGow/PicoChat") MODEL_FILENAME = "model.pt" META_FILENAME = "meta.json" TOKENIZER_FILENAME = "tokenizer.pkl" print(f"Initializing PicoChat on {DEVICE}...") # ----------------------------------------------------------------------------- # Load Components # ----------------------------------------------------------------------------- from huggingface_hub import hf_hub_download def get_file_path(filename): """Download file from HF Hub if not local, or return local path""" if os.path.exists(filename): return filename print(f"Downloading {filename} from {HF_MODEL_REPO}...") try: return hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename) except Exception as e: print(f"Error downloading {filename}: {e}") # Fallback for testing/building if files are local return filename # 1. Load Metadata meta_path = get_file_path(META_FILENAME) print(f"Loading metadata from {meta_path}...") with open(meta_path, "r") as f: meta = json.load(f) model_config = meta["model_config"] print(f"Model config: {model_config}") # 2. Load Tokenizer tok_path = get_file_path(TOKENIZER_FILENAME) print(f"Loading tokenizer from {tok_path}...") with open(tok_path, "rb") as f: import pickle # The tokenizer.pkl contains the tiktoken Encoding object enc = pickle.load(f) # Re-construct RustBPETokenizer (wrapper around tiktoken) # We use <|bos|> as the start token tokenizer = RustBPETokenizer(enc, "<|bos|>") # 3. Load Model model_path = get_file_path(MODEL_FILENAME) print(f"Loading model from {model_path}...") # Initialize model with config model = GPT(GPTConfig(**model_config)) # Load state dict # map_location=DEVICE ensures we load directly to CPU state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True) # Fix torch compile prefix if present (remove _orig_mod.) state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()} # Ensure float32 for CPU (bfloat16 not supported on all CPUs perfectly, and float32 is safer for inference) state_dict = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in state_dict.items()} # Load weights model.load_state_dict(state_dict) model.to(DEVICE) model.eval() print("Model loaded successfully!") # 4. Create Engine engine = Engine(model, tokenizer) # ----------------------------------------------------------------------------- # Chat Logic # ----------------------------------------------------------------------------- def chat_function(message, history): """ message: str, current user message history: list of [user_msg, bot_msg] from previous turns """ # Prepare special tokens bos = tokenizer.get_bos_token_id() user_start = tokenizer.encode_special("<|user_start|>") user_end = tokenizer.encode_special("<|user_end|>") assistant_start = tokenizer.encode_special("<|assistant_start|>") assistant_end = tokenizer.encode_special("<|assistant_end|>") # Build conversation tokens conversation_tokens = [bos] # Add history for user_msg, assistant_msg in history: if user_msg: conversation_tokens.append(user_start) conversation_tokens.extend(tokenizer.encode(user_msg)) conversation_tokens.append(user_end) if assistant_msg: conversation_tokens.append(assistant_start) conversation_tokens.extend(tokenizer.encode(assistant_msg)) conversation_tokens.append(assistant_end) # Add current message conversation_tokens.append(user_start) conversation_tokens.extend(tokenizer.encode(message)) conversation_tokens.append(user_end) # Prime assistant conversation_tokens.append(assistant_start) # Generation parameters generate_kwargs = { "num_samples": 1, "max_tokens": 512, "temperature": 0.7, "top_k": 50, } response_text = "" # Generate stream # Engine.generate yields (token_column, token_masks) for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs): token = token_column[0] # Stop if assistant ends if token == assistant_end: break # Decode and append text_chunk = tokenizer.decode([token]) response_text += text_chunk # Yield partial response for streaming UI yield response_text # ----------------------------------------------------------------------------- # Gradio UI # ----------------------------------------------------------------------------- custom_css = """ .gradio-container { font-family: 'Inter', sans-serif; } """ demo = gr.ChatInterface( fn=chat_function, title="PicoChat", description=""" **PicoChat** is a 335M parameter model trained from scratch on a MacBook Air M2. It is based on the **NanoChat** framework built by King Andrej Karpathy, and ported to Apple Silicon by Duke Michal Gow. It knows how to chat, do basic math, and tell stories. It is NOT ChatGPT (it's much smaller), but it runs purely on CPU here. """, examples=[ "Tell me a story about a robot named beep.", "What is 25 * 12?", "Explain gravity to a 5 year old.", "Write a python function to calculate fibonacci." ], cache_examples=False, ) if __name__ == "__main__": demo.launch()