Spaces:
Sleeping
Sleeping
File size: 2,772 Bytes
f62dc29 e9e19db f62dc29 e5ba9b9 f62dc29 9a74bcc e6eeb28 f62dc29 e6eeb28 f62dc29 db8631a f62dc29 c430d50 10eadd6 f62dc29 6383c22 1efbc24 6383c22 9a74bcc 6383c22 e6eeb28 a5f28af 6383c22 9a74bcc d420af1 e11f585 d598116 d420af1 e11f585 e9e19db d420af1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import os
import torch
import gradio as gr
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import RustBPETokenizer
# --- System Initialization ---
TOKENIZER_DIR = "."
tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR)
# Map Special Tokens
tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>")
tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>")
tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>")
tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>")
tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>")
# Model Setup
config = GPTConfig(
vocab_size=32768,
n_layer=12,
n_head=6,
n_embd=768,
sequence_len=2048
)
model = GPT(config)
print("Loading model weights...")
state_dict = torch.load("model_000971.pt", map_location="cpu")
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()
def predict(message, history):
try:
tokens = [tokenizer.bos_token_id]
user_content = str(message).strip()
tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_content) + [tokenizer.user_end_id])
tokens.append(tokenizer.assistant_start_id)
with torch.no_grad():
output = model.generate(
tokens,
max_tokens=512,
temperature=0.75,
top_k=40
)
generated_text = ""
for token in output:
token_id = token if isinstance(token, int) else token.item()
char = tokenizer.decode([token_id])
if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]):
break
generated_text += char
yield generated_text.strip()
except Exception as e:
yield f"⚠️ System Error: {str(e)}"
# --- UI Customization for Gradio 6.0 ---
with gr.Blocks() as demo:
gr.ChatInterface(
fn=predict,
title="⚡ SimpleAI-259M",
description="**Fast. Focused. Simple.** A lightweight general intelligence model optimized for reasoning and logic.",
examples=[
"Explain neural network?",
"Write a python function to calculate the area of a circle.",
"Why is the sky blue?"
]
)
if __name__ == "__main__":
# Moved 'theme' here as requested by the Gradio 6.0 Warning
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate")
) |