"""Flask chat interface for the trained GPT model.""" import os from flask import Flask, request, Response, stream_with_context from tokenizer import load_tokenizer from generate import load_model, get_device CHECKPOINT = os.environ.get("CHECKPOINT", "checkpoints/best_model.pt") TOKENIZER = os.environ.get("TOKENIZER", "tokenizer.json") device = get_device() tokenizer = load_tokenizer(TOKENIZER) model = load_model(CHECKPOINT, device) app = Flask(__name__) @app.route("/") def index(): return HTML @app.route("/generate", methods=["POST"]) def generate(): import torch data = request.json prompt = data.get("prompt", "").strip() or "Once upon a time" max_tokens = int(data.get("max_tokens", 100)) temperature = float(data.get("temperature", 0.8)) top_k = int(data.get("top_k", 40)) encoded = tokenizer.encode(prompt) if not encoded: encoded = [0] idx = torch.tensor([encoded], dtype=torch.long, device=device) def stream(): for token_id in model.stream(idx, max_tokens, temperature=temperature, top_k=top_k): yield tokenizer.decode([token_id]) return Response(stream_with_context(stream()), mimetype="text/plain") HTML = """ My GPT

My GPT

Trained from scratch on Shakespeare + Alpaca — 30M parameters

""" if __name__ == "__main__": print(f"Model loaded ({model.num_params():,} params) on {device}") print("Open http://127.0.0.1:5000 in your browser") app.run(host="127.0.0.1", port=5000, debug=False)