TMVishnu's picture
Update app.py
52b32be verified
import gradio as gr
import torch
import sys
import os
import pickle
from huggingface_hub import hf_hub_download
sys.path.insert(0, os.path.dirname(__file__))
from nanochat.gpt import GPT, GPTConfig
print("Downloading model files...")
model_path = hf_hub_download(
repo_id="TMVishnu/nanochat-distill-d12-int8",
filename="model.pt"
)
tokenizer_pkl_path = hf_hub_download(
repo_id="TMVishnu/nanochat-distill-d12-int8",
filename="tokenizer.pkl"
)
token_bytes_path = hf_hub_download(
repo_id="TMVishnu/nanochat-distill-d12-int8",
filename="token_bytes.pt"
)
device = "cpu"
print(f"Loading model on {device}")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
quantized_weights = checkpoint["quantized_weights"]
scales = checkpoint["scales"]
config_dict = checkpoint["config"]
bits = checkpoint["bits"]
print(f"Dequantizing INT{bits} weights...")
model_state = {}
for key, qweight in quantized_weights.items():
if key in scales:
scale = scales[key]
model_state[key] = qweight.float() * scale
else:
model_state[key] = qweight
config = GPTConfig(**config_dict)
model = GPT(config)
model.load_state_dict(model_state, strict=False)
model.eval()
model.to(device)
with open(tokenizer_pkl_path, "rb") as f:
tokenizer = pickle.load(f)
print("Model loaded successfully")
def generate_text(prompt, max_tokens=50, temperature=0.8):
try:
tokens = tokenizer.encode(prompt)
x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
with torch.no_grad():
for _ in range(max_tokens):
logits = model(x)
logits = logits[:, -1, :] / temperature
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
x = torch.cat([x, next_token], dim=1)
output_tokens = x[0].tolist()
output = tokenizer.decode(output_tokens)
return output
except Exception as e:
import traceback
return f"Error: {str(e)}\n\n{traceback.format_exc()}"
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt", lines=3),
gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max Tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature")
],
outputs=gr.Textbox(label="Generated Text", lines=10),
title="NanoChat Distilled Model INT8",
description="375M parameter student with MQA and INT8 quantization. Warning: Output quality is limited by undertrained teacher.",
examples=[
["What is the capital of France?", 50, 0.7],
["Explain machine learning in simple terms", 100, 0.8],
["Write a haiku about coding", 50, 1.0]
]
)
demo.launch()