| """Flask chat interface for the trained GPT model.""" |
| import os |
| from flask import Flask, request, Response, stream_with_context |
|
|
| from tokenizer import load_tokenizer |
| from generate import load_model, get_device |
|
|
| CHECKPOINT = os.environ.get("CHECKPOINT", "checkpoints/best_model.pt") |
| TOKENIZER = os.environ.get("TOKENIZER", "tokenizer.json") |
|
|
| device = get_device() |
| tokenizer = load_tokenizer(TOKENIZER) |
| model = load_model(CHECKPOINT, device) |
|
|
| app = Flask(__name__) |
|
|
|
|
| @app.route("/") |
| def index(): |
| return HTML |
|
|
|
|
| @app.route("/generate", methods=["POST"]) |
| def generate(): |
| import torch |
| data = request.json |
| prompt = data.get("prompt", "").strip() or "Once upon a time" |
| max_tokens = int(data.get("max_tokens", 100)) |
| temperature = float(data.get("temperature", 0.8)) |
| top_k = int(data.get("top_k", 40)) |
|
|
| encoded = tokenizer.encode(prompt) |
| if not encoded: |
| encoded = [0] |
| idx = torch.tensor([encoded], dtype=torch.long, device=device) |
|
|
| def stream(): |
| for token_id in model.stream(idx, max_tokens, temperature=temperature, top_k=top_k): |
| yield tokenizer.decode([token_id]) |
|
|
| return Response(stream_with_context(stream()), mimetype="text/plain") |
|
|
|
|
| HTML = """<!DOCTYPE html> |
| <html lang="en"> |
| <head> |
| <meta charset="UTF-8"/> |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"/> |
| <title>My GPT</title> |
| <style> |
| * { box-sizing: border-box; margin: 0; padding: 0; } |
| |
| body { |
| font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; |
| background: #1a1a2e; |
| color: #e0e0e0; |
| height: 100vh; |
| display: flex; |
| flex-direction: column; |
| align-items: center; |
| } |
| |
| header { |
| width: 100%; |
| max-width: 800px; |
| padding: 20px 24px 12px; |
| border-bottom: 1px solid #2a2a4a; |
| } |
| header h1 { font-size: 1.3rem; color: #a78bfa; } |
| header p { font-size: 0.8rem; color: #666; margin-top: 2px; } |
| |
| #chat { |
| flex: 1; |
| width: 100%; |
| max-width: 800px; |
| overflow-y: auto; |
| padding: 24px 24px 8px; |
| display: flex; |
| flex-direction: column; |
| gap: 16px; |
| } |
| |
| .bubble { |
| max-width: 80%; |
| padding: 12px 16px; |
| border-radius: 18px; |
| line-height: 1.6; |
| font-size: 0.95rem; |
| white-space: pre-wrap; |
| word-break: break-word; |
| } |
| .user-wrap { display: flex; justify-content: flex-end; } |
| .model-wrap { display: flex; justify-content: flex-start; } |
| |
| .user { background: #4f46e5; color: #fff; border-bottom-right-radius: 4px; } |
| .model { background: #1e1e3a; color: #e0e0e0; border-bottom-left-radius: 4px; border: 1px solid #2a2a4a; } |
| |
| .typing::after { |
| content: "▌"; |
| animation: blink 0.7s step-end infinite; |
| } |
| @keyframes blink { 50% { opacity: 0; } } |
| |
| #controls { |
| width: 100%; |
| max-width: 800px; |
| padding: 12px 24px 24px; |
| } |
| |
| #input-row { |
| display: flex; |
| gap: 10px; |
| } |
| |
| #prompt { |
| flex: 1; |
| background: #1e1e3a; |
| border: 1px solid #2a2a4a; |
| border-radius: 12px; |
| color: #e0e0e0; |
| padding: 12px 16px; |
| font-size: 0.95rem; |
| outline: none; |
| resize: none; |
| height: 50px; |
| line-height: 1.5; |
| transition: border-color 0.2s; |
| } |
| #prompt:focus { border-color: #4f46e5; } |
| |
| #send { |
| background: #4f46e5; |
| color: white; |
| border: none; |
| border-radius: 12px; |
| padding: 0 20px; |
| font-size: 1.1rem; |
| cursor: pointer; |
| transition: background 0.2s; |
| } |
| #send:hover { background: #4338ca; } |
| #send:disabled { background: #333; cursor: not-allowed; } |
| </style> |
| </head> |
| <body> |
| |
| <header> |
| <h1>My GPT</h1> |
| <p>Trained from scratch on Shakespeare + Alpaca — 30M parameters</p> |
| </header> |
| |
| <div id="chat"></div> |
| |
| <div id="controls"> |
| <div id="input-row"> |
| <textarea id="prompt" placeholder="Type a prompt and press Enter…"></textarea> |
| <button id="send" onclick="sendMessage()">►</button> |
| </div> |
| </div> |
| |
| <script> |
| const chat = document.getElementById("chat"); |
| const prompt = document.getElementById("prompt"); |
| const send = document.getElementById("send"); |
| |
| prompt.addEventListener("keydown", e => { |
| if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); sendMessage(); } |
| }); |
| |
| function addBubble(text, role) { |
| const wrap = document.createElement("div"); |
| wrap.className = role === "user" ? "user-wrap" : "model-wrap"; |
| const bubble = document.createElement("div"); |
| bubble.className = "bubble " + role; |
| bubble.textContent = text; |
| wrap.appendChild(bubble); |
| chat.appendChild(wrap); |
| chat.scrollTop = chat.scrollHeight; |
| return bubble; |
| } |
| |
| async function sendMessage() { |
| const text = prompt.value.trim(); |
| if (!text) return; |
| prompt.value = ""; |
| send.disabled = true; |
| |
| addBubble(text, "user"); |
| |
| const modelBubble = addBubble("", "model"); |
| modelBubble.classList.add("typing"); |
| |
| const res = await fetch("/generate", { |
| method: "POST", |
| headers: { "Content-Type": "application/json" }, |
| body: JSON.stringify({ |
| prompt: text, |
| max_tokens: 150, |
| temperature: 0.8, |
| top_k: 40, |
| }) |
| }); |
| |
| const reader = res.body.getReader(); |
| const decoder = new TextDecoder(); |
| let output = ""; |
| |
| while (true) { |
| const { done, value } = await reader.read(); |
| if (done) break; |
| output += decoder.decode(value); |
| modelBubble.textContent = output; |
| chat.scrollTop = chat.scrollHeight; |
| } |
| |
| modelBubble.classList.remove("typing"); |
| send.disabled = false; |
| prompt.focus(); |
| } |
| </script> |
| </body> |
| </html> |
| """ |
|
|
| if __name__ == "__main__": |
| print(f"Model loaded ({model.num_params():,} params) on {device}") |
| print("Open http://127.0.0.1:5000 in your browser") |
| app.run(host="127.0.0.1", port=5000, debug=False) |
|
|