File size: 5,976 Bytes
491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed becbb55 491d2ed 2f40023 491d2ed becbb55 491d2ed becbb55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import os
import sys
import torch
import json
import gradio as gr
from contextlib import nullcontext
# Add current directory to path so we can import nanochat
sys.path.append(os.path.dirname(__file__))
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer
from nanochat.engine import Engine
# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
DEVICE = "cpu" # Hugging Face Free Tier is CPU only
HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "MGow/PicoChat")
MODEL_FILENAME = "model.pt"
META_FILENAME = "meta.json"
TOKENIZER_FILENAME = "tokenizer.pkl"
print(f"Initializing PicoChat on {DEVICE}...")
# -----------------------------------------------------------------------------
# Load Components
# -----------------------------------------------------------------------------
from huggingface_hub import hf_hub_download
def get_file_path(filename):
"""Download file from HF Hub if not local, or return local path"""
if os.path.exists(filename):
return filename
print(f"Downloading {filename} from {HF_MODEL_REPO}...")
try:
return hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename)
except Exception as e:
print(f"Error downloading {filename}: {e}")
# Fallback for testing/building if files are local
return filename
# 1. Load Metadata
meta_path = get_file_path(META_FILENAME)
print(f"Loading metadata from {meta_path}...")
with open(meta_path, "r") as f:
meta = json.load(f)
model_config = meta["model_config"]
print(f"Model config: {model_config}")
# 2. Load Tokenizer
tok_path = get_file_path(TOKENIZER_FILENAME)
print(f"Loading tokenizer from {tok_path}...")
with open(tok_path, "rb") as f:
import pickle
# The tokenizer.pkl contains the tiktoken Encoding object
enc = pickle.load(f)
# Re-construct RustBPETokenizer (wrapper around tiktoken)
# We use <|bos|> as the start token
tokenizer = RustBPETokenizer(enc, "<|bos|>")
# 3. Load Model
model_path = get_file_path(MODEL_FILENAME)
print(f"Loading model from {model_path}...")
# Initialize model with config
model = GPT(GPTConfig(**model_config))
# Load state dict
# map_location=DEVICE ensures we load directly to CPU
state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True)
# Fix torch compile prefix if present (remove _orig_mod.)
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}
# Ensure float32 for CPU (bfloat16 not supported on all CPUs perfectly, and float32 is safer for inference)
state_dict = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in state_dict.items()}
# Load weights
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("Model loaded successfully!")
# 4. Create Engine
engine = Engine(model, tokenizer)
# -----------------------------------------------------------------------------
# Chat Logic
# -----------------------------------------------------------------------------
def chat_function(message, history):
"""
message: str, current user message
history: list of [user_msg, bot_msg] from previous turns
"""
# Prepare special tokens
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
assistant_end = tokenizer.encode_special("<|assistant_end|>")
# Build conversation tokens
conversation_tokens = [bos]
# Add history
for user_msg, assistant_msg in history:
if user_msg:
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(user_msg))
conversation_tokens.append(user_end)
if assistant_msg:
conversation_tokens.append(assistant_start)
conversation_tokens.extend(tokenizer.encode(assistant_msg))
conversation_tokens.append(assistant_end)
# Add current message
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(message))
conversation_tokens.append(user_end)
# Prime assistant
conversation_tokens.append(assistant_start)
# Generation parameters
generate_kwargs = {
"num_samples": 1,
"max_tokens": 512,
"temperature": 0.7,
"top_k": 50,
}
response_text = ""
# Generate stream
# Engine.generate yields (token_column, token_masks)
for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
token = token_column[0]
# Stop if assistant ends
if token == assistant_end:
break
# Decode and append
text_chunk = tokenizer.decode([token])
response_text += text_chunk
# Yield partial response for streaming UI
yield response_text
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
custom_css = """
.gradio-container {
font-family: 'Inter', sans-serif;
}
"""
demo = gr.ChatInterface(
fn=chat_function,
title="PicoChat",
description="""
**PicoChat** is a 335M parameter model trained from scratch on a MacBook Air M2.
It is based on the **NanoChat** framework built by King Andrej Karpathy,
and ported to Apple Silicon by Duke Michal Gow.
It knows how to chat, do basic math, and tell stories.
It is NOT ChatGPT (it's much smaller), but it runs purely on CPU here.
""",
examples=[
"Tell me a story about a robot named beep.",
"What is 25 * 12?",
"Explain gravity to a 5 year old.",
"Write a python function to calculate fibonacci."
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|