Spaces:
Sleeping
Sleeping
| """FastAPI service that wraps the GreenRouting classifier behind the partner- | |
| specific response schema. | |
| Endpoints: | |
| POST /classify - classify a query and pick a model from the partner registry | |
| GET /health - liveness probe used by the partner edge function | |
| Auth: none. Stateless. CORS open. Single-process. Designed for a HF Spaces | |
| Docker deployment with periodic /health pings keeping the container warm. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from greenrouting.classifier.trained_predictor import TrainedPredictor | |
| from mapper import ( | |
| build_reason, | |
| fold_recent_context, | |
| energy_savings_pct, | |
| pick_category, | |
| pick_complexity, | |
| pick_difficulty_int, | |
| rebucket_capabilities, | |
| select_model, | |
| ) | |
| from partner_registry import load_registry | |
| logger = logging.getLogger("router-api") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") | |
| ARTIFACT_DIR = os.environ.get("CLASSIFIER_ARTIFACT_DIR", "models/classifier_v1") | |
| INCLUDE_REASON = os.environ.get("INCLUDE_REASON", "1") not in ("0", "false", "False") | |
| app = FastAPI(title="GreenRouting Partner Router", version="0.1.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"], | |
| max_age=3600, | |
| ) | |
| _predictor: Optional[TrainedPredictor] = None | |
| _registry = None | |
| class RecentMessage(BaseModel): | |
| role: str | |
| content: str | |
| class ClassifyRequest(BaseModel): | |
| message: str = Field(min_length=1, max_length=8000) | |
| recentMessages: Optional[list[RecentMessage]] = None | |
| class ClassifyResponse(BaseModel): | |
| category: str | |
| complexity: str | |
| model_id: str | |
| capability_weights: dict[str, float] | |
| difficulty: int | |
| energy_savings_pct: Optional[float] = None | |
| method: str | |
| reason: Optional[str] = None | |
| def _ensure_loaded() -> None: | |
| global _predictor, _registry | |
| if _predictor is None: | |
| artifact_path = Path(ARTIFACT_DIR) | |
| if not (artifact_path / "head.pt").exists(): | |
| raise RuntimeError(f"trained classifier not found at {artifact_path}") | |
| _predictor = TrainedPredictor(artifact_path) | |
| _predictor.predict("warm up") | |
| logger.info("classifier loaded and warmed") | |
| if _registry is None: | |
| _registry = load_registry() | |
| logger.info("partner registry loaded with %d models", len(_registry)) | |
| def _startup() -> None: | |
| try: | |
| _ensure_loaded() | |
| except Exception as exc: | |
| logger.warning("startup warm load failed: %s (will retry on first request)", exc) | |
| def health() -> dict: | |
| try: | |
| _ensure_loaded() | |
| return {"status": "ok"} | |
| except Exception as exc: | |
| logger.exception("health check failed") | |
| raise HTTPException(status_code=503, detail=f"unhealthy: {exc}") | |
| def classify(req: ClassifyRequest) -> ClassifyResponse: | |
| _ensure_loaded() | |
| started = time.time() | |
| folded = fold_recent_context( | |
| req.message, | |
| [m.dict() for m in req.recentMessages] if req.recentMessages else None, | |
| ) | |
| profile = _predictor.predict(folded) | |
| weights = rebucket_capabilities(profile) | |
| category = pick_category(weights) | |
| complexity = pick_complexity(profile) | |
| difficulty = pick_difficulty_int(profile) | |
| chosen, escalated = select_model(_registry, weights, difficulty, is_ood=profile.is_ood) | |
| savings: Optional[float] | |
| if profile.is_ood or escalated: | |
| savings = None | |
| else: | |
| savings = round(energy_savings_pct(chosen), 1) | |
| reason = ( | |
| build_reason(weights, complexity, chosen, escalated, is_ood=profile.is_ood) | |
| if INCLUDE_REASON | |
| else None | |
| ) | |
| elapsed_ms = (time.time() - started) * 1000.0 | |
| logger.info( | |
| "classify model=%s tier=%s difficulty=%d category=%s ood=%s escalated=%s elapsed_ms=%.1f", | |
| chosen.id, chosen.tier, difficulty, category, profile.is_ood, escalated, elapsed_ms, | |
| ) | |
| return ClassifyResponse( | |
| category=category, | |
| complexity=complexity, | |
| model_id=chosen.id, | |
| capability_weights=weights, | |
| difficulty=difficulty, | |
| energy_savings_pct=savings, | |
| method="greenrouting", | |
| reason=reason, | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, log_level="info") | |