| """Gradio interface for nanochat model.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from collections.abc import Generator |
| from pathlib import Path |
| from typing import Any |
|
|
| import gradio as gr |
| from huggingface_hub import snapshot_download |
|
|
| from model import NanochatModel |
|
|
| MODEL_REPO = os.environ.get("MODEL_REPO", "sdobson/nanochat") |
| MODEL_DIR = os.environ.get("MODEL_DIR", "./model_cache") |
| _model: NanochatModel | None = None |
|
|
|
|
| def download_model() -> None: |
| """Download the model from Hugging Face if needed.""" |
| model_path = Path(MODEL_DIR) |
| if not model_path.exists() or not any(model_path.iterdir()): |
| snapshot_download( |
| repo_id=MODEL_REPO, |
| local_dir=MODEL_DIR, |
| ) |
|
|
|
|
| def load_model() -> None: |
| """Load the nanochat model.""" |
| global _model |
| if _model is None: |
| download_model() |
| _model = NanochatModel(model_dir=MODEL_DIR, device="cpu") |
|
|
|
|
| load_model() |
|
|
|
|
| def respond( |
| message: str, |
| history: list[dict[str, str]], |
| temperature: float, |
| top_k: int, |
| ) -> Generator[str, Any, None]: |
| """Generate a response using the nanochat model. |
| Args: |
| message: User's input message |
| history: Chat history in Gradio messages format |
| temperature: Sampling temperature |
| top_k: Top-k sampling parameter |
| Yields: |
| Incrementally generated response text |
| """ |
| conversation = [] |
|
|
| for msg in history: |
| conversation.append(msg) |
|
|
| conversation.append({"role": "user", "content": message}) |
|
|
| response = "" |
| for token in _model.generate( |
| history=conversation, |
| max_tokens=512, |
| temperature=temperature, |
| top_k=top_k, |
| ): |
| response += token |
| yield response |
|
|
|
|
| chatbot = gr.ChatInterface( |
| respond, |
| type="messages", |
| additional_inputs=[ |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"), |
| gr.Slider( |
| minimum=1, |
| maximum=200, |
| value=50, |
| step=1, |
| label="Top-k sampling", |
| ), |
| ], |
| ) |
|
|
| with gr.Blocks(title="nanochat") as demo: |
| gr.Markdown("# nanochat") |
| gr.Markdown("Chat with an AI trained in 4 hours for $100, [Details](https://axiilay.com/posts/andrej-karpathy-new-project-nanochat)") |
|
|
| chatbot.render() |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |