"""Gradio interface for nanochat model.""" from __future__ import annotations import os from collections.abc import Generator from pathlib import Path from typing import Any import gradio as gr from huggingface_hub import snapshot_download from model import NanochatModel MODEL_REPO = os.environ.get("MODEL_REPO", "Guilherme34/nanochat-retrained-pytorch-duplicated") MODEL_DIR = os.environ.get("MODEL_DIR", "./model_cache") _model: NanochatModel | None = None def download_model() -> None: """Download the model from Hugging Face if needed.""" model_path = Path(MODEL_DIR) if not model_path.exists() or not any(model_path.iterdir()): snapshot_download( repo_id=MODEL_REPO, local_dir=MODEL_DIR, ) def load_model() -> None: """Load the nanochat model.""" global _model if _model is None: download_model() _model = NanochatModel(model_dir=MODEL_DIR, device="cpu") load_model() def respond( message: str, history: list[dict[str, str]], temperature: float, top_k: int, system_prompt: str, # NEW ) -> Generator[str, Any, None]: """Generate a response using the nanochat model. Args: message: User's input message history: Chat history in Gradio messages format temperature: Sampling temperature top_k: Top-k sampling parameter system_prompt: Optional system message to steer behavior Yields: Incrementally generated response text """ conversation: list[dict[str, str]] = [] # If a system message is provided, put it at the start of the conversation. conversation.append({"role": "system", "content": system_prompt.strip()}) # Replay prior turns for msg in history: conversation.append(msg) # Current user turn conversation.append({"role": "user", "content": message}) response = "" for token in _model.generate( history=conversation, max_tokens=512, temperature=temperature, top_k=top_k, ): response += token yield response chatbot = gr.ChatInterface( respond, type="messages", additional_inputs=[ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=1, maximum=200, value=50, step=1, label="Top-k sampling", ), gr.Textbox( # NEW label="System message (optional)", placeholder="e.g., You are a concise assistant that answers in markdown.", lines=3, ), ], ) with gr.Blocks(title="nanochat") as demo: gr.Markdown("# nanochat") gr.Markdown("Chat with an AI trained in 4 hours for $100") gr.Markdown( "**Note:** This model is a research experiment. " "Obviously do not rely on the outputs!", ) chatbot.render() if __name__ == "__main__": demo.launch()