SimpleAI-259M / app.py
suraj-self's picture
last commit
d598116
import os
import torch
import gradio as gr
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer
# --- System Initialization ---
TOKENIZER_DIR = "."
tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)
# Map Special Tokens
tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>")
tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>")
# Model Setup
config = GPTConfig(
vocab_size=32768,
n_layer=12,
n_head=6,
n_embd=768,
sequence_len=2048
)
model = GPT(config)
print("Loading model weights...")
state_dict = torch.load("model_000971.pt", map_location="cpu")
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()
def predict(message, history):
try:
tokens = [tokenizer.bos_token_id]
user_content = str(message).strip()
tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_content) + [tokenizer.user_end_id])
tokens.append(tokenizer.assistant_start_id)
with torch.no_grad():
output = model.generate(
tokens,
max_tokens=512,
temperature=0.75,
top_k=40
)
generated_text = ""
for token in output:
token_id = token if isinstance(token, int) else token.item()
char = tokenizer.decode([token_id])
if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]):
break
generated_text += char
yield generated_text.strip()
except Exception as e:
yield f"⚠️ System Error: {str(e)}"
# --- UI Customization for Gradio 6.0 ---
with gr.Blocks() as demo:
gr.ChatInterface(
fn=predict,
title="⚡ SimpleAI-259M",
description="**Fast. Focused. Simple.** A lightweight general intelligence model optimized for reasoning and logic.",
examples=[
"Explain neural network?",
"Write a python function to calculate the area of a circle.",
"Why is the sky blue?"
]
)
if __name__ == "__main__":
# Moved 'theme' here as requested by the Gradio 6.0 Warning
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
)