Spaces:
Sleeping
Sleeping
File size: 4,731 Bytes
6f0ff99 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """FastAPI service that wraps the GreenRouting classifier behind the partner-
specific response schema.
Endpoints:
POST /classify - classify a query and pick a model from the partner registry
GET /health - liveness probe used by the partner edge function
Auth: none. Stateless. CORS open. Single-process. Designed for a HF Spaces
Docker deployment with periodic /health pings keeping the container warm.
"""
from __future__ import annotations
import logging
import os
import time
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from greenrouting.classifier.trained_predictor import TrainedPredictor
from mapper import (
build_reason,
fold_recent_context,
energy_savings_pct,
pick_category,
pick_complexity,
pick_difficulty_int,
rebucket_capabilities,
select_model,
)
from partner_registry import load_registry
logger = logging.getLogger("router-api")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
ARTIFACT_DIR = os.environ.get("CLASSIFIER_ARTIFACT_DIR", "models/classifier_v1")
INCLUDE_REASON = os.environ.get("INCLUDE_REASON", "1") not in ("0", "false", "False")
app = FastAPI(title="GreenRouting Partner Router", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
max_age=3600,
)
_predictor: Optional[TrainedPredictor] = None
_registry = None
class RecentMessage(BaseModel):
role: str
content: str
class ClassifyRequest(BaseModel):
message: str = Field(min_length=1, max_length=8000)
recentMessages: Optional[list[RecentMessage]] = None
class ClassifyResponse(BaseModel):
category: str
complexity: str
model_id: str
capability_weights: dict[str, float]
difficulty: int
energy_savings_pct: Optional[float] = None
method: str
reason: Optional[str] = None
def _ensure_loaded() -> None:
global _predictor, _registry
if _predictor is None:
artifact_path = Path(ARTIFACT_DIR)
if not (artifact_path / "head.pt").exists():
raise RuntimeError(f"trained classifier not found at {artifact_path}")
_predictor = TrainedPredictor(artifact_path)
_predictor.predict("warm up")
logger.info("classifier loaded and warmed")
if _registry is None:
_registry = load_registry()
logger.info("partner registry loaded with %d models", len(_registry))
@app.on_event("startup")
def _startup() -> None:
try:
_ensure_loaded()
except Exception as exc:
logger.warning("startup warm load failed: %s (will retry on first request)", exc)
@app.get("/health")
def health() -> dict:
try:
_ensure_loaded()
return {"status": "ok"}
except Exception as exc:
logger.exception("health check failed")
raise HTTPException(status_code=503, detail=f"unhealthy: {exc}")
@app.post("/classify", response_model=ClassifyResponse)
def classify(req: ClassifyRequest) -> ClassifyResponse:
_ensure_loaded()
started = time.time()
folded = fold_recent_context(
req.message,
[m.dict() for m in req.recentMessages] if req.recentMessages else None,
)
profile = _predictor.predict(folded)
weights = rebucket_capabilities(profile)
category = pick_category(weights)
complexity = pick_complexity(profile)
difficulty = pick_difficulty_int(profile)
chosen, escalated = select_model(_registry, weights, difficulty, is_ood=profile.is_ood)
savings: Optional[float]
if profile.is_ood or escalated:
savings = None
else:
savings = round(energy_savings_pct(chosen), 1)
reason = (
build_reason(weights, complexity, chosen, escalated, is_ood=profile.is_ood)
if INCLUDE_REASON
else None
)
elapsed_ms = (time.time() - started) * 1000.0
logger.info(
"classify model=%s tier=%s difficulty=%d category=%s ood=%s escalated=%s elapsed_ms=%.1f",
chosen.id, chosen.tier, difficulty, category, profile.is_ood, escalated, elapsed_ms,
)
return ClassifyResponse(
category=category,
complexity=complexity,
model_id=chosen.id,
capability_weights=weights,
difficulty=difficulty,
energy_savings_pct=savings,
method="greenrouting",
reason=reason,
)
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")
|