fomext's picture
Upload app.py
bcf5e24 verified
Raw
History Blame Contribute Delete
12 kB
"""
OpenAI-compatible FastAPI wrapper for Qwen3-14B (GGUF / llama-cpp-python)
Endpoints: GET /v1/models, POST /v1/chat/completions
Supports streaming (SSE) and non-streaming responses.
Model is downloaded automatically on first boot if not already present.
"""
import os
import time
import uuid
import json
import queue
import asyncio
import logging
import threading
from pathlib import Path
from typing import AsyncIterator, List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from llama_cpp import Llama
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Config (override via environment variables)
# ---------------------------------------------------------------------------
MODEL_PATH = os.environ.get("MODEL_PATH", "/models/qwen3-14b-q4_k_m.gguf")
MODEL_URL = os.environ.get("MODEL_URL",
"https://huggingface.co/bartowski/Qwen3-14B-GGUF/resolve/main/Qwen3-14B-Q4_K_M.gguf")
MODEL_ID = os.environ.get("MODEL_ID", "qwen3-14b")
N_CTX = int(os.environ.get("N_CTX", "4096"))
N_THREADS = int(os.environ.get("N_THREADS", str(os.cpu_count() or 4)))
N_BATCH = int(os.environ.get("N_BATCH", "512"))
VERBOSE = os.environ.get("VERBOSE", "false").lower() == "true"
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# ---------------------------------------------------------------------------
# Lazy model holder
# ---------------------------------------------------------------------------
_llm: Optional[Llama] = None
_llm_lock = threading.Lock()
_llm_ready = threading.Event() # set once the model is loaded
_llm_error: Optional[str] = None # set if loading failed
# llama.cpp contexts are NOT safe to call concurrently -- there's a single
# KV cache / sampling state shared by every call into the same Llama
# instance. This lock serializes all generation calls (streaming and
# non-streaming) so two requests can never run inference at the same time.
_inference_lock = threading.Lock()
def _download_model() -> None:
"""Download the GGUF file from MODEL_URL if MODEL_PATH doesn't exist."""
path = Path(MODEL_PATH)
if path.exists():
logger.info(f"Model already present at {MODEL_PATH}")
return
path.parent.mkdir(parents=True, exist_ok=True)
# Clean up any stale partial download from a previous crashed attempt
tmp = Path(str(MODEL_PATH) + ".part")
if tmp.exists():
logger.warning(f"Removing stale partial download: {tmp}")
tmp.unlink()
logger.info(f"Model not found — downloading from {MODEL_URL} ...")
logger.info("This will take a while on first boot (file is ~9 GB).")
import urllib.request
headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {}
if not HF_TOKEN:
logger.warning("HF_TOKEN not set — download may fail for gated models.")
req = urllib.request.Request(MODEL_URL, headers=headers)
with urllib.request.urlopen(req) as response, open(tmp, "wb") as out:
total = int(response.headers.get("Content-Length", 0))
downloaded = 0
last_pct = -1
while chunk := response.read(1 << 20): # 1 MB chunks
out.write(chunk)
downloaded += len(chunk)
if total:
pct = min(int(downloaded * 100 / total), 100)
if pct != last_pct and pct % 5 == 0:
logger.info(f"Download progress: {pct}%")
last_pct = pct
tmp.rename(path)
logger.info(f"Download complete -> {MODEL_PATH}")
def _load_model_background() -> None:
"""Download (if needed) then load the model. Runs in a daemon thread."""
global _llm, _llm_error
try:
_download_model()
logger.info(f"Loading model into memory from {MODEL_PATH} ...")
llm = Llama(
model_path=MODEL_PATH,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_batch=N_BATCH,
n_gpu_layers=0, # CPU only
verbose=VERBOSE,
chat_format="chatml", # Qwen3 uses ChatML
)
with _llm_lock:
_llm = llm
logger.info("Model loaded and ready")
except Exception as exc:
_llm_error = str(exc)
logger.error(f"Failed to load model: {exc}")
finally:
_llm_ready.set()
def _get_llm() -> Llama:
"""Return the loaded model or raise a 503 if it's not ready yet."""
if not _llm_ready.is_set():
raise HTTPException(
status_code=503,
detail="Model is still loading (or downloading). "
"Check /health for status and retry in a moment.",
)
if _llm_error:
raise HTTPException(
status_code=500,
detail=f"Model failed to load: {_llm_error}",
)
return _llm
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(title="Qwen3-14B OpenAI-compatible API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Kick off model download + load in a background thread so the server
starts immediately and stays responsive during the (long) load phase."""
t = threading.Thread(target=_load_model_background, daemon=True)
t.start()
logger.info("Server is up. Model loading in background -- see /health for status.")
# ---------------------------------------------------------------------------
# Pydantic schemas (OpenAI-compatible subset)
# ---------------------------------------------------------------------------
class Message(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
model: str = MODEL_ID
messages: List[Message]
max_tokens: Optional[int] = Field(default=1024, ge=1, le=8192)
temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=0.9, ge=0.0, le=1.0)
stream: Optional[bool] = False
stop: Optional[List[str]] = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_chunk(delta_content: str, finish_reason: Optional[str], request_id: str) -> str:
chunk = {
"id": request_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {"content": delta_content} if delta_content else {},
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(chunk)}\n\n"
async def _stream_response(llm: Llama, messages: list, request: ChatCompletionRequest, request_id: str) -> AsyncIterator[str]:
loop = asyncio.get_event_loop()
q: "queue.Queue" = queue.Queue()
_SENTINEL = object()
def _produce():
# Holds the lock for the *entire* generation, not just creation --
# this is the only place token generation actually happens, and it
# must never overlap with another request's call into the same
# Llama instance.
with _inference_lock:
try:
gen = llm.create_chat_completion(
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop or [],
stream=True,
)
for chunk in gen:
q.put(chunk)
except Exception as exc: # surfaced to the consumer below
q.put(exc)
finally:
q.put(_SENTINEL)
# Fire-and-forget: runs on a worker thread, the consumer below just
# drains the queue without ever blocking the asyncio event loop.
loop.run_in_executor(None, _produce)
yield _make_chunk("", None, request_id) # opening delta
while True:
item = await loop.run_in_executor(None, q.get)
if item is _SENTINEL:
break
if isinstance(item, Exception):
raise item
choice = item["choices"][0]
delta = choice.get("delta", {})
content = delta.get("content", "")
finish = choice.get("finish_reason")
if content:
yield _make_chunk(content, None, request_id)
if finish:
yield _make_chunk("", finish, request_id)
break
yield "data: [DONE]\n\n"
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/")
async def root():
ready = _llm_ready.is_set() and _llm is not None
return {"status": "ready" if ready else "loading", "model": MODEL_ID}
@app.get("/health")
async def health():
if not _llm_ready.is_set():
return {"status": "loading", "model": MODEL_ID, "ready": False}
if _llm_error:
return {"status": "error", "error": _llm_error, "ready": False}
return {"status": "healthy", "model": MODEL_ID, "ready": True}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [
{
"id": MODEL_ID,
"object": "model",
"created": 1700000000,
"owned_by": "local",
}
],
}
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
llm = _get_llm() # raises 503/500 *before* we commit to a response,
# including for the streaming branch below
messages = [{"role": m.role, "content": m.content} for m in request.messages]
if request.stream:
request_id = f"chatcmpl-{uuid.uuid4().hex}"
return StreamingResponse(
_stream_response(llm, messages, request, request_id),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
# Non-streaming
loop = asyncio.get_event_loop()
def _run():
with _inference_lock: # never overlap with another generation call
return llm.create_chat_completion(
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop or [],
stream=False,
)
result = await loop.run_in_executor(None, _run)
choice = result["choices"][0]
usage = result.get("usage", {})
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": MODEL_ID,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": choice["message"]["content"],
},
"finish_reason": choice.get("finish_reason", "stop"),
}
],
"usage": {
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
}