edgemindroboticslabs's picture
Upload app.py with huggingface_hub
512ca63 verified
"""Flask chat interface for the trained GPT model."""
import os
from flask import Flask, request, Response, stream_with_context
from tokenizer import load_tokenizer
from generate import load_model, get_device
CHECKPOINT = os.environ.get("CHECKPOINT", "checkpoints/best_model.pt")
TOKENIZER = os.environ.get("TOKENIZER", "tokenizer.json")
device = get_device()
tokenizer = load_tokenizer(TOKENIZER)
model = load_model(CHECKPOINT, device)
app = Flask(__name__)
@app.route("/")
def index():
return HTML
@app.route("/generate", methods=["POST"])
def generate():
import torch
data = request.json
prompt = data.get("prompt", "").strip() or "Once upon a time"
max_tokens = int(data.get("max_tokens", 100))
temperature = float(data.get("temperature", 0.8))
top_k = int(data.get("top_k", 40))
encoded = tokenizer.encode(prompt)
if not encoded:
encoded = [0]
idx = torch.tensor([encoded], dtype=torch.long, device=device)
def stream():
for token_id in model.stream(idx, max_tokens, temperature=temperature, top_k=top_k):
yield tokenizer.decode([token_id])
return Response(stream_with_context(stream()), mimetype="text/plain")
HTML = """<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<title>My GPT</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
background: #1a1a2e;
color: #e0e0e0;
height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
}
header {
width: 100%;
max-width: 800px;
padding: 20px 24px 12px;
border-bottom: 1px solid #2a2a4a;
}
header h1 { font-size: 1.3rem; color: #a78bfa; }
header p { font-size: 0.8rem; color: #666; margin-top: 2px; }
#chat {
flex: 1;
width: 100%;
max-width: 800px;
overflow-y: auto;
padding: 24px 24px 8px;
display: flex;
flex-direction: column;
gap: 16px;
}
.bubble {
max-width: 80%;
padding: 12px 16px;
border-radius: 18px;
line-height: 1.6;
font-size: 0.95rem;
white-space: pre-wrap;
word-break: break-word;
}
.user-wrap { display: flex; justify-content: flex-end; }
.model-wrap { display: flex; justify-content: flex-start; }
.user { background: #4f46e5; color: #fff; border-bottom-right-radius: 4px; }
.model { background: #1e1e3a; color: #e0e0e0; border-bottom-left-radius: 4px; border: 1px solid #2a2a4a; }
.typing::after {
content: "▌";
animation: blink 0.7s step-end infinite;
}
@keyframes blink { 50% { opacity: 0; } }
#controls {
width: 100%;
max-width: 800px;
padding: 12px 24px 24px;
}
#input-row {
display: flex;
gap: 10px;
}
#prompt {
flex: 1;
background: #1e1e3a;
border: 1px solid #2a2a4a;
border-radius: 12px;
color: #e0e0e0;
padding: 12px 16px;
font-size: 0.95rem;
outline: none;
resize: none;
height: 50px;
line-height: 1.5;
transition: border-color 0.2s;
}
#prompt:focus { border-color: #4f46e5; }
#send {
background: #4f46e5;
color: white;
border: none;
border-radius: 12px;
padding: 0 20px;
font-size: 1.1rem;
cursor: pointer;
transition: background 0.2s;
}
#send:hover { background: #4338ca; }
#send:disabled { background: #333; cursor: not-allowed; }
</style>
</head>
<body>
<header>
<h1>My GPT</h1>
<p>Trained from scratch on Shakespeare + Alpaca &mdash; 30M parameters</p>
</header>
<div id="chat"></div>
<div id="controls">
<div id="input-row">
<textarea id="prompt" placeholder="Type a prompt and press Enter…"></textarea>
<button id="send" onclick="sendMessage()">&#9658;</button>
</div>
</div>
<script>
const chat = document.getElementById("chat");
const prompt = document.getElementById("prompt");
const send = document.getElementById("send");
prompt.addEventListener("keydown", e => {
if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); sendMessage(); }
});
function addBubble(text, role) {
const wrap = document.createElement("div");
wrap.className = role === "user" ? "user-wrap" : "model-wrap";
const bubble = document.createElement("div");
bubble.className = "bubble " + role;
bubble.textContent = text;
wrap.appendChild(bubble);
chat.appendChild(wrap);
chat.scrollTop = chat.scrollHeight;
return bubble;
}
async function sendMessage() {
const text = prompt.value.trim();
if (!text) return;
prompt.value = "";
send.disabled = true;
addBubble(text, "user");
const modelBubble = addBubble("", "model");
modelBubble.classList.add("typing");
const res = await fetch("/generate", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
prompt: text,
max_tokens: 150,
temperature: 0.8,
top_k: 40,
})
});
const reader = res.body.getReader();
const decoder = new TextDecoder();
let output = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
output += decoder.decode(value);
modelBubble.textContent = output;
chat.scrollTop = chat.scrollHeight;
}
modelBubble.classList.remove("typing");
send.disabled = false;
prompt.focus();
}
</script>
</body>
</html>
"""
if __name__ == "__main__":
print(f"Model loaded ({model.num_params():,} params) on {device}")
print("Open http://127.0.0.1:5000 in your browser")
app.run(host="127.0.0.1", port=5000, debug=False)