Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import sys | |
| import os | |
| import pickle | |
| from huggingface_hub import hf_hub_download | |
| sys.path.insert(0, os.path.dirname(__file__)) | |
| from nanochat.gpt import GPT, GPTConfig | |
| print("Downloading model files...") | |
| model_path = hf_hub_download( | |
| repo_id="TMVishnu/nanochat-distill-d12-int8", | |
| filename="model.pt" | |
| ) | |
| tokenizer_pkl_path = hf_hub_download( | |
| repo_id="TMVishnu/nanochat-distill-d12-int8", | |
| filename="tokenizer.pkl" | |
| ) | |
| token_bytes_path = hf_hub_download( | |
| repo_id="TMVishnu/nanochat-distill-d12-int8", | |
| filename="token_bytes.pt" | |
| ) | |
| device = "cpu" | |
| print(f"Loading model on {device}") | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| quantized_weights = checkpoint["quantized_weights"] | |
| scales = checkpoint["scales"] | |
| config_dict = checkpoint["config"] | |
| bits = checkpoint["bits"] | |
| print(f"Dequantizing INT{bits} weights...") | |
| model_state = {} | |
| for key, qweight in quantized_weights.items(): | |
| if key in scales: | |
| scale = scales[key] | |
| model_state[key] = qweight.float() * scale | |
| else: | |
| model_state[key] = qweight | |
| config = GPTConfig(**config_dict) | |
| model = GPT(config) | |
| model.load_state_dict(model_state, strict=False) | |
| model.eval() | |
| model.to(device) | |
| with open(tokenizer_pkl_path, "rb") as f: | |
| tokenizer = pickle.load(f) | |
| print("Model loaded successfully") | |
| def generate_text(prompt, max_tokens=50, temperature=0.8): | |
| try: | |
| tokens = tokenizer.encode(prompt) | |
| x = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0) | |
| with torch.no_grad(): | |
| for _ in range(max_tokens): | |
| logits = model(x) | |
| logits = logits[:, -1, :] / temperature | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| x = torch.cat([x, next_token], dim=1) | |
| output_tokens = x[0].tolist() | |
| output = tokenizer.decode(output_tokens) | |
| return output | |
| except Exception as e: | |
| import traceback | |
| return f"Error: {str(e)}\n\n{traceback.format_exc()}" | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", placeholder="Enter your prompt", lines=3), | |
| gr.Slider(minimum=10, maximum=200, value=50, step=1, label="Max Tokens"), | |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature") | |
| ], | |
| outputs=gr.Textbox(label="Generated Text", lines=10), | |
| title="NanoChat Distilled Model INT8", | |
| description="375M parameter student with MQA and INT8 quantization. Warning: Output quality is limited by undertrained teacher.", | |
| examples=[ | |
| ["What is the capital of France?", 50, 0.7], | |
| ["Explain machine learning in simple terms", 100, 0.8], | |
| ["Write a haiku about coding", 50, 1.0] | |
| ] | |
| ) | |
| demo.launch() |