Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import uuid | |
| from typing import Any, Dict, List, Literal, Optional, Union | |
| from fastapi import Depends, FastAPI, Header, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from llama_cpp import Llama | |
| from pydantic import BaseModel, Field | |
| MODEL_REPO = os.getenv("MODEL_REPO", "bartowski/google_gemma-4-E2B-it-GGUF") | |
| MODEL_FILE = os.getenv("MODEL_FILE", "google_gemma-4-E2B-it-Q4_K_M.gguf") | |
| MODEL_ID = os.getenv("MODEL_ID", "gemma-4-e2b-it-gguf") | |
| API_KEY = os.getenv("API_KEY", "") | |
| N_CTX = int(os.getenv("N_CTX", "4096")) | |
| N_THREADS = int(os.getenv("N_THREADS", str(os.cpu_count() or 4))) | |
| N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) | |
| app = FastAPI( | |
| title="GGUF OpenAI-Compatible Model Server", | |
| version="1.0.0", | |
| ) | |
| llm: Optional[Llama] = None | |
| class ChatMessage(BaseModel): | |
| role: Literal["system", "user", "assistant"] | |
| content: str | |
| class ChatCompletionRequest(BaseModel): | |
| model: Optional[str] = MODEL_ID | |
| messages: List[ChatMessage] | |
| max_tokens: int = Field(default=512, ge=1, le=4096) | |
| temperature: float = Field(default=0.7, ge=0.0, le=2.0) | |
| top_p: float = Field(default=0.95, ge=0.0, le=1.0) | |
| stream: bool = False | |
| stop: Optional[Union[str, List[str]]] = None | |
| class QueryRequest(BaseModel): | |
| prompt: str | |
| max_tokens: int = Field(default=512, ge=1, le=4096) | |
| temperature: float = Field(default=0.7, ge=0.0, le=2.0) | |
| def require_api_key(authorization: Optional[str] = Header(default=None)) -> None: | |
| if not API_KEY: | |
| return | |
| expected = f"Bearer {API_KEY}" | |
| if authorization != expected: | |
| raise HTTPException(status_code=401, detail="Invalid or missing API key") | |
| def get_llm() -> Llama: | |
| global llm | |
| if llm is None: | |
| llm = Llama.from_pretrained( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILE, | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| n_gpu_layers=N_GPU_LAYERS, | |
| verbose=False, | |
| ) | |
| return llm | |
| def usage_from_response(response: Dict[str, Any]) -> Dict[str, int]: | |
| usage = response.get("usage") or {} | |
| prompt_tokens = int(usage.get("prompt_tokens", 0)) | |
| completion_tokens = int(usage.get("completion_tokens", 0)) | |
| return { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": completion_tokens, | |
| "total_tokens": prompt_tokens + completion_tokens, | |
| } | |
| def root() -> Dict[str, str]: | |
| return { | |
| "message": "GGUF model server is running", | |
| "chat_completions": "/v1/chat/completions", | |
| "models": "/v1/models", | |
| "health": "/health", | |
| } | |
| def health() -> Dict[str, Any]: | |
| return { | |
| "status": "ok", | |
| "model_id": MODEL_ID, | |
| "model_repo": MODEL_REPO, | |
| "model_file": MODEL_FILE, | |
| "loaded": llm is not None, | |
| } | |
| def list_models() -> Dict[str, Any]: | |
| return { | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": MODEL_ID, | |
| "object": "model", | |
| "created": 0, | |
| "owned_by": "huggingface-space", | |
| } | |
| ], | |
| } | |
| def query(req: QueryRequest) -> Dict[str, str]: | |
| model = get_llm() | |
| response = model.create_chat_completion( | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": req.prompt}, | |
| ], | |
| max_tokens=req.max_tokens, | |
| temperature=req.temperature, | |
| ) | |
| return {"answer": response["choices"][0]["message"]["content"]} | |
| def create_chat_completion(req: ChatCompletionRequest) -> Any: | |
| model = get_llm() | |
| messages = [message.model_dump() for message in req.messages] | |
| request_args = { | |
| "messages": messages, | |
| "max_tokens": req.max_tokens, | |
| "temperature": req.temperature, | |
| "top_p": req.top_p, | |
| "stop": req.stop, | |
| "stream": req.stream, | |
| } | |
| if req.stream: | |
| return StreamingResponse( | |
| stream_chat_completion(model, request_args), | |
| media_type="text/event-stream", | |
| ) | |
| response = model.create_chat_completion(**request_args) | |
| content = response["choices"][0]["message"]["content"] | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4().hex}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": req.model or MODEL_ID, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": content, | |
| }, | |
| "finish_reason": response["choices"][0].get("finish_reason", "stop"), | |
| } | |
| ], | |
| "usage": usage_from_response(response), | |
| } | |
| def stream_chat_completion(model: Llama, request_args: Dict[str, Any]): | |
| import json | |
| request_args["stream"] = True | |
| completion_id = f"chatcmpl-{uuid.uuid4().hex}" | |
| created = int(time.time()) | |
| for chunk in model.create_chat_completion(**request_args): | |
| choice = chunk["choices"][0] | |
| delta = choice.get("delta", {}) | |
| payload = { | |
| "id": completion_id, | |
| "object": "chat.completion.chunk", | |
| "created": created, | |
| "model": MODEL_ID, | |
| "choices": [ | |
| { | |
| "index": choice.get("index", 0), | |
| "delta": delta, | |
| "finish_reason": choice.get("finish_reason"), | |
| } | |
| ], | |
| } | |
| yield f"data: {json.dumps(payload)}\n\n" | |
| yield "data: [DONE]\n\n" | |