churn-api / api /main.py
BrejBala's picture
deploy: customer-churn-mlops API Space (tier3-deployment)
482a2b0
Raw
History Blame Contribute Delete
17.4 kB
"""Churn prediction API — serves the registered champion pipeline from the MLflow Model Registry."""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import time
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
import mlflow
import mlflow.sklearn
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException, Query, Request, Response
from mlflow.tracking import MlflowClient
from prometheus_fastapi_instrumentator import Instrumentator
from pydantic import BaseModel, Field
from churn.config import settings
from churn.data import ALL_FEATURES
logger = logging.getLogger(__name__)
CHAMPION_MODEL_NAME = "customer-churn-xgboost"
CHAMPION_ALIAS = "champion"
_THRESHOLD_FALLBACK_PATH = "reports/threshold.json"
LOG_DB_PATH = os.getenv("LOG_DB_PATH", "logs/predictions.db")
PROMETHEUS_ENABLED = os.getenv("PROMETHEUS_ENABLED", "true").lower() in {"1", "true", "yes"}
# ---------------------------------------------------------------------------
# Registry loader — injectable in tests via monkeypatch
# ---------------------------------------------------------------------------
def load_champion_model(tracking_uri: Optional[str] = None) -> tuple[Any, float, str]:
"""Load champion pipeline, threshold, and version from the MLflow Model Registry.
Returns (model, threshold, version_str). Falls back to reports/threshold.json
if the registered version's threshold tag is absent. Raises on any failure;
the lifespan caller is responsible for error handling.
"""
uri = tracking_uri or settings.mlflow_tracking_uri
mlflow.set_tracking_uri(uri)
client = MlflowClient()
mv = client.get_model_version_by_alias(CHAMPION_MODEL_NAME, CHAMPION_ALIAS)
version = str(mv.version)
tag_val = mv.tags.get("threshold")
if tag_val is not None:
threshold = float(tag_val)
else:
with open(_THRESHOLD_FALLBACK_PATH) as f:
threshold = json.load(f)["threshold"]
logger.warning(
"threshold tag absent from %s v%s; using fallback %s",
CHAMPION_MODEL_NAME, version, _THRESHOLD_FALLBACK_PATH,
)
model_uri = f"models:/{CHAMPION_MODEL_NAME}@{CHAMPION_ALIAS}"
model = mlflow.sklearn.load_model(model_uri)
print(
f"Loaded {CHAMPION_MODEL_NAME} v{version}"
f" (threshold={threshold:.4f}) from {model_uri}"
)
return model, threshold, version
# ---------------------------------------------------------------------------
# Lifespan
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
model, threshold, version = load_champion_model()
app.state.model = model
app.state.threshold = threshold
app.state.model_version = version
except Exception as exc: # noqa: BLE001
logger.warning("Failed to load champion model at startup: %s", exc)
app.state.model = None
app.state.threshold = None
app.state.model_version = None
yield
app = FastAPI(title="Customer Churn Prediction API", lifespan=lifespan)
if PROMETHEUS_ENABLED:
Instrumentator().instrument(app).expose(app)
# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class PredictRequest(BaseModel):
# Numeric features — pipeline expects float64
tenure: float = Field(ge=0.0, description="Months with the company")
MonthlyCharges: float = Field(ge=0.0, description="Current monthly bill")
TotalCharges: float = Field(ge=0.0, description="Total charges to date")
# SeniorCitizen is 0/1 in the raw data; pipeline stores it as string "0"/"1"
SeniorCitizen: Literal[0, 1]
# Categorical features — values must match the Telco training vocabulary exactly
gender: Literal["Female", "Male"]
Partner: Literal["No", "Yes"]
Dependents: Literal["No", "Yes"]
PhoneService: Literal["No", "Yes"]
MultipleLines: Literal["No", "No phone service", "Yes"]
InternetService: Literal["DSL", "Fiber optic", "No"]
OnlineSecurity: Literal["No", "No internet service", "Yes"]
OnlineBackup: Literal["No", "No internet service", "Yes"]
DeviceProtection: Literal["No", "No internet service", "Yes"]
TechSupport: Literal["No", "No internet service", "Yes"]
StreamingTV: Literal["No", "No internet service", "Yes"]
StreamingMovies: Literal["No", "No internet service", "Yes"]
Contract: Literal["Month-to-month", "One year", "Two year"]
PaperlessBilling: Literal["No", "Yes"]
PaymentMethod: Literal[
"Bank transfer (automatic)",
"Credit card (automatic)",
"Electronic check",
"Mailed check",
]
class PredictResponse(BaseModel):
churn_probability: float
churn_prediction: bool
threshold: float
model_version: str
class ExplainResponse(BaseModel):
# Champion-consistent fields — identical to /predict for the same input
churn_probability: float
threshold: float
model_version: str
# Explanation fields from ChurnExplanation
risk_level: str
summary: str
key_factors: List[str]
recommended_action: str
citations: List[str]
# Observability
provider: str
ungrounded_factors: List[str]
# ---------------------------------------------------------------------------
# Feature builder
# ---------------------------------------------------------------------------
def build_features(payload: PredictRequest) -> pd.DataFrame:
"""Build a single-row DataFrame matching the pipeline's expected raw input columns."""
row: Dict[str, Any] = {
"tenure": float(payload.tenure),
"MonthlyCharges": float(payload.MonthlyCharges),
"TotalCharges": float(payload.TotalCharges),
# SeniorCitizen: pipeline was trained on "0"/"1" strings (clean_telco casts all cats)
"gender": payload.gender,
"SeniorCitizen": str(payload.SeniorCitizen),
"Partner": payload.Partner,
"Dependents": payload.Dependents,
"PhoneService": payload.PhoneService,
"MultipleLines": payload.MultipleLines,
"InternetService": payload.InternetService,
"OnlineSecurity": payload.OnlineSecurity,
"OnlineBackup": payload.OnlineBackup,
"DeviceProtection": payload.DeviceProtection,
"TechSupport": payload.TechSupport,
"StreamingTV": payload.StreamingTV,
"StreamingMovies": payload.StreamingMovies,
"Contract": payload.Contract,
"PaperlessBilling": payload.PaperlessBilling,
"PaymentMethod": payload.PaymentMethod,
}
return pd.DataFrame([row], columns=ALL_FEATURES)
# ---------------------------------------------------------------------------
# Prediction logging
# ---------------------------------------------------------------------------
def _get_db_connection() -> sqlite3.Connection:
db_dir = os.path.dirname(LOG_DB_PATH)
if db_dir:
os.makedirs(db_dir, exist_ok=True)
conn = sqlite3.connect(LOG_DB_PATH)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS prediction_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
request_payload TEXT NOT NULL,
churn_probability REAL,
model_uri TEXT,
latency_ms REAL,
status TEXT NOT NULL,
error_message TEXT
)
"""
)
return conn
def log_prediction(
*,
request_payload: Dict[str, Any],
churn_probability: Optional[float],
model_uri: str,
latency_ms: float,
status: str,
error_message: Optional[str],
) -> None:
"""Persist a prediction event to SQLite. Failures are swallowed so the API never crashes."""
try:
conn = _get_db_connection()
with conn:
conn.execute(
"""
INSERT INTO prediction_logs (
timestamp, request_payload, churn_probability,
model_uri, latency_ms, status, error_message
) VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
datetime.utcnow().isoformat(timespec="seconds"),
json.dumps(request_payload),
float(churn_probability) if churn_probability is not None else None,
model_uri,
float(latency_ms),
status,
error_message,
),
)
except Exception: # noqa: BLE001
logger.exception("Failed to log prediction event.")
finally:
try:
conn.close()
except Exception: # noqa: BLE001
pass
def fetch_recent_logs(limit: int) -> List[Dict[str, Any]]:
try:
conn = _get_db_connection()
conn.row_factory = sqlite3.Row
cursor = conn.execute(
"""
SELECT timestamp, request_payload, churn_probability,
model_uri, latency_ms, status, error_message
FROM prediction_logs
ORDER BY id DESC LIMIT ?
""",
(int(limit),),
)
rows = cursor.fetchall()
except Exception: # noqa: BLE001
logger.exception("Failed to fetch recent prediction logs.")
return []
finally:
try:
conn.close()
except Exception: # noqa: BLE001
pass
logs: List[Dict[str, Any]] = []
for row in rows:
try:
payload = json.loads(row["request_payload"])
except Exception: # noqa: BLE001
payload = {}
logs.append(
{
"timestamp": row["timestamp"],
"request_payload": payload,
"churn_probability": row["churn_probability"],
"model_uri": row["model_uri"],
"latency_ms": row["latency_ms"],
"status": row["status"],
"error_message": row["error_message"],
}
)
return logs
def compute_stats(logs: List[Dict[str, Any]]) -> Dict[str, Any]:
total = len(logs)
success_count = sum(1 for log in logs if log.get("status") == "success")
failure_count = total - success_count
success_rate = float(success_count / total) if total > 0 else 0.0
latencies = [
float(log["latency_ms"]) for log in logs if log.get("latency_ms") is not None
]
probabilities = [
float(log["churn_probability"])
for log in logs
if log.get("churn_probability") is not None
]
if latencies:
arr = np.asarray(latencies, dtype="float64")
latency_p50 = float(np.percentile(arr, 50))
latency_p95 = float(np.percentile(arr, 95))
latency_avg = float(arr.mean())
else:
latency_p50 = latency_p95 = latency_avg = 0.0
avg_probability = float(np.mean(probabilities)) if probabilities else 0.0
last_model_uri = logs[0].get("model_uri") if logs else None
return {
"count": total,
"success_count": success_count,
"failure_count": failure_count,
"success_rate": success_rate,
"latency_p50_ms": latency_p50,
"latency_p95_ms": latency_p95,
"latency_avg_ms": latency_avg,
"avg_churn_probability": avg_probability,
"last_model_uri": last_model_uri,
}
# ---------------------------------------------------------------------------
# Explanation helper — defined at module level so tests can monkeypatch it
# ---------------------------------------------------------------------------
def _get_explanation(
features: Dict[str, Any],
calibrated_prob: float,
top_k: int = 5,
) -> tuple:
"""Call explain_prediction with the champion's calibrated probability.
Defined here (not inside the endpoint) so tests can monkeypatch this
function without importing churn.genai at collection time.
"""
from churn.genai.explainer import explain_prediction # noqa: PLC0415
return explain_prediction(
features, top_k=top_k, _calibrated_probability=calibrated_prob
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health(request: Request) -> Dict[str, Any]:
model = getattr(request.app.state, "model", None)
version = getattr(request.app.state, "model_version", None)
return {
"status": "ok",
"model_loaded": model is not None,
"model_version": version,
}
@app.get("/stats")
def stats(limit: int = Query(100, ge=1, le=10_000)) -> Dict[str, Any]:
logs = fetch_recent_logs(limit)
return compute_stats(logs)
@app.get("/recent")
def recent(limit: int = Query(20, ge=1, le=1_000)) -> List[Dict[str, Any]]:
return fetch_recent_logs(limit)
@app.post("/predict", response_model=PredictResponse)
def predict(
request_data: PredictRequest,
request: Request,
response: Response,
) -> PredictResponse:
model = getattr(request.app.state, "model", None)
threshold = getattr(request.app.state, "threshold", 0.5)
model_version = getattr(request.app.state, "model_version", "unknown")
model_uri = f"models:/{CHAMPION_MODEL_NAME}@{CHAMPION_ALIAS}"
if model is None:
raise HTTPException(status_code=503, detail="No model loaded.")
start_time = time.perf_counter()
probability: Optional[float] = None
status = "success"
error_message: Optional[str] = None
try:
features = build_features(request_data)
proba_arr = model.predict_proba(features)
probability = float(np.clip(proba_arr[0, 1], 0.0, 1.0))
churn_prediction = bool(probability >= threshold)
except Exception as exc: # noqa: BLE001
status = "fail"
error_message = str(exc)
logger.exception("Unexpected error in /predict.")
raise HTTPException(
status_code=500, detail="Unexpected error during prediction."
) from exc
finally:
latency_ms = (time.perf_counter() - start_time) * 1000.0
response.headers["X-Model-Latency-ms"] = f"{latency_ms:.3f}"
log_prediction(
request_payload=request_data.model_dump(),
churn_probability=probability,
model_uri=model_uri,
latency_ms=latency_ms,
status=status,
error_message=error_message,
)
return PredictResponse(
churn_probability=probability,
churn_prediction=churn_prediction,
threshold=threshold,
model_version=str(model_version),
)
@app.post("/explain", response_model=ExplainResponse)
def explain(
request_data: PredictRequest,
request: Request,
) -> ExplainResponse:
"""Generate a grounded LLM explanation for a churn prediction.
Probability and threshold come from the champion model (same as /predict).
The SHAP drivers and narrative come from the explanation layer.
risk_level is derived from the champion's calibrated probability so it is
consistent with the /predict result for the same input.
Always returns 200 — falls back to a deterministic explanation on any
LLM or SHAP failure. Returns 503 only when no champion model is loaded.
"""
model = getattr(request.app.state, "model", None)
threshold = getattr(request.app.state, "threshold", 0.5)
model_version = getattr(request.app.state, "model_version", "unknown")
if model is None:
raise HTTPException(status_code=503, detail="No model loaded.")
features_df = build_features(request_data)
calibrated_prob = float(np.clip(model.predict_proba(features_df)[0, 1], 0.0, 1.0))
features_dict = features_df.iloc[0].to_dict()
# Derive risk_level from calibrated probability so headline is consistent.
_risk: str = "high" if calibrated_prob > 0.5 else "medium" if calibrated_prob > 0.3 else "low"
try:
expl, meta = _get_explanation(features_dict, calibrated_prob)
return ExplainResponse(
churn_probability=calibrated_prob,
threshold=float(threshold),
model_version=str(model_version),
risk_level=expl.risk_level,
summary=expl.summary,
key_factors=expl.key_factors,
recommended_action=expl.recommended_action,
citations=expl.citations,
provider=meta.get("provider", "fallback"),
ungrounded_factors=meta.get("ungrounded_factors", []),
)
except Exception:
logger.exception("/explain: explanation generation failed; returning deterministic fallback.")
return ExplainResponse(
churn_probability=calibrated_prob,
threshold=float(threshold),
model_version=str(model_version),
risk_level=_risk,
summary="Detailed explanation unavailable.",
key_factors=[],
recommended_action="Contact the customer for a proactive check-in.",
citations=[],
provider="fallback",
ungrounded_factors=[],
)