from __future__ import annotations import asyncio import json import logging import os import time import uuid from contextlib import asynccontextmanager from typing import Dict, List, Optional, Union, Any from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from huggingface_hub import hf_hub_download from pydantic import BaseModel, Field, ValidationError from llama_cpp import Llama # ---------- Configuration ---------- DEFAULT_MODEL_NAME = os.getenv("DEFAULT_MODEL_NAME", "bonsai-1.7b") LOCAL_MODEL_DIR = os.getenv("LOCAL_MODEL_DIR", "/data/models") MAX_NEW_TOKENS_DEFAULT = int(os.getenv("MAX_NEW_TOKENS_DEFAULT", "256")) API_KEY = os.getenv("API_KEY", None) HF_TOKEN = os.getenv("HF_TOKEN") # Performance settings N_CTX = int(os.getenv("N_CTX", "4096")) N_THREADS = int(os.getenv("N_THREADS", "4")) N_BATCH = int(os.getenv("N_BATCH", "512")) # ---------- Model Registry ---------- MODEL_REGISTRY: Dict[str, Dict[str, str]] = { "bonsai-1.7b": { "repo_id": "lilyanatia/Bonsai-1.7B-requantized", "filename": "Bonsai-1.7B-IQ1_S.gguf", }, "bonsai-4b": { "repo_id": "lilyanatia/Bonsai-4B-requantized", "filename": "Bonsai-4B-IQ1_S.gguf", }, "bonsai-8b": { "repo_id": "lilyanatia/Bonsai-8B-requantized", "filename": "Bonsai-8B-IQ1_S.gguf", }, } logging.basicConfig(level=logging.INFO) logger = logging.getLogger("uvicorn.error") # ---------- Pydantic Models ---------- class Message(BaseModel): role: str = Field(..., pattern="^(system|user|assistant|tool)$") content: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = None tool_call_id: Optional[str] = None name: Optional[str] = None class ToolFunction(BaseModel): name: str description: Optional[str] = None parameters: Optional[Dict[str, Any]] = None class Tool(BaseModel): type: str = "function" function: ToolFunction class ChatCompletionRequest(BaseModel): messages: List[Message] model: str = Field(default=DEFAULT_MODEL_NAME) max_tokens: int = Field(default=MAX_NEW_TOKENS_DEFAULT, ge=1, le=2048) temperature: float = Field(default=0.7, ge=0.0, le=2.0) top_p: float = Field(default=0.95, gt=0.0, le=1.0) stream: bool = False stop: Optional[Union[str, List[str]]] = None tools: Optional[List[Tool]] = None tool_choice: Optional[Union[str, Dict[str, Any]]] = None response_format: Optional[Dict[str, str]] = None class ChatCompletionResponseChoice(BaseModel): index: int message: Message finish_reason: str class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatCompletionResponseChoice] usage: Usage class ModelInfo(BaseModel): id: str object: str = "model" created: int owned_by: str = "lilyanatia" class ModelListResponse(BaseModel): object: str = "list" data: List[ModelInfo] class ErrorResponse(BaseModel): error: str detail: Optional[str] = None # ---------- Global State ---------- current_model_name: Optional[str] = None llm: Optional[Llama] = None model_load_error: Optional[str] = None MODEL_LOCK = asyncio.Lock() DOWNLOADED_MODELS = set() # ---------- Helper Functions ---------- def _verify_api_key(request: Request) -> None: if API_KEY is None: return auth = request.headers.get("X-API-Key") if not auth or auth != API_KEY: raise HTTPException(status_code=401, detail="Invalid or missing API key") def _download_model(model_name: str) -> str: """Downloads a model if it's not already present.""" if model_name not in MODEL_REGISTRY: raise HTTPException(status_code=400, detail=f"Model '{model_name}' not found in registry.") model_info = MODEL_REGISTRY[model_name] repo_id = model_info["repo_id"] filename = model_info["filename"] os.makedirs(LOCAL_MODEL_DIR, exist_ok=True) local_path = os.path.join(LOCAL_MODEL_DIR, filename) if os.path.exists(local_path): logger.info(f"Model '{model_name}' already downloaded at {local_path}") return local_path logger.info(f"Downloading model '{model_name}' from {repo_id}/{filename}...") try: hf_hub_download( repo_id=repo_id, filename=filename, local_dir=LOCAL_MODEL_DIR, token=HF_TOKEN, ) logger.info(f"Model '{model_name}' downloaded successfully.") return local_path except Exception as e: logger.error(f"Model download failed for '{model_name}': {e}") raise HTTPException(status_code=500, detail=f"Failed to download model: {str(e)}") async def _precache_all_models(): """Downloads all models in the registry at startup.""" logger.info("Pre-caching all models in registry...") download_tasks = [] for model_name in MODEL_REGISTRY.keys(): download_tasks.append(asyncio.to_thread(_download_model, model_name)) results = await asyncio.gather(*download_tasks, return_exceptions=True) for model_name, result in zip(MODEL_REGISTRY.keys(), results): if isinstance(result, Exception): logger.error(f"Failed to pre-cache model '{model_name}': {result}") else: DOWNLOADED_MODELS.add(model_name) logger.info(f"Model '{model_name}' is ready.") logger.info(f"Pre-caching complete. {len(DOWNLOADED_MODELS)}/{len(MODEL_REGISTRY)} models cached.") async def _ensure_model_loaded(model_name: str): """Loads the specified model, downloading it first if necessary.""" global llm, current_model_name, model_load_error async with MODEL_LOCK: if current_model_name == model_name and llm is not None: return if llm is not None: logger.info(f"Unloading previous model '{current_model_name}'...") del llm llm = None current_model_name = None try: model_path = _download_model(model_name) llm = Llama( model_path=model_path, n_ctx=N_CTX, n_threads=N_THREADS, n_batch=N_BATCH, verbose=False, ) current_model_name = model_name logger.info(f"Model '{model_name}' loaded successfully.") except Exception as e: model_load_error = str(e) logger.exception(f"Model loading failed for '{model_name}'") raise HTTPException(status_code=503, detail=f"Model unavailable: {model_load_error}") def _build_chat_prompt(messages: List[Message]) -> List[Dict[str, Any]]: """Convert Pydantic messages to dict format for llama.cpp.""" formatted = [] for msg in messages: msg_dict = {"role": msg.role, "content": msg.content} if msg.tool_calls: msg_dict["tool_calls"] = msg.tool_calls if msg.tool_call_id: msg_dict["tool_call_id"] = msg.tool_call_id if msg.name: msg_dict["name"] = msg.name formatted.append(msg_dict) return formatted def _convert_tools(tools: Optional[List[Tool]]) -> Optional[List[Dict[str, Any]]]: """Convert Pydantic tools to dict format for llama.cpp.""" if not tools: return None return [tool.model_dump() for tool in tools] async def _generate_full( prompt: List[Dict[str, Any]], max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[str, Dict[str, Any]]] = None, response_format: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: if llm is None: raise HTTPException(status_code=503, detail="Model not loaded") kwargs = { "messages": prompt, "max_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "stop": stop_sequences, "stream": False, } if tools: kwargs["tools"] = tools if tool_choice: kwargs["tool_choice"] = tool_choice if response_format: kwargs["response_format"] = response_format result = await asyncio.to_thread(lambda: llm.create_chat_completion(**kwargs)) return result async def _generate_stream( prompt: List[Dict[str, Any]], max_new_tokens: int, temperature: float, top_p: float, stop_sequences: Optional[List[str]] = None, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[str, Dict[str, Any]]] = None, response_format: Optional[Dict[str, str]] = None, ): if llm is None: raise HTTPException(status_code=503, detail="Model not loaded") kwargs = { "messages": prompt, "max_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "stop": stop_sequences, "stream": True, } if tools: kwargs["tools"] = tools if tool_choice: kwargs["tool_choice"] = tool_choice if response_format: kwargs["response_format"] = response_format def sync_gen(): for chunk in llm.create_chat_completion(**kwargs): yield chunk for chunk in await asyncio.to_thread(list, sync_gen()): yield chunk await asyncio.sleep(0) # ---------- FastAPI App ---------- @asynccontextmanager async def lifespan(app: FastAPI): try: await _precache_all_models() await _ensure_model_loaded(DEFAULT_MODEL_NAME) logger.info(f"Default model '{DEFAULT_MODEL_NAME}' loaded successfully") except Exception as e: logger.error(f"Startup model load failed: {e}") yield global llm llm = None app = FastAPI( title="Bonsai Multi-Model Inference API", version="3.0.0", description="Lightning-fast inference for Bonsai LLMs with tool calling support.", docs_url="/docs", redoc_url="/redoc", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=os.getenv("ALLOW_ORIGINS", "*").split(","), allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.middleware("http") async def auth_middleware(request: Request, call_next): _verify_api_key(request) response = await call_next(request) return response @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): return JSONResponse( status_code=exc.status_code, content=ErrorResponse(error=exc.detail, detail=str(exc.detail)).model_dump(), ) @app.exception_handler(ValidationError) async def validation_exception_handler(request, exc): return JSONResponse( status_code=422, content=ErrorResponse(error="Validation error", detail=str(exc)).model_dump(), ) @app.exception_handler(Exception) async def generic_exception_handler(request, exc): logger.exception("Unhandled exception") return JSONResponse( status_code=500, content=ErrorResponse(error="Internal server error", detail=str(exc)).model_dump(), ) @app.get("/", summary="Root") def root(): return {"message": "Bonsai Multi-Model API is running", "docs": "/docs"} @app.get("/health", summary="Health check") def health(): loaded = llm is not None return { "status": "ok" if loaded else "degraded", "model_loaded": loaded, "current_model": current_model_name, "cached_models": list(DOWNLOADED_MODELS), "error": model_load_error if model_load_error else None, } @app.get("/v1/models", response_model=ModelListResponse, summary="List available models") def list_models(): models = [] for name in MODEL_REGISTRY.keys(): models.append(ModelInfo(id=name, created=int(time.time()))) return ModelListResponse(data=models) @app.get("/v1/models/{model_name}", response_model=ModelInfo, summary="Get model information") def get_model(model_name: str): if model_name not in MODEL_REGISTRY: raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") return ModelInfo(id=model_name, created=int(time.time())) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(req: ChatCompletionRequest): model_name = req.model or DEFAULT_MODEL_NAME await _ensure_model_loaded(model_name) prompt = _build_chat_prompt(req.messages) tools = _convert_tools(req.tools) stop_seq = req.stop if isinstance(req.stop, list) else ([req.stop] if req.stop else None) if req.stream: async def stream_generator(): yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': {'role': 'assistant'}, 'finish_reason': None}]})}\n\n" async for chunk in _generate_stream(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq, tools, req.tool_choice, req.response_format): delta = {} if "choices" in chunk and len(chunk["choices"]) > 0: choice = chunk["choices"][0] if "delta" in choice: delta = choice["delta"] yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': delta, 'finish_reason': None}]})}\n\n" await asyncio.sleep(0) yield f"data: {json.dumps({'id': f'chatcmpl-{uuid.uuid4().hex[:12]}', 'object': 'chat.completion.chunk', 'created': int(time.time()), 'model': model_name, 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'stop'}]})}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(stream_generator(), media_type="text/event-stream") else: result = await _generate_full(prompt, req.max_tokens, req.temperature, req.top_p, stop_seq, tools, req.tool_choice, req.response_format) choice = result["choices"][0] message_data = choice.get("message", {}) assistant_msg = Message( role=message_data.get("role", "assistant"), content=message_data.get("content"), tool_calls=message_data.get("tool_calls"), ) finish_reason = choice.get("finish_reason", "stop") usage_data = result.get("usage", {}) usage = Usage( prompt_tokens=usage_data.get("prompt_tokens", 0), completion_tokens=usage_data.get("completion_tokens", 0), total_tokens=usage_data.get("total_tokens", 0), ) return ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex[:12]}", created=int(time.time()), model=model_name, choices=[ChatCompletionResponseChoice(index=0, message=assistant_msg, finish_reason=finish_reason)], usage=usage, ) if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)