| """Churn prediction API — serves the registered champion pipeline from the MLflow Model Registry.""" |
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
| import sqlite3 |
| import time |
| from contextlib import asynccontextmanager |
| from datetime import datetime |
| from typing import Any, Dict, List, Literal, Optional |
|
|
| import mlflow |
| import mlflow.sklearn |
| import numpy as np |
| import pandas as pd |
| from fastapi import FastAPI, HTTPException, Query, Request, Response |
| from mlflow.tracking import MlflowClient |
| from prometheus_fastapi_instrumentator import Instrumentator |
| from pydantic import BaseModel, Field |
|
|
| from churn.config import settings |
| from churn.data import ALL_FEATURES |
|
|
| logger = logging.getLogger(__name__) |
|
|
| CHAMPION_MODEL_NAME = "customer-churn-xgboost" |
| CHAMPION_ALIAS = "champion" |
| _THRESHOLD_FALLBACK_PATH = "reports/threshold.json" |
|
|
| LOG_DB_PATH = os.getenv("LOG_DB_PATH", "logs/predictions.db") |
| PROMETHEUS_ENABLED = os.getenv("PROMETHEUS_ENABLED", "true").lower() in {"1", "true", "yes"} |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_champion_model(tracking_uri: Optional[str] = None) -> tuple[Any, float, str]: |
| """Load champion pipeline, threshold, and version from the MLflow Model Registry. |
| |
| Returns (model, threshold, version_str). Falls back to reports/threshold.json |
| if the registered version's threshold tag is absent. Raises on any failure; |
| the lifespan caller is responsible for error handling. |
| """ |
| uri = tracking_uri or settings.mlflow_tracking_uri |
| mlflow.set_tracking_uri(uri) |
| client = MlflowClient() |
|
|
| mv = client.get_model_version_by_alias(CHAMPION_MODEL_NAME, CHAMPION_ALIAS) |
| version = str(mv.version) |
| tag_val = mv.tags.get("threshold") |
| if tag_val is not None: |
| threshold = float(tag_val) |
| else: |
| with open(_THRESHOLD_FALLBACK_PATH) as f: |
| threshold = json.load(f)["threshold"] |
| logger.warning( |
| "threshold tag absent from %s v%s; using fallback %s", |
| CHAMPION_MODEL_NAME, version, _THRESHOLD_FALLBACK_PATH, |
| ) |
|
|
| model_uri = f"models:/{CHAMPION_MODEL_NAME}@{CHAMPION_ALIAS}" |
| model = mlflow.sklearn.load_model(model_uri) |
| print( |
| f"Loaded {CHAMPION_MODEL_NAME} v{version}" |
| f" (threshold={threshold:.4f}) from {model_uri}" |
| ) |
| return model, threshold, version |
|
|
|
|
| |
| |
| |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| try: |
| model, threshold, version = load_champion_model() |
| app.state.model = model |
| app.state.threshold = threshold |
| app.state.model_version = version |
| except Exception as exc: |
| logger.warning("Failed to load champion model at startup: %s", exc) |
| app.state.model = None |
| app.state.threshold = None |
| app.state.model_version = None |
| yield |
|
|
|
|
| app = FastAPI(title="Customer Churn Prediction API", lifespan=lifespan) |
|
|
| if PROMETHEUS_ENABLED: |
| Instrumentator().instrument(app).expose(app) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class PredictRequest(BaseModel): |
| |
| tenure: float = Field(ge=0.0, description="Months with the company") |
| MonthlyCharges: float = Field(ge=0.0, description="Current monthly bill") |
| TotalCharges: float = Field(ge=0.0, description="Total charges to date") |
|
|
| |
| SeniorCitizen: Literal[0, 1] |
|
|
| |
| gender: Literal["Female", "Male"] |
| Partner: Literal["No", "Yes"] |
| Dependents: Literal["No", "Yes"] |
| PhoneService: Literal["No", "Yes"] |
| MultipleLines: Literal["No", "No phone service", "Yes"] |
| InternetService: Literal["DSL", "Fiber optic", "No"] |
| OnlineSecurity: Literal["No", "No internet service", "Yes"] |
| OnlineBackup: Literal["No", "No internet service", "Yes"] |
| DeviceProtection: Literal["No", "No internet service", "Yes"] |
| TechSupport: Literal["No", "No internet service", "Yes"] |
| StreamingTV: Literal["No", "No internet service", "Yes"] |
| StreamingMovies: Literal["No", "No internet service", "Yes"] |
| Contract: Literal["Month-to-month", "One year", "Two year"] |
| PaperlessBilling: Literal["No", "Yes"] |
| PaymentMethod: Literal[ |
| "Bank transfer (automatic)", |
| "Credit card (automatic)", |
| "Electronic check", |
| "Mailed check", |
| ] |
|
|
|
|
| class PredictResponse(BaseModel): |
| churn_probability: float |
| churn_prediction: bool |
| threshold: float |
| model_version: str |
|
|
|
|
| class ExplainResponse(BaseModel): |
| |
| churn_probability: float |
| threshold: float |
| model_version: str |
| |
| risk_level: str |
| summary: str |
| key_factors: List[str] |
| recommended_action: str |
| citations: List[str] |
| |
| provider: str |
| ungrounded_factors: List[str] |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_features(payload: PredictRequest) -> pd.DataFrame: |
| """Build a single-row DataFrame matching the pipeline's expected raw input columns.""" |
| row: Dict[str, Any] = { |
| "tenure": float(payload.tenure), |
| "MonthlyCharges": float(payload.MonthlyCharges), |
| "TotalCharges": float(payload.TotalCharges), |
| |
| "gender": payload.gender, |
| "SeniorCitizen": str(payload.SeniorCitizen), |
| "Partner": payload.Partner, |
| "Dependents": payload.Dependents, |
| "PhoneService": payload.PhoneService, |
| "MultipleLines": payload.MultipleLines, |
| "InternetService": payload.InternetService, |
| "OnlineSecurity": payload.OnlineSecurity, |
| "OnlineBackup": payload.OnlineBackup, |
| "DeviceProtection": payload.DeviceProtection, |
| "TechSupport": payload.TechSupport, |
| "StreamingTV": payload.StreamingTV, |
| "StreamingMovies": payload.StreamingMovies, |
| "Contract": payload.Contract, |
| "PaperlessBilling": payload.PaperlessBilling, |
| "PaymentMethod": payload.PaymentMethod, |
| } |
| return pd.DataFrame([row], columns=ALL_FEATURES) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _get_db_connection() -> sqlite3.Connection: |
| db_dir = os.path.dirname(LOG_DB_PATH) |
| if db_dir: |
| os.makedirs(db_dir, exist_ok=True) |
| conn = sqlite3.connect(LOG_DB_PATH) |
| conn.execute( |
| """ |
| CREATE TABLE IF NOT EXISTS prediction_logs ( |
| id INTEGER PRIMARY KEY AUTOINCREMENT, |
| timestamp TEXT NOT NULL, |
| request_payload TEXT NOT NULL, |
| churn_probability REAL, |
| model_uri TEXT, |
| latency_ms REAL, |
| status TEXT NOT NULL, |
| error_message TEXT |
| ) |
| """ |
| ) |
| return conn |
|
|
|
|
| def log_prediction( |
| *, |
| request_payload: Dict[str, Any], |
| churn_probability: Optional[float], |
| model_uri: str, |
| latency_ms: float, |
| status: str, |
| error_message: Optional[str], |
| ) -> None: |
| """Persist a prediction event to SQLite. Failures are swallowed so the API never crashes.""" |
| try: |
| conn = _get_db_connection() |
| with conn: |
| conn.execute( |
| """ |
| INSERT INTO prediction_logs ( |
| timestamp, request_payload, churn_probability, |
| model_uri, latency_ms, status, error_message |
| ) VALUES (?, ?, ?, ?, ?, ?, ?) |
| """, |
| ( |
| datetime.utcnow().isoformat(timespec="seconds"), |
| json.dumps(request_payload), |
| float(churn_probability) if churn_probability is not None else None, |
| model_uri, |
| float(latency_ms), |
| status, |
| error_message, |
| ), |
| ) |
| except Exception: |
| logger.exception("Failed to log prediction event.") |
| finally: |
| try: |
| conn.close() |
| except Exception: |
| pass |
|
|
|
|
| def fetch_recent_logs(limit: int) -> List[Dict[str, Any]]: |
| try: |
| conn = _get_db_connection() |
| conn.row_factory = sqlite3.Row |
| cursor = conn.execute( |
| """ |
| SELECT timestamp, request_payload, churn_probability, |
| model_uri, latency_ms, status, error_message |
| FROM prediction_logs |
| ORDER BY id DESC LIMIT ? |
| """, |
| (int(limit),), |
| ) |
| rows = cursor.fetchall() |
| except Exception: |
| logger.exception("Failed to fetch recent prediction logs.") |
| return [] |
| finally: |
| try: |
| conn.close() |
| except Exception: |
| pass |
|
|
| logs: List[Dict[str, Any]] = [] |
| for row in rows: |
| try: |
| payload = json.loads(row["request_payload"]) |
| except Exception: |
| payload = {} |
| logs.append( |
| { |
| "timestamp": row["timestamp"], |
| "request_payload": payload, |
| "churn_probability": row["churn_probability"], |
| "model_uri": row["model_uri"], |
| "latency_ms": row["latency_ms"], |
| "status": row["status"], |
| "error_message": row["error_message"], |
| } |
| ) |
| return logs |
|
|
|
|
| def compute_stats(logs: List[Dict[str, Any]]) -> Dict[str, Any]: |
| total = len(logs) |
| success_count = sum(1 for log in logs if log.get("status") == "success") |
| failure_count = total - success_count |
| success_rate = float(success_count / total) if total > 0 else 0.0 |
|
|
| latencies = [ |
| float(log["latency_ms"]) for log in logs if log.get("latency_ms") is not None |
| ] |
| probabilities = [ |
| float(log["churn_probability"]) |
| for log in logs |
| if log.get("churn_probability") is not None |
| ] |
|
|
| if latencies: |
| arr = np.asarray(latencies, dtype="float64") |
| latency_p50 = float(np.percentile(arr, 50)) |
| latency_p95 = float(np.percentile(arr, 95)) |
| latency_avg = float(arr.mean()) |
| else: |
| latency_p50 = latency_p95 = latency_avg = 0.0 |
|
|
| avg_probability = float(np.mean(probabilities)) if probabilities else 0.0 |
| last_model_uri = logs[0].get("model_uri") if logs else None |
|
|
| return { |
| "count": total, |
| "success_count": success_count, |
| "failure_count": failure_count, |
| "success_rate": success_rate, |
| "latency_p50_ms": latency_p50, |
| "latency_p95_ms": latency_p95, |
| "latency_avg_ms": latency_avg, |
| "avg_churn_probability": avg_probability, |
| "last_model_uri": last_model_uri, |
| } |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _get_explanation( |
| features: Dict[str, Any], |
| calibrated_prob: float, |
| top_k: int = 5, |
| ) -> tuple: |
| """Call explain_prediction with the champion's calibrated probability. |
| |
| Defined here (not inside the endpoint) so tests can monkeypatch this |
| function without importing churn.genai at collection time. |
| """ |
| from churn.genai.explainer import explain_prediction |
|
|
| return explain_prediction( |
| features, top_k=top_k, _calibrated_probability=calibrated_prob |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @app.get("/health") |
| def health(request: Request) -> Dict[str, Any]: |
| model = getattr(request.app.state, "model", None) |
| version = getattr(request.app.state, "model_version", None) |
| return { |
| "status": "ok", |
| "model_loaded": model is not None, |
| "model_version": version, |
| } |
|
|
|
|
| @app.get("/stats") |
| def stats(limit: int = Query(100, ge=1, le=10_000)) -> Dict[str, Any]: |
| logs = fetch_recent_logs(limit) |
| return compute_stats(logs) |
|
|
|
|
| @app.get("/recent") |
| def recent(limit: int = Query(20, ge=1, le=1_000)) -> List[Dict[str, Any]]: |
| return fetch_recent_logs(limit) |
|
|
|
|
| @app.post("/predict", response_model=PredictResponse) |
| def predict( |
| request_data: PredictRequest, |
| request: Request, |
| response: Response, |
| ) -> PredictResponse: |
| model = getattr(request.app.state, "model", None) |
| threshold = getattr(request.app.state, "threshold", 0.5) |
| model_version = getattr(request.app.state, "model_version", "unknown") |
| model_uri = f"models:/{CHAMPION_MODEL_NAME}@{CHAMPION_ALIAS}" |
|
|
| if model is None: |
| raise HTTPException(status_code=503, detail="No model loaded.") |
|
|
| start_time = time.perf_counter() |
| probability: Optional[float] = None |
| status = "success" |
| error_message: Optional[str] = None |
|
|
| try: |
| features = build_features(request_data) |
| proba_arr = model.predict_proba(features) |
| probability = float(np.clip(proba_arr[0, 1], 0.0, 1.0)) |
| churn_prediction = bool(probability >= threshold) |
| except Exception as exc: |
| status = "fail" |
| error_message = str(exc) |
| logger.exception("Unexpected error in /predict.") |
| raise HTTPException( |
| status_code=500, detail="Unexpected error during prediction." |
| ) from exc |
| finally: |
| latency_ms = (time.perf_counter() - start_time) * 1000.0 |
| response.headers["X-Model-Latency-ms"] = f"{latency_ms:.3f}" |
| log_prediction( |
| request_payload=request_data.model_dump(), |
| churn_probability=probability, |
| model_uri=model_uri, |
| latency_ms=latency_ms, |
| status=status, |
| error_message=error_message, |
| ) |
|
|
| return PredictResponse( |
| churn_probability=probability, |
| churn_prediction=churn_prediction, |
| threshold=threshold, |
| model_version=str(model_version), |
| ) |
|
|
|
|
| @app.post("/explain", response_model=ExplainResponse) |
| def explain( |
| request_data: PredictRequest, |
| request: Request, |
| ) -> ExplainResponse: |
| """Generate a grounded LLM explanation for a churn prediction. |
| |
| Probability and threshold come from the champion model (same as /predict). |
| The SHAP drivers and narrative come from the explanation layer. |
| risk_level is derived from the champion's calibrated probability so it is |
| consistent with the /predict result for the same input. |
| |
| Always returns 200 — falls back to a deterministic explanation on any |
| LLM or SHAP failure. Returns 503 only when no champion model is loaded. |
| """ |
| model = getattr(request.app.state, "model", None) |
| threshold = getattr(request.app.state, "threshold", 0.5) |
| model_version = getattr(request.app.state, "model_version", "unknown") |
|
|
| if model is None: |
| raise HTTPException(status_code=503, detail="No model loaded.") |
|
|
| features_df = build_features(request_data) |
| calibrated_prob = float(np.clip(model.predict_proba(features_df)[0, 1], 0.0, 1.0)) |
| features_dict = features_df.iloc[0].to_dict() |
|
|
| |
| _risk: str = "high" if calibrated_prob > 0.5 else "medium" if calibrated_prob > 0.3 else "low" |
|
|
| try: |
| expl, meta = _get_explanation(features_dict, calibrated_prob) |
| return ExplainResponse( |
| churn_probability=calibrated_prob, |
| threshold=float(threshold), |
| model_version=str(model_version), |
| risk_level=expl.risk_level, |
| summary=expl.summary, |
| key_factors=expl.key_factors, |
| recommended_action=expl.recommended_action, |
| citations=expl.citations, |
| provider=meta.get("provider", "fallback"), |
| ungrounded_factors=meta.get("ungrounded_factors", []), |
| ) |
| except Exception: |
| logger.exception("/explain: explanation generation failed; returning deterministic fallback.") |
| return ExplainResponse( |
| churn_probability=calibrated_prob, |
| threshold=float(threshold), |
| model_version=str(model_version), |
| risk_level=_risk, |
| summary="Detailed explanation unavailable.", |
| key_factors=[], |
| recommended_action="Contact the customer for a proactive check-in.", |
| citations=[], |
| provider="fallback", |
| ungrounded_factors=[], |
| ) |
|
|