import os import time import uuid from typing import Any, Dict, List, Literal, Optional, Union from fastapi import Depends, FastAPI, Header, HTTPException from fastapi.responses import StreamingResponse from llama_cpp import Llama from pydantic import BaseModel, Field MODEL_REPO = os.getenv("MODEL_REPO", "bartowski/google_gemma-4-E2B-it-GGUF") MODEL_FILE = os.getenv("MODEL_FILE", "google_gemma-4-E2B-it-Q4_K_M.gguf") MODEL_ID = os.getenv("MODEL_ID", "gemma-4-e2b-it-gguf") API_KEY = os.getenv("API_KEY", "") N_CTX = int(os.getenv("N_CTX", "4096")) N_THREADS = int(os.getenv("N_THREADS", str(os.cpu_count() or 4))) N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) app = FastAPI( title="GGUF OpenAI-Compatible Model Server", version="1.0.0", ) llm: Optional[Llama] = None class ChatMessage(BaseModel): role: Literal["system", "user", "assistant"] content: str class ChatCompletionRequest(BaseModel): model: Optional[str] = MODEL_ID messages: List[ChatMessage] max_tokens: int = Field(default=512, ge=1, le=4096) temperature: float = Field(default=0.7, ge=0.0, le=2.0) top_p: float = Field(default=0.95, ge=0.0, le=1.0) stream: bool = False stop: Optional[Union[str, List[str]]] = None class QueryRequest(BaseModel): prompt: str max_tokens: int = Field(default=512, ge=1, le=4096) temperature: float = Field(default=0.7, ge=0.0, le=2.0) def require_api_key(authorization: Optional[str] = Header(default=None)) -> None: if not API_KEY: return expected = f"Bearer {API_KEY}" if authorization != expected: raise HTTPException(status_code=401, detail="Invalid or missing API key") def get_llm() -> Llama: global llm if llm is None: llm = Llama.from_pretrained( repo_id=MODEL_REPO, filename=MODEL_FILE, n_ctx=N_CTX, n_threads=N_THREADS, n_gpu_layers=N_GPU_LAYERS, verbose=False, ) return llm def usage_from_response(response: Dict[str, Any]) -> Dict[str, int]: usage = response.get("usage") or {} prompt_tokens = int(usage.get("prompt_tokens", 0)) completion_tokens = int(usage.get("completion_tokens", 0)) return { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, } @app.get("/") def root() -> Dict[str, str]: return { "message": "GGUF model server is running", "chat_completions": "/v1/chat/completions", "models": "/v1/models", "health": "/health", } @app.get("/health") def health() -> Dict[str, Any]: return { "status": "ok", "model_id": MODEL_ID, "model_repo": MODEL_REPO, "model_file": MODEL_FILE, "loaded": llm is not None, } @app.get("/v1/models", dependencies=[Depends(require_api_key)]) def list_models() -> Dict[str, Any]: return { "object": "list", "data": [ { "id": MODEL_ID, "object": "model", "created": 0, "owned_by": "huggingface-space", } ], } @app.post("/query", dependencies=[Depends(require_api_key)]) def query(req: QueryRequest) -> Dict[str, str]: model = get_llm() response = model.create_chat_completion( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": req.prompt}, ], max_tokens=req.max_tokens, temperature=req.temperature, ) return {"answer": response["choices"][0]["message"]["content"]} @app.post("/v1/chat/completions", dependencies=[Depends(require_api_key)]) def create_chat_completion(req: ChatCompletionRequest) -> Any: model = get_llm() messages = [message.model_dump() for message in req.messages] request_args = { "messages": messages, "max_tokens": req.max_tokens, "temperature": req.temperature, "top_p": req.top_p, "stop": req.stop, "stream": req.stream, } if req.stream: return StreamingResponse( stream_chat_completion(model, request_args), media_type="text/event-stream", ) response = model.create_chat_completion(**request_args) content = response["choices"][0]["message"]["content"] return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": req.model or MODEL_ID, "choices": [ { "index": 0, "message": { "role": "assistant", "content": content, }, "finish_reason": response["choices"][0].get("finish_reason", "stop"), } ], "usage": usage_from_response(response), } def stream_chat_completion(model: Llama, request_args: Dict[str, Any]): import json request_args["stream"] = True completion_id = f"chatcmpl-{uuid.uuid4().hex}" created = int(time.time()) for chunk in model.create_chat_completion(**request_args): choice = chunk["choices"][0] delta = choice.get("delta", {}) payload = { "id": completion_id, "object": "chat.completion.chunk", "created": created, "model": MODEL_ID, "choices": [ { "index": choice.get("index", 0), "delta": delta, "finish_reason": choice.get("finish_reason"), } ], } yield f"data: {json.dumps(payload)}\n\n" yield "data: [DONE]\n\n"