Spaces:
Sleeping
Sleeping
| """ | |
| LLMOpt FastAPI application. | |
| Endpoints: | |
| POST /generate — full pipeline, returns response + metrics | |
| POST /explain — returns routing decision without LLM call | |
| GET /models — list all registered models | |
| GET /health — health check | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import logging | |
| from typing import Optional, Dict | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel, Field | |
| from llmopt.core import LLMOpt | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # App | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="LLMOpt", | |
| description="Adaptive LLM Inference Optimization Framework", | |
| version="0.1.0", | |
| ) | |
| # Single shared client (stateless — safe for concurrent use) | |
| _client = LLMOpt(log_level=os.getenv("LOG_LEVEL", "WARNING")) | |
| # --------------------------------------------------------------------------- | |
| # Request / Response schemas | |
| # --------------------------------------------------------------------------- | |
| class GenerateRequest(BaseModel): | |
| query: str = Field(..., min_length=1, max_length=32000, description="User query") | |
| budget_mode: str = Field("balanced", description="cheap | balanced | quality") | |
| max_cost_per_request: Optional[float] = Field(None, description="Hard cost cap in USD") | |
| quality_threshold: float = Field(0.60, ge=0.0, le=1.0) | |
| exclude_providers: list[str] = Field(default_factory=list) | |
| only_providers: list[str] = Field(default_factory=list) | |
| prefer_local: bool = Field(False, description="Prefer Ollama local models") | |
| conversation_history: Optional[list[dict]] = Field(None) | |
| temperature: float = Field(0.7, ge=0.0, le=2.0) | |
| dry_run: bool = Field(False, description="Skip actual LLM call") | |
| # Bring Your Own Key (BYOK) support | |
| api_keys: Optional[Dict[str, str]] = Field( | |
| None, | |
| description="Optional provider API keys (e.g. {'openai': 'sk-...', 'anthropic': '...' })" | |
| ) | |
| class GenerateResponse(BaseModel): | |
| response: str | |
| model_used: str | |
| provider: str | |
| input_tokens: int | |
| output_tokens: int | |
| total_tokens: int | |
| estimated_cost: float | |
| tokens_saved: int | |
| cost_saved: float | |
| compression_ratio: float | |
| complexity_score: float | |
| complexity_tier: str | |
| latency_ms: float | |
| class ExplainRequest(BaseModel): | |
| query: str = Field(..., min_length=1, max_length=32000) | |
| budget_mode: str = Field("balanced") | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def root(): | |
| return { | |
| "message": "LLMOpt V2 API is running!", | |
| "docs": "/docs", | |
| "health": "/health" | |
| } | |
| def health(): | |
| return {"status": "ok", "version": "0.1.0"} | |
| def list_models(): | |
| """List all models in the registry with their specs.""" | |
| return {"models": _client.registry.summary_table()} | |
| def generate(req: GenerateRequest): | |
| """ | |
| Full pipeline: analyze → optimize → route → return response + metrics. | |
| """ | |
| try: | |
| result = _client.generate( | |
| query=req.query, | |
| budget_mode=req.budget_mode, | |
| max_cost_per_request=req.max_cost_per_request, | |
| quality_threshold=req.quality_threshold, | |
| exclude_providers=req.exclude_providers, | |
| only_providers=req.only_providers, | |
| prefer_local=req.prefer_local, | |
| conversation_history=req.conversation_history, | |
| temperature=req.temperature, | |
| dry_run=req.dry_run, | |
| api_keys=req.api_keys, # Pass BYOK keys | |
| ) | |
| return GenerateResponse(**result.to_dict()) | |
| except KeyError as e: | |
| raise HTTPException(status_code=400, detail=f"Model not found: {e}") | |
| except Exception as e: | |
| logger.exception("generate() failed") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def explain(req: ExplainRequest): | |
| """ | |
| Returns the full routing decision for a query WITHOUT making an LLM API call. | |
| Useful for debugging, testing, and understanding optimization decisions. | |
| """ | |
| try: | |
| return _client.explain(query=req.query, budget_mode=req.budget_mode) | |
| except Exception as e: | |
| logger.exception("explain() failed") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def stream_generate(req: GenerateRequest): | |
| """Server-sent stream of response tokens.""" | |
| def token_generator(): | |
| try: | |
| for chunk in _client.stream( | |
| query=req.query, | |
| budget_mode=req.budget_mode, | |
| api_keys=req.api_keys, # Pass BYOK keys | |
| ): | |
| yield chunk | |
| except Exception as e: | |
| yield f"\n[ERROR: {e}]" | |
| return StreamingResponse(token_generator(), media_type="text/plain") | |