Spaces:
Sleeping
Sleeping
File size: 5,161 Bytes
3c1db6c dd8ea91 3c1db6c dd8ea91 3c1db6c 2c126c1 3c1db6c dd8ea91 3c1db6c dd8ea91 3c1db6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """
LLMOpt FastAPI application.
Endpoints:
POST /generate β full pipeline, returns response + metrics
POST /explain β returns routing decision without LLM call
GET /models β list all registered models
GET /health β health check
"""
from __future__ import annotations
import os
import logging
from typing import Optional, Dict
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from llmopt.core import LLMOpt
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(
title="LLMOpt",
description="Adaptive LLM Inference Optimization Framework",
version="0.1.0",
)
# Single shared client (stateless β safe for concurrent use)
_client = LLMOpt(log_level=os.getenv("LOG_LEVEL", "WARNING"))
# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class GenerateRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=32000, description="User query")
budget_mode: str = Field("balanced", description="cheap | balanced | quality")
max_cost_per_request: Optional[float] = Field(None, description="Hard cost cap in USD")
quality_threshold: float = Field(0.60, ge=0.0, le=1.0)
exclude_providers: list[str] = Field(default_factory=list)
only_providers: list[str] = Field(default_factory=list)
prefer_local: bool = Field(False, description="Prefer Ollama local models")
conversation_history: Optional[list[dict]] = Field(None)
temperature: float = Field(0.7, ge=0.0, le=2.0)
dry_run: bool = Field(False, description="Skip actual LLM call")
# Bring Your Own Key (BYOK) support
api_keys: Optional[Dict[str, str]] = Field(
None,
description="Optional provider API keys (e.g. {'openai': 'sk-...', 'anthropic': '...' })"
)
class GenerateResponse(BaseModel):
response: str
model_used: str
provider: str
input_tokens: int
output_tokens: int
total_tokens: int
estimated_cost: float
tokens_saved: int
cost_saved: float
compression_ratio: float
complexity_score: float
complexity_tier: str
latency_ms: float
class ExplainRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=32000)
budget_mode: str = Field("balanced")
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/")
def root():
return {
"message": "LLMOpt V2 API is running!",
"docs": "/docs",
"health": "/health"
}
@app.get("/health")
def health():
return {"status": "ok", "version": "0.1.0"}
@app.get("/models")
def list_models():
"""List all models in the registry with their specs."""
return {"models": _client.registry.summary_table()}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
"""
Full pipeline: analyze β optimize β route β return response + metrics.
"""
try:
result = _client.generate(
query=req.query,
budget_mode=req.budget_mode,
max_cost_per_request=req.max_cost_per_request,
quality_threshold=req.quality_threshold,
exclude_providers=req.exclude_providers,
only_providers=req.only_providers,
prefer_local=req.prefer_local,
conversation_history=req.conversation_history,
temperature=req.temperature,
dry_run=req.dry_run,
api_keys=req.api_keys, # Pass BYOK keys
)
return GenerateResponse(**result.to_dict())
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Model not found: {e}")
except Exception as e:
logger.exception("generate() failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/explain")
def explain(req: ExplainRequest):
"""
Returns the full routing decision for a query WITHOUT making an LLM API call.
Useful for debugging, testing, and understanding optimization decisions.
"""
try:
return _client.explain(query=req.query, budget_mode=req.budget_mode)
except Exception as e:
logger.exception("explain() failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stream")
def stream_generate(req: GenerateRequest):
"""Server-sent stream of response tokens."""
def token_generator():
try:
for chunk in _client.stream(
query=req.query,
budget_mode=req.budget_mode,
api_keys=req.api_keys, # Pass BYOK keys
):
yield chunk
except Exception as e:
yield f"\n[ERROR: {e}]"
return StreamingResponse(token_generator(), media_type="text/plain")
|