broadfield-dev's picture
Update app.py
66e1d40 verified
import os
import json
import threading
from flask import Flask, request, jsonify, render_template, Response, stream_with_context
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer
app = Flask(__name__)
# ── Model configuration ───────────────────────────────────────────────────────
MODEL_ID = "Qwen/Qwen3.5-0.8B" #Qwen3.5-9B
LOAD_IN_4BIT = True
tokenizer = None
model = None
model_lock = threading.Lock()
load_error: str | None = None
def load_model():
global tokenizer, model, load_error
try:
print(f"[*] Loading tokenizer for {MODEL_ID} …")
global tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
print(f"[*] Loading model ({'4-bit quant' if LOAD_IN_4BIT else 'full precision'}) …")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
) if LOAD_IN_4BIT else None
global model
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
model.eval()
print("[*] Model ready.")
except Exception as e:
load_error = str(e)
print(f"[!] Model load failed: {e}")
threading.Thread(target=load_model, daemon=True).start()
# ── Routes ────────────────────────────────────────────────────────────────────
@app.route("/")
def index():
return render_template("index.html")
@app.route("/health")
def health():
if load_error:
return jsonify({"status": "error", "detail": load_error}), 500
ready = model is not None and tokenizer is not None
return jsonify({"status": "ready" if ready else "loading"})
@app.route("/v1/chat/completions", methods=["POST"])
def chat_completions():
"""OpenAI-compatible streaming chat completions endpoint."""
if model is None or tokenizer is None:
return jsonify({"error": load_error or "Model is still loading."}), 503
data = request.get_json(force=True)
messages = data.get("messages", [])
max_tokens = int(data.get("max_tokens", 1024))
temperature = float(data.get("temperature", 0.7))
top_p = float(data.get("top_p", 0.9))
stream = bool(data.get("stream", False))
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
gen_kwargs = dict(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=tokenizer.eos_token_id,
)
if stream:
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=30.0,
)
gen_kwargs["streamer"] = streamer
gen_thread = threading.Thread(target=lambda: model.generate(**gen_kwargs), daemon=True)
def event_stream():
gen_thread.start()
try:
for token_text in streamer:
if token_text:
payload = {
"choices": [{
"delta": {"content": token_text},
"finish_reason": None
}]
}
yield f"data: {json.dumps(payload)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
payload = {"choices": [{"delta": {}, "finish_reason": "stop"}]}
yield f"data: {json.dumps(payload)}\n\n"
yield "data: [DONE]\n\n"
return Response(
stream_with_context(event_stream()),
mimetype="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
# Non-streaming fallback
with model_lock:
output_ids = model.generate(**gen_kwargs)
new_tokens = output_ids[0][input_len:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
return jsonify({
"choices": [{"message": {"role": "assistant", "content": text}, "finish_reason": "stop"}],
"model": MODEL_ID,
})
@app.route("/generate", methods=["POST"])
def generate_simple():
"""Streaming single-turn endpoint for the UI — accepts messages array or plain prompt."""
if model is None or tokenizer is None:
return jsonify({"error": load_error or "Model is still loading."}), 503
data = request.get_json(force=True)
messages = data.get("messages", [{"role": "user", "content": data.get("prompt", "")}])
max_tokens = int(data.get("max_tokens", 1024))
temperature= float(data.get("temperature", 0.7))
do_stream = bool(data.get("stream", True))
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
input_len = inputs["input_ids"].shape[-1]
gen_kwargs = dict(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=tokenizer.eos_token_id,
)
if do_stream:
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=30.0,
)
gen_kwargs["streamer"] = streamer
gen_thread = threading.Thread(target=lambda: model.generate(**gen_kwargs), daemon=True)
def event_stream():
gen_thread.start()
try:
for token_text in streamer:
if token_text:
yield f"data: {json.dumps({'token': token_text})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
yield "data: [DONE]\n\n"
return Response(
stream_with_context(event_stream()),
mimetype="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
# Non-streaming fallback
with model_lock:
output_ids = model.generate(**gen_kwargs)
new_tokens = output_ids[0][input_len:]
text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
return jsonify({"response": text})
# ── Entry point ───────────────────────────────────────────────────────────────
if __name__ == "__main__":
# threaded=True is required — each SSE connection needs its own thread
app.run(host="0.0.0.0", port=7860, debug=False, threaded=True)