File size: 5,976 Bytes
491d2ed
 
 
 
becbb55
491d2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
becbb55
491d2ed
 
becbb55
 
491d2ed
 
 
 
 
 
becbb55
491d2ed
 
becbb55
491d2ed
 
 
 
 
 
 
 
 
 
becbb55
491d2ed
 
 
 
becbb55
491d2ed
 
becbb55
491d2ed
 
 
 
 
 
 
becbb55
491d2ed
becbb55
491d2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
becbb55
491d2ed
 
 
 
 
 
2f40023
 
491d2ed
 
 
 
 
 
 
 
becbb55
491d2ed
becbb55
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import sys
import torch
import json
import gradio as gr
from contextlib import nullcontext

# Add current directory to path so we can import nanochat
sys.path.append(os.path.dirname(__file__))

from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer
from nanochat.engine import Engine

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------
DEVICE = "cpu" # Hugging Face Free Tier is CPU only
HF_MODEL_REPO = os.environ.get("HF_MODEL_REPO", "MGow/PicoChat")
MODEL_FILENAME = "model.pt"
META_FILENAME = "meta.json"
TOKENIZER_FILENAME = "tokenizer.pkl"

print(f"Initializing PicoChat on {DEVICE}...")

# -----------------------------------------------------------------------------
# Load Components
# -----------------------------------------------------------------------------

from huggingface_hub import hf_hub_download

def get_file_path(filename):
    """Download file from HF Hub if not local, or return local path"""
    if os.path.exists(filename):
        return filename
    print(f"Downloading {filename} from {HF_MODEL_REPO}...")
    try:
        return hf_hub_download(repo_id=HF_MODEL_REPO, filename=filename)
    except Exception as e:
        print(f"Error downloading {filename}: {e}")
        # Fallback for testing/building if files are local
        return filename

# 1. Load Metadata
meta_path = get_file_path(META_FILENAME)
print(f"Loading metadata from {meta_path}...")
with open(meta_path, "r") as f:
    meta = json.load(f)
model_config = meta["model_config"]
print(f"Model config: {model_config}")

# 2. Load Tokenizer
tok_path = get_file_path(TOKENIZER_FILENAME)
print(f"Loading tokenizer from {tok_path}...")
with open(tok_path, "rb") as f:
    import pickle
    # The tokenizer.pkl contains the tiktoken Encoding object
    enc = pickle.load(f)
# Re-construct RustBPETokenizer (wrapper around tiktoken)
# We use <|bos|> as the start token
tokenizer = RustBPETokenizer(enc, "<|bos|>")

# 3. Load Model
model_path = get_file_path(MODEL_FILENAME)
print(f"Loading model from {model_path}...")
# Initialize model with config
model = GPT(GPTConfig(**model_config))

# Load state dict
# map_location=DEVICE ensures we load directly to CPU
state_dict = torch.load(model_path, map_location=DEVICE, weights_only=True)

# Fix torch compile prefix if present (remove _orig_mod.)
state_dict = {k.removeprefix("_orig_mod."): v for k, v in state_dict.items()}

# Ensure float32 for CPU (bfloat16 not supported on all CPUs perfectly, and float32 is safer for inference)
state_dict = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in state_dict.items()}

# Load weights
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()

print("Model loaded successfully!")

# 4. Create Engine
engine = Engine(model, tokenizer)

# -----------------------------------------------------------------------------
# Chat Logic
# -----------------------------------------------------------------------------

def chat_function(message, history):
    """
    message: str, current user message
    history: list of [user_msg, bot_msg] from previous turns
    """

    # Prepare special tokens
    bos = tokenizer.get_bos_token_id()
    user_start = tokenizer.encode_special("<|user_start|>")
    user_end = tokenizer.encode_special("<|user_end|>")
    assistant_start = tokenizer.encode_special("<|assistant_start|>")
    assistant_end = tokenizer.encode_special("<|assistant_end|>")

    # Build conversation tokens
    conversation_tokens = [bos]

    # Add history
    for user_msg, assistant_msg in history:
        if user_msg:
            conversation_tokens.append(user_start)
            conversation_tokens.extend(tokenizer.encode(user_msg))
            conversation_tokens.append(user_end)
        if assistant_msg:
            conversation_tokens.append(assistant_start)
            conversation_tokens.extend(tokenizer.encode(assistant_msg))
            conversation_tokens.append(assistant_end)

    # Add current message
    conversation_tokens.append(user_start)
    conversation_tokens.extend(tokenizer.encode(message))
    conversation_tokens.append(user_end)

    # Prime assistant
    conversation_tokens.append(assistant_start)

    # Generation parameters
    generate_kwargs = {
        "num_samples": 1,
        "max_tokens": 512,
        "temperature": 0.7,
        "top_k": 50,
    }

    response_text = ""

    # Generate stream
    # Engine.generate yields (token_column, token_masks)
    for token_column, token_masks in engine.generate(conversation_tokens, **generate_kwargs):
        token = token_column[0]

        # Stop if assistant ends
        if token == assistant_end:
            break

        # Decode and append
        text_chunk = tokenizer.decode([token])
        response_text += text_chunk

        # Yield partial response for streaming UI
        yield response_text

# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------

custom_css = """
.gradio-container {
    font-family: 'Inter', sans-serif;
}
"""

demo = gr.ChatInterface(
    fn=chat_function,
    title="PicoChat",
    description="""
    **PicoChat** is a 335M parameter model trained from scratch on a MacBook Air M2.
    It is based on the **NanoChat** framework built by King Andrej Karpathy,
    and ported to Apple Silicon by Duke Michal Gow.
    It knows how to chat, do basic math, and tell stories.
    It is NOT ChatGPT (it's much smaller), but it runs purely on CPU here.
    """,
    examples=[
        "Tell me a story about a robot named beep.",
        "What is 25 * 12?",
        "Explain gravity to a 5 year old.",
        "Write a python function to calculate fibonacci."
    ],
    cache_examples=False,
)

if __name__ == "__main__":
    demo.launch()