Shrot101's picture
fix: add root route to avoid 404
2c126c1
"""
LLMOpt FastAPI application.
Endpoints:
POST /generate — full pipeline, returns response + metrics
POST /explain — returns routing decision without LLM call
GET /models — list all registered models
GET /health — health check
"""
from __future__ import annotations
import os
import logging
from typing import Optional, Dict
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from llmopt.core import LLMOpt
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(
title="LLMOpt",
description="Adaptive LLM Inference Optimization Framework",
version="0.1.0",
)
# Single shared client (stateless — safe for concurrent use)
_client = LLMOpt(log_level=os.getenv("LOG_LEVEL", "WARNING"))
# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class GenerateRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=32000, description="User query")
budget_mode: str = Field("balanced", description="cheap | balanced | quality")
max_cost_per_request: Optional[float] = Field(None, description="Hard cost cap in USD")
quality_threshold: float = Field(0.60, ge=0.0, le=1.0)
exclude_providers: list[str] = Field(default_factory=list)
only_providers: list[str] = Field(default_factory=list)
prefer_local: bool = Field(False, description="Prefer Ollama local models")
conversation_history: Optional[list[dict]] = Field(None)
temperature: float = Field(0.7, ge=0.0, le=2.0)
dry_run: bool = Field(False, description="Skip actual LLM call")
# Bring Your Own Key (BYOK) support
api_keys: Optional[Dict[str, str]] = Field(
None,
description="Optional provider API keys (e.g. {'openai': 'sk-...', 'anthropic': '...' })"
)
class GenerateResponse(BaseModel):
response: str
model_used: str
provider: str
input_tokens: int
output_tokens: int
total_tokens: int
estimated_cost: float
tokens_saved: int
cost_saved: float
compression_ratio: float
complexity_score: float
complexity_tier: str
latency_ms: float
class ExplainRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=32000)
budget_mode: str = Field("balanced")
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/")
def root():
return {
"message": "LLMOpt V2 API is running!",
"docs": "/docs",
"health": "/health"
}
@app.get("/health")
def health():
return {"status": "ok", "version": "0.1.0"}
@app.get("/models")
def list_models():
"""List all models in the registry with their specs."""
return {"models": _client.registry.summary_table()}
@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
"""
Full pipeline: analyze → optimize → route → return response + metrics.
"""
try:
result = _client.generate(
query=req.query,
budget_mode=req.budget_mode,
max_cost_per_request=req.max_cost_per_request,
quality_threshold=req.quality_threshold,
exclude_providers=req.exclude_providers,
only_providers=req.only_providers,
prefer_local=req.prefer_local,
conversation_history=req.conversation_history,
temperature=req.temperature,
dry_run=req.dry_run,
api_keys=req.api_keys, # Pass BYOK keys
)
return GenerateResponse(**result.to_dict())
except KeyError as e:
raise HTTPException(status_code=400, detail=f"Model not found: {e}")
except Exception as e:
logger.exception("generate() failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/explain")
def explain(req: ExplainRequest):
"""
Returns the full routing decision for a query WITHOUT making an LLM API call.
Useful for debugging, testing, and understanding optimization decisions.
"""
try:
return _client.explain(query=req.query, budget_mode=req.budget_mode)
except Exception as e:
logger.exception("explain() failed")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stream")
def stream_generate(req: GenerateRequest):
"""Server-sent stream of response tokens."""
def token_generator():
try:
for chunk in _client.stream(
query=req.query,
budget_mode=req.budget_mode,
api_keys=req.api_keys, # Pass BYOK keys
):
yield chunk
except Exception as e:
yield f"\n[ERROR: {e}]"
return StreamingResponse(token_generator(), media_type="text/plain")