Spaces:
Sleeping
Sleeping
| """ | |
| OpenAI-compatible FastAPI wrapper for Qwen3-14B (GGUF / llama-cpp-python) | |
| Endpoints: GET /v1/models, POST /v1/chat/completions | |
| Supports streaming (SSE) and non-streaming responses. | |
| Model is downloaded automatically on first boot if not already present. | |
| """ | |
| import os | |
| import time | |
| import uuid | |
| import json | |
| import queue | |
| import asyncio | |
| import logging | |
| import threading | |
| from pathlib import Path | |
| from typing import AsyncIterator, List, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from llama_cpp import Llama | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Config (override via environment variables) | |
| # --------------------------------------------------------------------------- | |
| MODEL_PATH = os.environ.get("MODEL_PATH", "/models/qwen3-14b-q4_k_m.gguf") | |
| MODEL_URL = os.environ.get("MODEL_URL", | |
| "https://huggingface.co/bartowski/Qwen3-14B-GGUF/resolve/main/Qwen3-14B-Q4_K_M.gguf") | |
| MODEL_ID = os.environ.get("MODEL_ID", "qwen3-14b") | |
| N_CTX = int(os.environ.get("N_CTX", "4096")) | |
| N_THREADS = int(os.environ.get("N_THREADS", str(os.cpu_count() or 4))) | |
| N_BATCH = int(os.environ.get("N_BATCH", "512")) | |
| VERBOSE = os.environ.get("VERBOSE", "false").lower() == "true" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # --------------------------------------------------------------------------- | |
| # Lazy model holder | |
| # --------------------------------------------------------------------------- | |
| _llm: Optional[Llama] = None | |
| _llm_lock = threading.Lock() | |
| _llm_ready = threading.Event() # set once the model is loaded | |
| _llm_error: Optional[str] = None # set if loading failed | |
| # llama.cpp contexts are NOT safe to call concurrently -- there's a single | |
| # KV cache / sampling state shared by every call into the same Llama | |
| # instance. This lock serializes all generation calls (streaming and | |
| # non-streaming) so two requests can never run inference at the same time. | |
| _inference_lock = threading.Lock() | |
| def _download_model() -> None: | |
| """Download the GGUF file from MODEL_URL if MODEL_PATH doesn't exist.""" | |
| path = Path(MODEL_PATH) | |
| if path.exists(): | |
| logger.info(f"Model already present at {MODEL_PATH}") | |
| return | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| # Clean up any stale partial download from a previous crashed attempt | |
| tmp = Path(str(MODEL_PATH) + ".part") | |
| if tmp.exists(): | |
| logger.warning(f"Removing stale partial download: {tmp}") | |
| tmp.unlink() | |
| logger.info(f"Model not found — downloading from {MODEL_URL} ...") | |
| logger.info("This will take a while on first boot (file is ~9 GB).") | |
| import urllib.request | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} | |
| if not HF_TOKEN: | |
| logger.warning("HF_TOKEN not set — download may fail for gated models.") | |
| req = urllib.request.Request(MODEL_URL, headers=headers) | |
| with urllib.request.urlopen(req) as response, open(tmp, "wb") as out: | |
| total = int(response.headers.get("Content-Length", 0)) | |
| downloaded = 0 | |
| last_pct = -1 | |
| while chunk := response.read(1 << 20): # 1 MB chunks | |
| out.write(chunk) | |
| downloaded += len(chunk) | |
| if total: | |
| pct = min(int(downloaded * 100 / total), 100) | |
| if pct != last_pct and pct % 5 == 0: | |
| logger.info(f"Download progress: {pct}%") | |
| last_pct = pct | |
| tmp.rename(path) | |
| logger.info(f"Download complete -> {MODEL_PATH}") | |
| def _load_model_background() -> None: | |
| """Download (if needed) then load the model. Runs in a daemon thread.""" | |
| global _llm, _llm_error | |
| try: | |
| _download_model() | |
| logger.info(f"Loading model into memory from {MODEL_PATH} ...") | |
| llm = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_batch=N_BATCH, | |
| n_gpu_layers=0, # CPU only | |
| verbose=VERBOSE, | |
| chat_format="chatml", # Qwen3 uses ChatML | |
| ) | |
| with _llm_lock: | |
| _llm = llm | |
| logger.info("Model loaded and ready") | |
| except Exception as exc: | |
| _llm_error = str(exc) | |
| logger.error(f"Failed to load model: {exc}") | |
| finally: | |
| _llm_ready.set() | |
| def _get_llm() -> Llama: | |
| """Return the loaded model or raise a 503 if it's not ready yet.""" | |
| if not _llm_ready.is_set(): | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Model is still loading (or downloading). " | |
| "Check /health for status and retry in a moment.", | |
| ) | |
| if _llm_error: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Model failed to load: {_llm_error}", | |
| ) | |
| return _llm | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI(title="Qwen3-14B OpenAI-compatible API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """Kick off model download + load in a background thread so the server | |
| starts immediately and stays responsive during the (long) load phase.""" | |
| t = threading.Thread(target=_load_model_background, daemon=True) | |
| t.start() | |
| logger.info("Server is up. Model loading in background -- see /health for status.") | |
| # --------------------------------------------------------------------------- | |
| # Pydantic schemas (OpenAI-compatible subset) | |
| # --------------------------------------------------------------------------- | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = MODEL_ID | |
| messages: List[Message] | |
| max_tokens: Optional[int] = Field(default=1024, ge=1, le=8192) | |
| temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) | |
| top_p: Optional[float] = Field(default=0.9, ge=0.0, le=1.0) | |
| stream: Optional[bool] = False | |
| stop: Optional[List[str]] = None | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _make_chunk(delta_content: str, finish_reason: Optional[str], request_id: str) -> str: | |
| chunk = { | |
| "id": request_id, | |
| "object": "chat.completion.chunk", | |
| "created": int(time.time()), | |
| "model": MODEL_ID, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "delta": {"content": delta_content} if delta_content else {}, | |
| "finish_reason": finish_reason, | |
| } | |
| ], | |
| } | |
| return f"data: {json.dumps(chunk)}\n\n" | |
| async def _stream_response(llm: Llama, messages: list, request: ChatCompletionRequest, request_id: str) -> AsyncIterator[str]: | |
| loop = asyncio.get_event_loop() | |
| q: "queue.Queue" = queue.Queue() | |
| _SENTINEL = object() | |
| def _produce(): | |
| # Holds the lock for the *entire* generation, not just creation -- | |
| # this is the only place token generation actually happens, and it | |
| # must never overlap with another request's call into the same | |
| # Llama instance. | |
| with _inference_lock: | |
| try: | |
| gen = llm.create_chat_completion( | |
| messages=messages, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| stop=request.stop or [], | |
| stream=True, | |
| ) | |
| for chunk in gen: | |
| q.put(chunk) | |
| except Exception as exc: # surfaced to the consumer below | |
| q.put(exc) | |
| finally: | |
| q.put(_SENTINEL) | |
| # Fire-and-forget: runs on a worker thread, the consumer below just | |
| # drains the queue without ever blocking the asyncio event loop. | |
| loop.run_in_executor(None, _produce) | |
| yield _make_chunk("", None, request_id) # opening delta | |
| while True: | |
| item = await loop.run_in_executor(None, q.get) | |
| if item is _SENTINEL: | |
| break | |
| if isinstance(item, Exception): | |
| raise item | |
| choice = item["choices"][0] | |
| delta = choice.get("delta", {}) | |
| content = delta.get("content", "") | |
| finish = choice.get("finish_reason") | |
| if content: | |
| yield _make_chunk(content, None, request_id) | |
| if finish: | |
| yield _make_chunk("", finish, request_id) | |
| break | |
| yield "data: [DONE]\n\n" | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| async def root(): | |
| ready = _llm_ready.is_set() and _llm is not None | |
| return {"status": "ready" if ready else "loading", "model": MODEL_ID} | |
| async def health(): | |
| if not _llm_ready.is_set(): | |
| return {"status": "loading", "model": MODEL_ID, "ready": False} | |
| if _llm_error: | |
| return {"status": "error", "error": _llm_error, "ready": False} | |
| return {"status": "healthy", "model": MODEL_ID, "ready": True} | |
| async def list_models(): | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": MODEL_ID, | |
| "object": "model", | |
| "created": 1700000000, | |
| "owned_by": "local", | |
| } | |
| ], | |
| } | |
| async def chat_completions(request: ChatCompletionRequest): | |
| llm = _get_llm() # raises 503/500 *before* we commit to a response, | |
| # including for the streaming branch below | |
| messages = [{"role": m.role, "content": m.content} for m in request.messages] | |
| if request.stream: | |
| request_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| return StreamingResponse( | |
| _stream_response(llm, messages, request, request_id), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |
| # Non-streaming | |
| loop = asyncio.get_event_loop() | |
| def _run(): | |
| with _inference_lock: # never overlap with another generation call | |
| return llm.create_chat_completion( | |
| messages=messages, | |
| max_tokens=request.max_tokens, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| stop=request.stop or [], | |
| stream=False, | |
| ) | |
| result = await loop.run_in_executor(None, _run) | |
| choice = result["choices"][0] | |
| usage = result.get("usage", {}) | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4().hex}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": MODEL_ID, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": choice["message"]["content"], | |
| }, | |
| "finish_reason": choice.get("finish_reason", "stop"), | |
| } | |
| ], | |
| "usage": { | |
| "prompt_tokens": usage.get("prompt_tokens", 0), | |
| "completion_tokens": usage.get("completion_tokens", 0), | |
| "total_tokens": usage.get("total_tokens", 0), | |
| }, | |
| } | |