"""FastAPI service that wraps the GreenRouting classifier behind the partner- specific response schema. Endpoints: POST /classify - classify a query and pick a model from the partner registry GET /health - liveness probe used by the partner edge function Auth: none. Stateless. CORS open. Single-process. Designed for a HF Spaces Docker deployment with periodic /health pings keeping the container warm. """ from __future__ import annotations import logging import os import time from pathlib import Path from typing import Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from greenrouting.classifier.trained_predictor import TrainedPredictor from mapper import ( build_reason, fold_recent_context, energy_savings_pct, pick_category, pick_complexity, pick_difficulty_int, rebucket_capabilities, select_model, ) from partner_registry import load_registry logger = logging.getLogger("router-api") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") ARTIFACT_DIR = os.environ.get("CLASSIFIER_ARTIFACT_DIR", "models/classifier_v1") INCLUDE_REASON = os.environ.get("INCLUDE_REASON", "1") not in ("0", "false", "False") app = FastAPI(title="GreenRouting Partner Router", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"], max_age=3600, ) _predictor: Optional[TrainedPredictor] = None _registry = None class RecentMessage(BaseModel): role: str content: str class ClassifyRequest(BaseModel): message: str = Field(min_length=1, max_length=8000) recentMessages: Optional[list[RecentMessage]] = None class ClassifyResponse(BaseModel): category: str complexity: str model_id: str capability_weights: dict[str, float] difficulty: int energy_savings_pct: Optional[float] = None method: str reason: Optional[str] = None def _ensure_loaded() -> None: global _predictor, _registry if _predictor is None: artifact_path = Path(ARTIFACT_DIR) if not (artifact_path / "head.pt").exists(): raise RuntimeError(f"trained classifier not found at {artifact_path}") _predictor = TrainedPredictor(artifact_path) _predictor.predict("warm up") logger.info("classifier loaded and warmed") if _registry is None: _registry = load_registry() logger.info("partner registry loaded with %d models", len(_registry)) @app.on_event("startup") def _startup() -> None: try: _ensure_loaded() except Exception as exc: logger.warning("startup warm load failed: %s (will retry on first request)", exc) @app.get("/health") def health() -> dict: try: _ensure_loaded() return {"status": "ok"} except Exception as exc: logger.exception("health check failed") raise HTTPException(status_code=503, detail=f"unhealthy: {exc}") @app.post("/classify", response_model=ClassifyResponse) def classify(req: ClassifyRequest) -> ClassifyResponse: _ensure_loaded() started = time.time() folded = fold_recent_context( req.message, [m.dict() for m in req.recentMessages] if req.recentMessages else None, ) profile = _predictor.predict(folded) weights = rebucket_capabilities(profile) category = pick_category(weights) complexity = pick_complexity(profile) difficulty = pick_difficulty_int(profile) chosen, escalated = select_model(_registry, weights, difficulty, is_ood=profile.is_ood) savings: Optional[float] if profile.is_ood or escalated: savings = None else: savings = round(energy_savings_pct(chosen), 1) reason = ( build_reason(weights, complexity, chosen, escalated, is_ood=profile.is_ood) if INCLUDE_REASON else None ) elapsed_ms = (time.time() - started) * 1000.0 logger.info( "classify model=%s tier=%s difficulty=%d category=%s ood=%s escalated=%s elapsed_ms=%.1f", chosen.id, chosen.tier, difficulty, category, profile.is_ood, escalated, elapsed_ms, ) return ClassifyResponse( category=category, complexity=complexity, model_id=chosen.id, capability_weights=weights, difficulty=difficulty, energy_savings_pct=savings, method="greenrouting", reason=reason, ) if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info")