Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import gradio as gr | |
| from nanochat.gpt import GPT, GPTConfig | |
| from nanochat.tokenizer import RustBPETokenizer | |
| # --- System Initialization --- | |
| TOKENIZER_DIR = "." | |
| tokenizer = RustBPETokenizer.from_directory(TOKENIZER_DIR) | |
| # Map Special Tokens | |
| tokenizer.bos_token_id = tokenizer.enc.encode_single_token("<|bos|>") | |
| tokenizer.user_start_id = tokenizer.enc.encode_single_token("<|user_start|>") | |
| tokenizer.user_end_id = tokenizer.enc.encode_single_token("<|user_end|>") | |
| tokenizer.assistant_start_id = tokenizer.enc.encode_single_token("<|assistant_start|>") | |
| tokenizer.assistant_end_id = tokenizer.enc.encode_single_token("<|assistant_end|>") | |
| # Model Setup | |
| config = GPTConfig( | |
| vocab_size=32768, | |
| n_layer=12, | |
| n_head=6, | |
| n_embd=768, | |
| sequence_len=2048 | |
| ) | |
| model = GPT(config) | |
| print("Loading model weights...") | |
| state_dict = torch.load("model_000971.pt", map_location="cpu") | |
| state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| def predict(message, history): | |
| try: | |
| tokens = [tokenizer.bos_token_id] | |
| user_content = str(message).strip() | |
| tokens.extend([tokenizer.user_start_id] + tokenizer.encode(user_content) + [tokenizer.user_end_id]) | |
| tokens.append(tokenizer.assistant_start_id) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| tokens, | |
| max_tokens=512, | |
| temperature=0.75, | |
| top_k=40 | |
| ) | |
| generated_text = "" | |
| for token in output: | |
| token_id = token if isinstance(token, int) else token.item() | |
| char = tokenizer.decode([token_id]) | |
| if any(tag in char for tag in ["<|assistant_end|>", "<|end|>", "<|user_start|>"]): | |
| break | |
| generated_text += char | |
| yield generated_text.strip() | |
| except Exception as e: | |
| yield f"⚠️ System Error: {str(e)}" | |
| # --- UI Customization for Gradio 6.0 --- | |
| with gr.Blocks() as demo: | |
| gr.ChatInterface( | |
| fn=predict, | |
| title="⚡ SimpleAI-259M", | |
| description="**Fast. Focused. Simple.** A lightweight general intelligence model optimized for reasoning and logic.", | |
| examples=[ | |
| "Explain neural network?", | |
| "Write a python function to calculate the area of a circle.", | |
| "Why is the sky blue?" | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| # Moved 'theme' here as requested by the Gradio 6.0 Warning | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate") | |
| ) |