| import os |
| import json |
| import threading |
| from flask import Flask, request, jsonify, render_template, Response, stream_with_context |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextIteratorStreamer |
|
|
| app = Flask(__name__) |
|
|
| |
| MODEL_ID = "Qwen/Qwen3.5-0.8B" |
| LOAD_IN_4BIT = True |
|
|
| tokenizer = None |
| model = None |
| model_lock = threading.Lock() |
| load_error: str | None = None |
|
|
|
|
| def load_model(): |
| global tokenizer, model, load_error |
| try: |
| print(f"[*] Loading tokenizer for {MODEL_ID} …") |
| global tokenizer |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
| print(f"[*] Loading model ({'4-bit quant' if LOAD_IN_4BIT else 'full precision'}) …") |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=torch.float16, |
| bnb_4bit_use_double_quant=True, |
| ) if LOAD_IN_4BIT else None |
|
|
| global model |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| quantization_config=bnb_config, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| ) |
| model.eval() |
| print("[*] Model ready.") |
| except Exception as e: |
| load_error = str(e) |
| print(f"[!] Model load failed: {e}") |
|
|
|
|
| threading.Thread(target=load_model, daemon=True).start() |
|
|
|
|
| |
|
|
| @app.route("/") |
| def index(): |
| return render_template("index.html") |
|
|
|
|
| @app.route("/health") |
| def health(): |
| if load_error: |
| return jsonify({"status": "error", "detail": load_error}), 500 |
| ready = model is not None and tokenizer is not None |
| return jsonify({"status": "ready" if ready else "loading"}) |
|
|
|
|
| @app.route("/v1/chat/completions", methods=["POST"]) |
| def chat_completions(): |
| """OpenAI-compatible streaming chat completions endpoint.""" |
| if model is None or tokenizer is None: |
| return jsonify({"error": load_error or "Model is still loading."}), 503 |
|
|
| data = request.get_json(force=True) |
| messages = data.get("messages", []) |
| max_tokens = int(data.get("max_tokens", 1024)) |
| temperature = float(data.get("temperature", 0.7)) |
| top_p = float(data.get("top_p", 0.9)) |
| stream = bool(data.get("stream", False)) |
|
|
| inputs = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True, |
| ).to(model.device) |
| input_len = inputs["input_ids"].shape[-1] |
|
|
| gen_kwargs = dict( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=temperature > 0, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| if stream: |
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| timeout=30.0, |
| ) |
| gen_kwargs["streamer"] = streamer |
| gen_thread = threading.Thread(target=lambda: model.generate(**gen_kwargs), daemon=True) |
|
|
| def event_stream(): |
| gen_thread.start() |
| try: |
| for token_text in streamer: |
| if token_text: |
| payload = { |
| "choices": [{ |
| "delta": {"content": token_text}, |
| "finish_reason": None |
| }] |
| } |
| yield f"data: {json.dumps(payload)}\n\n" |
| except Exception as e: |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| finally: |
| payload = {"choices": [{"delta": {}, "finish_reason": "stop"}]} |
| yield f"data: {json.dumps(payload)}\n\n" |
| yield "data: [DONE]\n\n" |
|
|
| return Response( |
| stream_with_context(event_stream()), |
| mimetype="text/event-stream", |
| headers={ |
| "X-Accel-Buffering": "no", |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| }, |
| ) |
|
|
| |
| with model_lock: |
| output_ids = model.generate(**gen_kwargs) |
|
|
| new_tokens = output_ids[0][input_len:] |
| text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
| return jsonify({ |
| "choices": [{"message": {"role": "assistant", "content": text}, "finish_reason": "stop"}], |
| "model": MODEL_ID, |
| }) |
|
|
|
|
| @app.route("/generate", methods=["POST"]) |
| def generate_simple(): |
| """Streaming single-turn endpoint for the UI — accepts messages array or plain prompt.""" |
| if model is None or tokenizer is None: |
| return jsonify({"error": load_error or "Model is still loading."}), 503 |
|
|
| data = request.get_json(force=True) |
| messages = data.get("messages", [{"role": "user", "content": data.get("prompt", "")}]) |
| max_tokens = int(data.get("max_tokens", 1024)) |
| temperature= float(data.get("temperature", 0.7)) |
| do_stream = bool(data.get("stream", True)) |
|
|
| inputs = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True, |
| ).to(model.device) |
| input_len = inputs["input_ids"].shape[-1] |
|
|
| gen_kwargs = dict( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=temperature > 0, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| if do_stream: |
| streamer = TextIteratorStreamer( |
| tokenizer, |
| skip_prompt=True, |
| skip_special_tokens=True, |
| timeout=30.0, |
| ) |
| gen_kwargs["streamer"] = streamer |
| gen_thread = threading.Thread(target=lambda: model.generate(**gen_kwargs), daemon=True) |
|
|
| def event_stream(): |
| gen_thread.start() |
| try: |
| for token_text in streamer: |
| if token_text: |
| yield f"data: {json.dumps({'token': token_text})}\n\n" |
| except Exception as e: |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" |
| finally: |
| yield "data: [DONE]\n\n" |
|
|
| return Response( |
| stream_with_context(event_stream()), |
| mimetype="text/event-stream", |
| headers={ |
| "X-Accel-Buffering": "no", |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| }, |
| ) |
|
|
| |
| with model_lock: |
| output_ids = model.generate(**gen_kwargs) |
| new_tokens = output_ids[0][input_len:] |
| text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
| return jsonify({"response": text}) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| |
| app.run(host="0.0.0.0", port=7860, debug=False, threaded=True) |