|
|
|
|
|
import os |
|
|
import uvicorn |
|
|
import uuid |
|
|
import time |
|
|
import json |
|
|
from datetime import datetime |
|
|
from typing import Optional, List, Union, Literal |
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends, status |
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from fastapi.responses import StreamingResponse |
|
|
from pydantic import BaseModel, Field |
|
|
from llama_cpp import Llama |
|
|
|
|
|
|
|
|
VALID_API_KEYS = { |
|
|
"sk-adminkey02", |
|
|
"sk-testkey123", |
|
|
"sk-userkey456", |
|
|
"sk-demokey789" |
|
|
} |
|
|
MODEL_PATH = "capybarahermes-2.5-mistral-7b.Q5_K_M.gguf" |
|
|
MODEL_NAME = "capybarahermes-2.5-mistral-7b" |
|
|
|
|
|
|
|
|
llm = None |
|
|
security = HTTPBearer() |
|
|
|
|
|
|
|
|
|
|
|
class Message(BaseModel): |
|
|
role: Literal["system", "user", "assistant"] |
|
|
content: str |
|
|
|
|
|
class ChatCompletionRequest(BaseModel): |
|
|
model: str = MODEL_NAME |
|
|
messages: List[Message] |
|
|
max_tokens: Optional[int] = 512 |
|
|
temperature: Optional[float] = 0.7 |
|
|
top_p: Optional[float] = 0.9 |
|
|
n: Optional[int] = 1 |
|
|
stream: Optional[bool] = False |
|
|
stop: Optional[Union[str, List[str]]] = None |
|
|
|
|
|
class ChatCompletionChoice(BaseModel): |
|
|
index: int |
|
|
message: Message |
|
|
finish_reason: Optional[Literal["stop", "length"]] = None |
|
|
|
|
|
class Usage(BaseModel): |
|
|
prompt_tokens: int |
|
|
completion_tokens: int |
|
|
total_tokens: int |
|
|
|
|
|
class ChatCompletionResponse(BaseModel): |
|
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") |
|
|
object: str = "chat.completion" |
|
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
|
model: str = MODEL_NAME |
|
|
choices: List[ChatCompletionChoice] |
|
|
usage: Usage |
|
|
|
|
|
class ModelData(BaseModel): |
|
|
id: str |
|
|
object: str = "model" |
|
|
created: int = Field(default_factory=lambda: int(time.time())) |
|
|
owned_by: str = "user" |
|
|
|
|
|
class ModelsResponse(BaseModel): |
|
|
object: str = "list" |
|
|
data: List[ModelData] |
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="CapybaraHermes OpenAI-Compatible API", |
|
|
description=f"An OpenAI-compatible API for the {MODEL_NAME} model.", |
|
|
version="1.0.0", |
|
|
docs_url="/v1/docs", |
|
|
redoc_url="/v1/redoc" |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)): |
|
|
if credentials.credentials not in VALID_API_KEYS: |
|
|
raise HTTPException( |
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
detail="Invalid or missing API key" |
|
|
) |
|
|
return credentials.credentials |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
def load_model(): |
|
|
global llm |
|
|
if not os.path.exists(MODEL_PATH): |
|
|
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") |
|
|
|
|
|
print("๐ Loading GGUF model...") |
|
|
llm = Llama( |
|
|
model_path=MODEL_PATH, |
|
|
n_ctx=4096, |
|
|
n_threads=2, |
|
|
n_batch=512, |
|
|
verbose=False, |
|
|
use_mlock=True, |
|
|
n_gpu_layers=0, |
|
|
) |
|
|
print("โ
Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
|
|
def format_messages(messages: List[Message]) -> str: |
|
|
"""Formats messages for the ChatML format expected by the model.""" |
|
|
formatted = "" |
|
|
for message in messages: |
|
|
formatted += f"<|im_start|>{message.role}\n{message.content}<|im_end|>\n" |
|
|
formatted += "<|im_start|>assistant\n" |
|
|
return formatted |
|
|
|
|
|
def count_tokens_rough(text: str) -> int: |
|
|
"""A rough approximation of token counting.""" |
|
|
return len(text.split()) |
|
|
|
|
|
|
|
|
|
|
|
@app.get("/v1/health") |
|
|
async def health_check(): |
|
|
"""Health check endpoint.""" |
|
|
return {"status": "healthy", "model_loaded": llm is not None} |
|
|
|
|
|
@app.get("/v1/models", response_model=ModelsResponse) |
|
|
async def list_models(api_key: str = Depends(verify_api_key)): |
|
|
"""Lists the available models.""" |
|
|
return ModelsResponse(data=[ModelData(id=MODEL_NAME)]) |
|
|
|
|
|
@app.post("/v1/chat/completions") |
|
|
async def create_chat_completion( |
|
|
request: ChatCompletionRequest, |
|
|
api_key: str = Depends(verify_api_key) |
|
|
): |
|
|
"""Creates a model response for the given chat conversation.""" |
|
|
if llm is None: |
|
|
raise HTTPException(status_code=503, detail="Model is not loaded yet") |
|
|
|
|
|
prompt = format_messages(request.messages) |
|
|
|
|
|
|
|
|
if request.stream: |
|
|
async def stream_generator(): |
|
|
completion_id = f"chatcmpl-{uuid.uuid4().hex}" |
|
|
created_time = int(time.time()) |
|
|
|
|
|
stream = llm( |
|
|
prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_p=request.top_p, |
|
|
stop=["<|im_end|>", "<|im_start|>"] + (request.stop or []), |
|
|
stream=True, |
|
|
echo=False |
|
|
) |
|
|
|
|
|
for output in stream: |
|
|
if 'choices' in output and len(output['choices']) > 0: |
|
|
delta_content = output['choices'][0].get('text', '') |
|
|
chunk = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": MODEL_NAME, |
|
|
"choices": [{"index": 0, "delta": {"content": delta_content}, "finish_reason": None}] |
|
|
} |
|
|
yield f"data: {json.dumps(chunk)}\n\n" |
|
|
|
|
|
|
|
|
final_chunk = { |
|
|
"id": completion_id, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": created_time, |
|
|
"model": MODEL_NAME, |
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}] |
|
|
} |
|
|
yield f"data: {json.dumps(final_chunk)}\n\n" |
|
|
yield "data: [DONE]\n\n" |
|
|
|
|
|
return StreamingResponse(stream_generator(), media_type="text/event-stream") |
|
|
|
|
|
|
|
|
else: |
|
|
response = llm( |
|
|
prompt, |
|
|
max_tokens=request.max_tokens, |
|
|
temperature=request.temperature, |
|
|
top_p=request.top_p, |
|
|
stop=["<|im_end|>", "<|im_start|>"] + (request.stop or []), |
|
|
echo=False |
|
|
) |
|
|
|
|
|
response_text = response['choices'][0]['text'].strip() |
|
|
|
|
|
prompt_tokens = count_tokens_rough(prompt) |
|
|
completion_tokens = count_tokens_rough(response_text) |
|
|
|
|
|
return ChatCompletionResponse( |
|
|
model=MODEL_NAME, |
|
|
choices=[ |
|
|
ChatCompletionChoice( |
|
|
index=0, |
|
|
message=Message(role="assistant", content=response_text), |
|
|
finish_reason="stop" |
|
|
) |
|
|
], |
|
|
usage=Usage( |
|
|
prompt_tokens=prompt_tokens, |
|
|
completion_tokens=completion_tokens, |
|
|
total_tokens=prompt_tokens + completion_tokens |
|
|
) |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|