| from __future__ import annotations |
|
|
| import asyncio |
| import json |
| import logging |
| import os |
| import time |
| import uuid |
| from contextlib import asynccontextmanager |
| from typing import Dict, List, Optional, Union, Any |
|
|
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse, StreamingResponse |
| from huggingface_hub import hf_hub_download |
| from pydantic import BaseModel, Field, ValidationError |
| from llama_cpp import Llama |
|
|
| |
| DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL_NAME", "bonsai-1.7b") |
| LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models") |
| MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256")) |
| API_KEY = os.getenv("API_KEY", None) |
| HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
| |
| N_CTX = int(os.getenv("N_CTX", "4096")) |
| N_THREADS = int(os.getenv("N_THREADS", "4")) |
| N_BATCH = int(os.getenv("N_BATCH", "512")) |
|
|
| |
| MODEL_REGISTRY: Dict[str, Dict[str, str]] = { |
| "bonsai-1.7b": { |
| "repo_id": "lilyanatia/Bonsai-1.7B-requantized", |
| "filename": "Bonsai-1.7B-IQ1_S.gguf", |
| }, |
| "bonsai-4b": { |
| "repo_id": "lilyanatia/Bonsai-4B-requantized", |
| "filename": "Bonsai-4B-IQ1_S.gguf", |
| }, |
| "bonsai-8b": { |
| "repo_id": "lilyanatia/Bonsai-8B-requantized", |
| "filename": "Bonsai-8B-IQ1_S.gguf", |
| }, |
| } |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("uvicorn.error") |
|
|
| |
| class Message(BaseModel): |
| role: str = Field(..., pattern="^(system|user|assistant|tool)$") |
| content: Optional[str] = None |
| tool_calls: Optional[List[Dict[str, Any]]] = None |
| tool_call_id: Optional[str] = None |
| name: Optional[str] = None |
|
|
| class ToolFunction(BaseModel): |
| name: str |
| description: Optional[str] = None |
| parameters: Optional[Dict[str, Any]] = None |
|
|
| class Tool(BaseModel): |
| type: str = "function" |
| function: ToolFunction |
|
|
| class ChatCompletionRequest(BaseModel): |
| messages: List[Message] |
| model: str = Field(default=DEFAULT_MODEL_NAME) |
| max_tokens: int = Field(default=MAX_NEW_TOKENS_DEFAULT, ge=1, le=2048) |
| temperature: float = Field(default=0.7, ge=0.0, le=2.0) |
| top_p: float = Field(default=0.95, gt=0.0, le=1.0) |
| stream: bool = False |
| stop: Optional[Union[str, List[str]]] = None |
| tools: Optional[List[Tool]] = None |
| tool_choice: Optional[Union[str, Dict[str, Any]]] = None |
| response_format: Optional[Dict[str, str]] = None |
|
|
| class ChatCompletionResponseChoice(BaseModel): |
| index: int |
| message: Message |
| finish_reason: str |
|
|
| class Usage(BaseModel): |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
|
|
| class ChatCompletionResponse(BaseModel): |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[ChatCompletionResponseChoice] |
| usage: Usage |
|
|
| class ModelInfo(BaseModel): |
| id: str |
| object: str = "model" |
| created: int |
| owned_by: str = "lilyanatia" |
|
|
| class ModelListResponse(BaseModel): |
| object: str = "list" |
| data: List[ModelInfo] |
|
|
| class ErrorResponse(BaseModel): |
| error: str |
| detail: Optional[str] = None |
|
|
| |
| current_model_name: Optional[str] = None |
| llm: Optional[Llama] = None |
| model_load_error: Optional[str] = None |
| MODEL_LOCK = asyncio.Lock() |
| DOWNLOADED_MODELS = set() |
|
|
| |
| def _verify_api_key(request: Request) -> None: |
| if API_KEY is None: |
| return |
| auth = request.headers.get("X-API-Key") |
| if not auth or auth != API_KEY: |
| raise HTTPException(status_code=401, detail="Invalid or missing API key") |
|
|
| def _download_model(model_name: str) -> str: |
| """Downloads a model if it's not already present.""" |
| if model_name not in MODEL_REGISTRY: |
| raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found in registry.") |
| |
| model_info = MODEL_REGISTRY[model_name] |
| repo_id = model_info["repo_id"] |
| filename = model_info["filename"] |
| os.makedirs(LOCAL_MODEL_DIR, exist_ok=True) |
| local_path = os.path.join(LOCAL_MODEL_DIR, filename) |
|
|
| if os.path.exists(local_path): |
| logger.info(f"Model '{model_name}' already downloaded at {local_path}") |
| return local_path |
|
|
| logger.info(f"Downloading model '{model_name}' from {repo_id}/{filename}...") |
| try: |
| hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| local_dir=LOCAL_MODEL_DIR, |
| token=HF_TOKEN, |
| ) |
| logger.info(f"Model '{model_name}' downloaded successfully.") |
| return local_path |
| except Exception as e: |
| logger.error(f"Model download failed for '{model_name}': {e}") |
| raise HTTPException(status_code=500, detail=f"Failed to download model: {str(e)}") |
|
|
| async def _precache_all_models(): |
| """Downloads all models in the registry at startup.""" |
| logger.info("Pre-caching all models in registry...") |
| download_tasks = [] |
| for model_name in MODEL_REGISTRY.keys(): |
| download_tasks.append(asyncio.to_thread(_download_model, model_name)) |
| |
| results = await asyncio.gather(*download_tasks, return_exceptions=True) |
| for model_name, result in zip(MODEL_REGISTRY.keys(), results): |
| if isinstance(result, Exception): |
| logger.error(f"Failed to pre-cache model '{model_name}': {result}") |
| else: |
| DOWNLOADED_MODELS.add(model_name) |
| logger.info(f"Model '{model_name}' is ready.") |
| |
| logger.info(f"Pre-caching complete. {len(DOWNLOADED_MODELS)}/{len(MODEL_REGISTRY)} models cached.") |
|
|
| async def _ensure_model_loaded(model_name: str): |
| """Loads the specified model, downloading it first if necessary.""" |
| global llm, current_model_name, model_load_error |
| async with MODEL_LOCK: |
| if current_model_name == model_name and llm is not None: |
| return |
|
|
| if llm is not None: |
| logger.info(f"Unloading previous model '{current_model_name}'...") |
| del llm |
| llm = None |
| current_model_name = None |
|
|
| try: |
| model_path = _download_model(model_name) |
| llm = Llama( |
| model_path=model_path, |
| n_ctx=N_CTX, |
| n_threads=N_THREADS, |
| n_batch=N_BATCH, |
| verbose=False, |
| ) |
| current_model_name = model_name |
| logger.info(f"Model '{model_name}' loaded successfully.") |
| except Exception as e: |
| model_load_error = str(e) |
| logger.exception(f"Model loading failed for '{model_name}'") |
| raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}") |
|
|
| def _build_chat_prompt(messages: List[Message]) -> List[Dict[str, Any]]: |
| """Convert Pydantic messages to dict format for llama.cpp.""" |
| formatted = [] |
| for msg in messages: |
| msg_dict = {"role": msg.role, "content": msg.content} |
| if msg.tool_calls: |
| msg_dict["tool_calls"] = msg.tool_calls |
| if msg.tool_call_id: |
| msg_dict["tool_call_id"] = msg.tool_call_id |
| if msg.name: |
| msg_dict["name"] = msg.name |
| formatted.append(msg_dict) |
| return formatted |
|
|
| def _convert_tools(tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: |
| """Convert Pydantic tools to dict format for llama.cpp.""" |
| if not tools: |
| return None |
| return [tool.model_dump() for tool in tools] |
|
|
| async def _generate_full( |
| prompt: List[Dict[str, Any]], |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| stop_sequences: Optional[List[str]] = None, |
| tools: Optional[List[Dict[str, Any]]] = None, |
| tool_choice: Optional[Union[str, Dict[str, Any]]] = None, |
| response_format: Optional[Dict[str, str]] = None, |
| ) -> Dict[str, Any]: |
| if llm is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| kwargs = { |
| "messages": prompt, |
| "max_tokens": max_new_tokens, |
| "temperature": temperature, |
| "top_p": top_p, |
| "stop": stop_sequences, |
| "stream": False, |
| } |
| if tools: |
| kwargs["tools"] = tools |
| if tool_choice: |
| kwargs["tool_choice"] = tool_choice |
| if response_format: |
| kwargs["response_format"] = response_format |
| |
| result = await asyncio.to_thread(lambda: llm.create_chat_completion(**kwargs)) |
| return result |
|
|
| async def _generate_stream( |
| prompt: List[Dict[str, Any]], |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| stop_sequences: Optional[List[str]] = None, |
| tools: Optional[List[Dict[str, Any]]] = None, |
| tool_choice: Optional[Union[str, Dict[str, Any]]] = None, |
| response_format: Optional[Dict[str, str]] = None, |
| ): |
| if llm is None: |
| raise HTTPException(status_code=503, detail="Model not loaded") |
| |
| kwargs = { |
| "messages": prompt, |
| "max_tokens": max_new_tokens, |
| "temperature": temperature, |
| "top_p": top_p, |
| "stop": stop_sequences, |
| "stream": True, |
| } |
| if tools: |
| kwargs["tools"] = tools |
| if tool_choice: |
| kwargs["tool_choice"] = tool_choice |
| if response_format: |
| kwargs["response_format"] = response_format |
| |
| def sync_gen(): |
| for chunk in llm.create_chat_completion(**kwargs): |
| yield chunk |
| |
| for chunk in await asyncio.to_thread(list, sync_gen()): |
| yield chunk |
| await asyncio.sleep(0) |
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| try: |
| await _precache_all_models() |
| await _ensure_model_loaded(DEFAULT_MODEL_NAME) |
| logger.info(f"Default model '{DEFAULT_MODEL_NAME}' loaded successfully") |
| except Exception as e: |
| logger.error(f"Startup model load failed: {e}") |
| yield |
| global llm |
| llm = None |
|
|
| app = FastAPI( |
| title="Bonsai Multi-Model Inference API", |
| version="3.0.0", |
| description="Lightning-fast inference for Bonsai LLMs with tool calling support.", |
| docs_url="/docs", |
| redoc_url="/redoc", |
| lifespan=lifespan, |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=os.getenv("ALLOW_ORIGINS", "*").split(","), |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.middleware("http") |
| async def auth_middleware(request: Request, call_next): |
| _verify_api_key(request) |
| response = await call_next(request) |
| return response |
|
|
| @app.exception_handler(HTTPException) |
| async def http_exception_handler(request, exc): |
| return JSONResponse( |
| status_code=exc.status_code, |
| content=ErrorResponse(error=exc.detail, detail=str(exc.detail)).model_dump(), |
| ) |
|
|
| @app.exception_handler(ValidationError) |
| async def validation_exception_handler(request, exc): |
| return JSONResponse( |
| status_code=422, |
| content=ErrorResponse(error="Validation error", detail=str(exc)).model_dump(), |
| ) |
|
|
| @app.exception_handler(Exception) |
| async def generic_exception_handler(request, exc): |
| logger.exception("Unhandled exception") |
| return JSONResponse( |
| status_code=500, |
| content=ErrorResponse(error="Internal server error", detail=str(exc)).model_dump(), |
| ) |
|
|
| @app.get("/", summary="Root") |
| def root(): |
| return {"message": "Bonsai Multi-Model API is running", "docs": "/docs"} |
|
|
| @app.get("/health", summary="Health check") |
| def health(): |
| loaded = llm is not None |
| return { |
| "status": "ok" if loaded else "degraded", |
| "model_loaded": loaded, |
| "current_model": current_model_name, |
| "cached_models": list(DOWNLOADED_MODELS), |
| "error": model_load_error if model_load_error else None, |
| } |
|
|
| @app.get("/v1/models", response_model=ModelListResponse, summary="List available models") |
| def list_models(): |
| models = [] |
| for name in MODEL_REGISTRY.keys(): |
| models.append(ModelInfo(id=name, created=int(time.time()))) |
| return ModelListResponse(data=models) |
|
|
| @app.get("/v1/models/{model_name}", response_model=ModelInfo, summary="Get model information") |
| def get_model(model_name: str): |
| if model_name not in MODEL_REGISTRY: |
| raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") |
| return ModelInfo(id=model_name, created=int(time.time())) |
|
|
| @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) |
| async def chat_completions(req: ChatCompletionRequest): |
| model_name = req.model or DEFAULT_MODEL_NAME |
| await _ensure_model_loaded(model_name) |
| |
| prompt = _build_chat_prompt(req.messages) |
| tools = _convert_tools(req.tools) |
| stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None) |
|
|
| if req.stream: |
| async def stream_generator(): |
| yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" |
| async for chunk in _generate_stream(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq, tools, req.tool_choice, req.response_format): |
| delta = {} |
| if "choices" in chunk and len(chunk["choices"]) > 0: |
| choice = chunk["choices"][0] |
| if "delta" in choice: |
| delta = choice["delta"] |
| yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': delta, 'finish_reason': None}]})}\n\n" |
| await asyncio.sleep(0) |
| yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" |
| yield "data: [DONE]\n\n" |
| return StreamingResponse(stream_generator(), media_type="text/event-stream") |
| else: |
| result = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq, tools, req.tool_choice, req.response_format) |
| choice = result["choices"][0] |
| message_data = choice.get("message", {}) |
| assistant_msg = Message( |
| role=message_data.get("role", "assistant"), |
| content=message_data.get("content"), |
| tool_calls=message_data.get("tool_calls"), |
| ) |
| finish_reason = choice.get("finish_reason", "stop") |
| usage_data = result.get("usage", {}) |
| usage = Usage( |
| prompt_tokens=usage_data.get("prompt_tokens", 0), |
| completion_tokens=usage_data.get("completion_tokens", 0), |
| total_tokens=usage_data.get("total_tokens", 0), |
| ) |
| return ChatCompletionResponse( |
| id=f"chatcmpl-{uuid.uuid4().hex[:12]}", |
| created=int(time.time()), |
| model=model_name, |
| choices=[ChatCompletionResponseChoice(index=0, message=assistant_msg, finish_reason=finish_reason)], |
| usage=usage, |
| ) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |