import gradio as gr import torch import sys import os import pickle from huggingface_hub import hf_hub_download sys.path.insert(0, os.path.dirname(__file__)) from nanochat.gpt import GPT, GPTConfig print("Downloading model files...") model_path = hf_hub_download( repo_id="TMVishnu/nanochat-distill-d12-int8", filename="model.pt" ) tokenizer_pkl_path = hf_hub_download( repo_id="TMVishnu/nanochat-distill-d12-int8", filename="tokenizer.pkl" ) token_bytes_path = hf_hub_download( repo_id="TMVishnu/nanochat-distill-d12-int8", filename="token_bytes.pt" ) device = "cpu" print(f"Loading model on {device}") checkpoint = torch.load(model_path, map_location=device, weights_only=False) quantized_weights = checkpoint["quantized_weights"] scales = checkpoint["scales"] config_dict = checkpoint["config"] bits = checkpoint["bits"] print(f"Dequantizing INT{bits} weights...") model_state = {} for key, qweight in quantized_weights.items(): if key in scales: scale = scales[key] model_state[key] = qweight.float() * scale else: model_state[key] = qweight config = GPTConfig(**config_dict) model = GPT(config) model.load_state_dict(model_state, strict=False) model.eval() model.to(device) with open(tokenizer_pkl_path, "rb") as f: tokenizer = pickle.load(f) print("Model loaded successfully") def generate_text(prompt, max_tokens=50, temperature=0.8): try: tokens = tokenizer.encode(prompt) x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) with torch.no_grad(): for _ in range(max_tokens): logits = model(x) logits = logits[:, -1, :] / temperature probs = torch.nn.functional.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) x = torch.cat([x, next_token], dim=1) output_tokens = x[0].tolist() output = tokenizer.decode(output_tokens) return output except Exception as e: import traceback return f"Error: {str(e)}\n\n{traceback.format_exc()}" demo = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(label="Prompt", placeholder="Enter your prompt", lines=3), gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max Tokens"), gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") ], outputs=gr.Textbox(label="Generated Text", lines=10), title="NanoChat Distilled Model INT8", description="375M parameter student with MQA and INT8 quantization. Warning: Output quality is limited by undertrained teacher.", examples=[ ["What is the capital of France?", 50, 0.7], ["Explain machine learning in simple terms", 100, 0.8], ["Write a haiku about coding", 50, 1.0] ] ) demo.launch()