GopinathV19's picture
Deploy v1.1
c1f4a8f
import os
import time
import uuid
from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.responses import StreamingResponse
from llama_cpp import Llama
from pydantic import BaseModel, Field
MODEL_REPO = os.getenv("MODEL_REPO", "bartowski/google_gemma-4-E2B-it-GGUF")
MODEL_FILE = os.getenv("MODEL_FILE", "google_gemma-4-E2B-it-Q4_K_M.gguf")
MODEL_ID = os.getenv("MODEL_ID", "gemma-4-e2b-it-gguf")
API_KEY = os.getenv("API_KEY", "")
N_CTX = int(os.getenv("N_CTX", "4096"))
N_THREADS = int(os.getenv("N_THREADS", str(os.cpu_count() or 4)))
N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0"))
app = FastAPI(
title="GGUF OpenAI-Compatible Model Server",
version="1.0.0",
)
llm: Optional[Llama] = None
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
class ChatCompletionRequest(BaseModel):
model: Optional[str] = MODEL_ID
messages: List[ChatMessage]
max_tokens: int = Field(default=512, ge=1, le=4096)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=0.95, ge=0.0, le=1.0)
stream: bool = False
stop: Optional[Union[str, List[str]]] = None
class QueryRequest(BaseModel):
prompt: str
max_tokens: int = Field(default=512, ge=1, le=4096)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
def require_api_key(authorization: Optional[str] = Header(default=None)) -> None:
if not API_KEY:
return
expected = f"Bearer {API_KEY}"
if authorization != expected:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
def get_llm() -> Llama:
global llm
if llm is None:
llm = Llama.from_pretrained(
repo_id=MODEL_REPO,
filename=MODEL_FILE,
n_ctx=N_CTX,
n_threads=N_THREADS,
n_gpu_layers=N_GPU_LAYERS,
verbose=False,
)
return llm
def usage_from_response(response: Dict[str, Any]) -> Dict[str, int]:
usage = response.get("usage") or {}
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
return {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
@app.get("/")
def root() -> Dict[str, str]:
return {
"message": "GGUF model server is running",
"chat_completions": "/v1/chat/completions",
"models": "/v1/models",
"health": "/health",
}
@app.get("/health")
def health() -> Dict[str, Any]:
return {
"status": "ok",
"model_id": MODEL_ID,
"model_repo": MODEL_REPO,
"model_file": MODEL_FILE,
"loaded": llm is not None,
}
@app.get("/v1/models", dependencies=[Depends(require_api_key)])
def list_models() -> Dict[str, Any]:
return {
"object": "list",
"data": [
{
"id": MODEL_ID,
"object": "model",
"created": 0,
"owned_by": "huggingface-space",
}
],
}
@app.post("/query", dependencies=[Depends(require_api_key)])
def query(req: QueryRequest) -> Dict[str, str]:
model = get_llm()
response = model.create_chat_completion(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": req.prompt},
],
max_tokens=req.max_tokens,
temperature=req.temperature,
)
return {"answer": response["choices"][0]["message"]["content"]}
@app.post("/v1/chat/completions", dependencies=[Depends(require_api_key)])
def create_chat_completion(req: ChatCompletionRequest) -> Any:
model = get_llm()
messages = [message.model_dump() for message in req.messages]
request_args = {
"messages": messages,
"max_tokens": req.max_tokens,
"temperature": req.temperature,
"top_p": req.top_p,
"stop": req.stop,
"stream": req.stream,
}
if req.stream:
return StreamingResponse(
stream_chat_completion(model, request_args),
media_type="text/event-stream",
)
response = model.create_chat_completion(**request_args)
content = response["choices"][0]["message"]["content"]
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": int(time.time()),
"model": req.model or MODEL_ID,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"finish_reason": response["choices"][0].get("finish_reason", "stop"),
}
],
"usage": usage_from_response(response),
}
def stream_chat_completion(model: Llama, request_args: Dict[str, Any]):
import json
request_args["stream"] = True
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
created = int(time.time())
for chunk in model.create_chat_completion(**request_args):
choice = chunk["choices"][0]
delta = choice.get("delta", {})
payload = {
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": choice.get("index", 0),
"delta": delta,
"finish_reason": choice.get("finish_reason"),
}
],
}
yield f"data: {json.dumps(payload)}\n\n"
yield "data: [DONE]\n\n"