router-api / app.py
spectralman's picture
Initial deploy: classifier + FastAPI router
6f0ff99 verified
"""FastAPI service that wraps the GreenRouting classifier behind the partner-
specific response schema.
Endpoints:
POST /classify - classify a query and pick a model from the partner registry
GET /health - liveness probe used by the partner edge function
Auth: none. Stateless. CORS open. Single-process. Designed for a HF Spaces
Docker deployment with periodic /health pings keeping the container warm.
"""
from __future__ import annotations
import logging
import os
import time
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from greenrouting.classifier.trained_predictor import TrainedPredictor
from mapper import (
build_reason,
fold_recent_context,
energy_savings_pct,
pick_category,
pick_complexity,
pick_difficulty_int,
rebucket_capabilities,
select_model,
)
from partner_registry import load_registry
logger = logging.getLogger("router-api")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
ARTIFACT_DIR = os.environ.get("CLASSIFIER_ARTIFACT_DIR", "models/classifier_v1")
INCLUDE_REASON = os.environ.get("INCLUDE_REASON", "1") not in ("0", "false", "False")
app = FastAPI(title="GreenRouting Partner Router", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
max_age=3600,
)
_predictor: Optional[TrainedPredictor] = None
_registry = None
class RecentMessage(BaseModel):
role: str
content: str
class ClassifyRequest(BaseModel):
message: str = Field(min_length=1, max_length=8000)
recentMessages: Optional[list[RecentMessage]] = None
class ClassifyResponse(BaseModel):
category: str
complexity: str
model_id: str
capability_weights: dict[str, float]
difficulty: int
energy_savings_pct: Optional[float] = None
method: str
reason: Optional[str] = None
def _ensure_loaded() -> None:
global _predictor, _registry
if _predictor is None:
artifact_path = Path(ARTIFACT_DIR)
if not (artifact_path / "head.pt").exists():
raise RuntimeError(f"trained classifier not found at {artifact_path}")
_predictor = TrainedPredictor(artifact_path)
_predictor.predict("warm up")
logger.info("classifier loaded and warmed")
if _registry is None:
_registry = load_registry()
logger.info("partner registry loaded with %d models", len(_registry))
@app.on_event("startup")
def _startup() -> None:
try:
_ensure_loaded()
except Exception as exc:
logger.warning("startup warm load failed: %s (will retry on first request)", exc)
@app.get("/health")
def health() -> dict:
try:
_ensure_loaded()
return {"status": "ok"}
except Exception as exc:
logger.exception("health check failed")
raise HTTPException(status_code=503, detail=f"unhealthy: {exc}")
@app.post("/classify", response_model=ClassifyResponse)
def classify(req: ClassifyRequest) -> ClassifyResponse:
_ensure_loaded()
started = time.time()
folded = fold_recent_context(
req.message,
[m.dict() for m in req.recentMessages] if req.recentMessages else None,
)
profile = _predictor.predict(folded)
weights = rebucket_capabilities(profile)
category = pick_category(weights)
complexity = pick_complexity(profile)
difficulty = pick_difficulty_int(profile)
chosen, escalated = select_model(_registry, weights, difficulty, is_ood=profile.is_ood)
savings: Optional[float]
if profile.is_ood or escalated:
savings = None
else:
savings = round(energy_savings_pct(chosen), 1)
reason = (
build_reason(weights, complexity, chosen, escalated, is_ood=profile.is_ood)
if INCLUDE_REASON
else None
)
elapsed_ms = (time.time() - started) * 1000.0
logger.info(
"classify model=%s tier=%s difficulty=%d category=%s ood=%s escalated=%s elapsed_ms=%.1f",
chosen.id, chosen.tier, difficulty, category, profile.is_ood, escalated, elapsed_ms,
)
return ClassifyResponse(
category=category,
complexity=complexity,
model_id=chosen.id,
capability_weights=weights,
difficulty=difficulty,
energy_savings_pct=savings,
method="greenrouting",
reason=reason,
)
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")