""" OpenAI-compatible FastAPI wrapper for Qwen3-14B (GGUF / llama-cpp-python) Endpoints: GET /v1/models, POST /v1/chat/completions Supports streaming (SSE) and non-streaming responses. Model is downloaded automatically on first boot if not already present. """ import os import time import uuid import json import queue import asyncio import logging import threading from pathlib import Path from typing import AsyncIterator, List, Optional from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from llama_cpp import Llama # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Config (override via environment variables) # --------------------------------------------------------------------------- MODEL_PATH = os.environ.get("MODEL_PATH", "/models/qwen3-14b-q4_k_m.gguf") MODEL_URL = os.environ.get("MODEL_URL", "https://huggingface.co/bartowski/Qwen3-14B-GGUF/resolve/main/Qwen3-14B-Q4_K_M.gguf") MODEL_ID = os.environ.get("MODEL_ID", "qwen3-14b") N_CTX = int(os.environ.get("N_CTX", "4096")) N_THREADS = int(os.environ.get("N_THREADS", str(os.cpu_count() or 4))) N_BATCH = int(os.environ.get("N_BATCH", "512")) VERBOSE = os.environ.get("VERBOSE", "false").lower() == "true" HF_TOKEN = os.environ.get("HF_TOKEN", "") # --------------------------------------------------------------------------- # Lazy model holder # --------------------------------------------------------------------------- _llm: Optional[Llama] = None _llm_lock = threading.Lock() _llm_ready = threading.Event() # set once the model is loaded _llm_error: Optional[str] = None # set if loading failed # llama.cpp contexts are NOT safe to call concurrently -- there's a single # KV cache / sampling state shared by every call into the same Llama # instance. This lock serializes all generation calls (streaming and # non-streaming) so two requests can never run inference at the same time. _inference_lock = threading.Lock() def _download_model() -> None: """Download the GGUF file from MODEL_URL if MODEL_PATH doesn't exist.""" path = Path(MODEL_PATH) if path.exists(): logger.info(f"Model already present at {MODEL_PATH}") return path.parent.mkdir(parents=True, exist_ok=True) # Clean up any stale partial download from a previous crashed attempt tmp = Path(str(MODEL_PATH) + ".part") if tmp.exists(): logger.warning(f"Removing stale partial download: {tmp}") tmp.unlink() logger.info(f"Model not found — downloading from {MODEL_URL} ...") logger.info("This will take a while on first boot (file is ~9 GB).") import urllib.request headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} if not HF_TOKEN: logger.warning("HF_TOKEN not set — download may fail for gated models.") req = urllib.request.Request(MODEL_URL, headers=headers) with urllib.request.urlopen(req) as response, open(tmp, "wb") as out: total = int(response.headers.get("Content-Length", 0)) downloaded = 0 last_pct = -1 while chunk := response.read(1 << 20): # 1 MB chunks out.write(chunk) downloaded += len(chunk) if total: pct = min(int(downloaded * 100 / total), 100) if pct != last_pct and pct % 5 == 0: logger.info(f"Download progress: {pct}%") last_pct = pct tmp.rename(path) logger.info(f"Download complete -> {MODEL_PATH}") def _load_model_background() -> None: """Download (if needed) then load the model. Runs in a daemon thread.""" global _llm, _llm_error try: _download_model() logger.info(f"Loading model into memory from {MODEL_PATH} ...") llm = Llama( model_path=MODEL_PATH, n_ctx=N_CTX, n_threads=N_THREADS, n_batch=N_BATCH, n_gpu_layers=0, # CPU only verbose=VERBOSE, chat_format="chatml", # Qwen3 uses ChatML ) with _llm_lock: _llm = llm logger.info("Model loaded and ready") except Exception as exc: _llm_error = str(exc) logger.error(f"Failed to load model: {exc}") finally: _llm_ready.set() def _get_llm() -> Llama: """Return the loaded model or raise a 503 if it's not ready yet.""" if not _llm_ready.is_set(): raise HTTPException( status_code=503, detail="Model is still loading (or downloading). " "Check /health for status and retry in a moment.", ) if _llm_error: raise HTTPException( status_code=500, detail=f"Model failed to load: {_llm_error}", ) return _llm # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI(title="Qwen3-14B OpenAI-compatible API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): """Kick off model download + load in a background thread so the server starts immediately and stays responsive during the (long) load phase.""" t = threading.Thread(target=_load_model_background, daemon=True) t.start() logger.info("Server is up. Model loading in background -- see /health for status.") # --------------------------------------------------------------------------- # Pydantic schemas (OpenAI-compatible subset) # --------------------------------------------------------------------------- class Message(BaseModel): role: str content: str class ChatCompletionRequest(BaseModel): model: str = MODEL_ID messages: List[Message] max_tokens: Optional[int] = Field(default=1024, ge=1, le=8192) temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) top_p: Optional[float] = Field(default=0.9, ge=0.0, le=1.0) stream: Optional[bool] = False stop: Optional[List[str]] = None # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_chunk(delta_content: str, finish_reason: Optional[str], request_id: str) -> str: chunk = { "id": request_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": MODEL_ID, "choices": [ { "index": 0, "delta": {"content": delta_content} if delta_content else {}, "finish_reason": finish_reason, } ], } return f"data: {json.dumps(chunk)}\n\n" async def _stream_response(llm: Llama, messages: list, request: ChatCompletionRequest, request_id: str) -> AsyncIterator[str]: loop = asyncio.get_event_loop() q: "queue.Queue" = queue.Queue() _SENTINEL = object() def _produce(): # Holds the lock for the *entire* generation, not just creation -- # this is the only place token generation actually happens, and it # must never overlap with another request's call into the same # Llama instance. with _inference_lock: try: gen = llm.create_chat_completion( messages=messages, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, stop=request.stop or [], stream=True, ) for chunk in gen: q.put(chunk) except Exception as exc: # surfaced to the consumer below q.put(exc) finally: q.put(_SENTINEL) # Fire-and-forget: runs on a worker thread, the consumer below just # drains the queue without ever blocking the asyncio event loop. loop.run_in_executor(None, _produce) yield _make_chunk("", None, request_id) # opening delta while True: item = await loop.run_in_executor(None, q.get) if item is _SENTINEL: break if isinstance(item, Exception): raise item choice = item["choices"][0] delta = choice.get("delta", {}) content = delta.get("content", "") finish = choice.get("finish_reason") if content: yield _make_chunk(content, None, request_id) if finish: yield _make_chunk("", finish, request_id) break yield "data: [DONE]\n\n" # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/") async def root(): ready = _llm_ready.is_set() and _llm is not None return {"status": "ready" if ready else "loading", "model": MODEL_ID} @app.get("/health") async def health(): if not _llm_ready.is_set(): return {"status": "loading", "model": MODEL_ID, "ready": False} if _llm_error: return {"status": "error", "error": _llm_error, "ready": False} return {"status": "healthy", "model": MODEL_ID, "ready": True} @app.get("/v1/models") async def list_models(): return { "object": "list", "data": [ { "id": MODEL_ID, "object": "model", "created": 1700000000, "owned_by": "local", } ], } @app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): llm = _get_llm() # raises 503/500 *before* we commit to a response, # including for the streaming branch below messages = [{"role": m.role, "content": m.content} for m in request.messages] if request.stream: request_id = f"chatcmpl-{uuid.uuid4().hex}" return StreamingResponse( _stream_response(llm, messages, request, request_id), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, ) # Non-streaming loop = asyncio.get_event_loop() def _run(): with _inference_lock: # never overlap with another generation call return llm.create_chat_completion( messages=messages, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, stop=request.stop or [], stream=False, ) result = await loop.run_in_executor(None, _run) choice = result["choices"][0] usage = result.get("usage", {}) return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": MODEL_ID, "choices": [ { "index": 0, "message": { "role": "assistant", "content": choice["message"]["content"], }, "finish_reason": choice.get("finish_reason", "stop"), } ], "usage": { "prompt_tokens": usage.get("prompt_tokens", 0), "completion_tokens": usage.get("completion_tokens", 0), "total_tokens": usage.get("total_tokens", 0), }, }