File size: 5,161 Bytes
3c1db6c
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8ea91
3c1db6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8ea91
 
 
 
 
 
3c1db6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c126c1
 
 
 
 
 
 
 
 
3c1db6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8ea91
3c1db6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd8ea91
3c1db6c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
LLMOpt FastAPI application.

Endpoints:
  POST /generate       β€” full pipeline, returns response + metrics
  POST /explain        β€” returns routing decision without LLM call
  GET  /models         β€” list all registered models
  GET  /health         β€” health check
"""

from __future__ import annotations

import os
import logging
from typing import Optional, Dict

from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

from llmopt.core import LLMOpt

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------

app = FastAPI(
    title="LLMOpt",
    description="Adaptive LLM Inference Optimization Framework",
    version="0.1.0",
)

# Single shared client (stateless β€” safe for concurrent use)
_client = LLMOpt(log_level=os.getenv("LOG_LEVEL", "WARNING"))


# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------

class GenerateRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=32000, description="User query")
    budget_mode: str = Field("balanced", description="cheap | balanced | quality")
    max_cost_per_request: Optional[float] = Field(None, description="Hard cost cap in USD")
    quality_threshold: float = Field(0.60, ge=0.0, le=1.0)
    exclude_providers: list[str] = Field(default_factory=list)
    only_providers: list[str] = Field(default_factory=list)
    prefer_local: bool = Field(False, description="Prefer Ollama local models")
    conversation_history: Optional[list[dict]] = Field(None)
    temperature: float = Field(0.7, ge=0.0, le=2.0)
    dry_run: bool = Field(False, description="Skip actual LLM call")

    # Bring Your Own Key (BYOK) support
    api_keys: Optional[Dict[str, str]] = Field(
        None,
        description="Optional provider API keys (e.g. {'openai': 'sk-...', 'anthropic': '...' })"
    )


class GenerateResponse(BaseModel):
    response: str
    model_used: str
    provider: str
    input_tokens: int
    output_tokens: int
    total_tokens: int
    estimated_cost: float
    tokens_saved: int
    cost_saved: float
    compression_ratio: float
    complexity_score: float
    complexity_tier: str
    latency_ms: float


class ExplainRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=32000)
    budget_mode: str = Field("balanced")


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------

@app.get("/")
def root():
    return {
        "message": "LLMOpt V2 API is running!",
        "docs": "/docs",
        "health": "/health"
    }


@app.get("/health")
def health():
    return {"status": "ok", "version": "0.1.0"}


@app.get("/models")
def list_models():
    """List all models in the registry with their specs."""
    return {"models": _client.registry.summary_table()}


@app.post("/generate", response_model=GenerateResponse)
def generate(req: GenerateRequest):
    """
    Full pipeline: analyze β†’ optimize β†’ route β†’ return response + metrics.
    """
    try:
        result = _client.generate(
            query=req.query,
            budget_mode=req.budget_mode,
            max_cost_per_request=req.max_cost_per_request,
            quality_threshold=req.quality_threshold,
            exclude_providers=req.exclude_providers,
            only_providers=req.only_providers,
            prefer_local=req.prefer_local,
            conversation_history=req.conversation_history,
            temperature=req.temperature,
            dry_run=req.dry_run,
            api_keys=req.api_keys,  # Pass BYOK keys
        )
        return GenerateResponse(**result.to_dict())
    except KeyError as e:
        raise HTTPException(status_code=400, detail=f"Model not found: {e}")
    except Exception as e:
        logger.exception("generate() failed")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/explain")
def explain(req: ExplainRequest):
    """
    Returns the full routing decision for a query WITHOUT making an LLM API call.
    Useful for debugging, testing, and understanding optimization decisions.
    """
    try:
        return _client.explain(query=req.query, budget_mode=req.budget_mode)
    except Exception as e:
        logger.exception("explain() failed")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/stream")
def stream_generate(req: GenerateRequest):
    """Server-sent stream of response tokens."""
    def token_generator():
        try:
            for chunk in _client.stream(
                query=req.query,
                budget_mode=req.budget_mode,
                api_keys=req.api_keys,  # Pass BYOK keys
            ):
                yield chunk
        except Exception as e:
            yield f"\n[ERROR: {e}]"

    return StreamingResponse(token_generator(), media_type="text/plain")