"""api/routes/explain.py — SHAP explainability endpoints (Feature 3).""" from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import Dict, Any import json from infra.database import get_db, JobModel from infra.result_contract import normalize_results router = APIRouter(prefix="/api", tags=["explainability"]) class ExplainRequest(BaseModel): features: Dict[str, Any] @router.get("/shap/{job_id}") def get_shap_summary(job_id: str): """Global SHAP feature importance for a completed job.""" with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or not job.results_json: raise HTTPException(status_code=404, detail="Job not found or not completed") try: results = normalize_results(json.loads(job.results_json)) except Exception: results = {} from services.explain_service import get_global_shap return get_global_shap(job_id, results) @router.post("/explain/{job_id}") def explain_prediction(job_id: str, req: ExplainRequest): """ Feature 3: Per-prediction local SHAP explanation. Returns feature contributions, base value, and the prediction. """ with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or job.status != "completed": raise HTTPException(status_code=404, detail="Job not completed") try: results = normalize_results(json.loads(job.results_json) if job.results_json else {}) except Exception: results = {} from services.explain_service import explain_local return explain_local(job_id, results, req.features) @router.get("/pipeline/{job_id}") def pipeline_graph(job_id: str): with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or not job.results_json: return {"error": "Job not completed"} from infra.database import DatasetModel dataset = db.query(DatasetModel).filter(DatasetModel.id == job.dataset_id).first() try: profile = json.loads(dataset.profile_json) if dataset and dataset.profile_json else {} except Exception: profile = {} try: results = normalize_results(json.loads(job.results_json)) except Exception: results = {} from core.debugger import generate_pipeline_graph return {"mermaid": generate_pipeline_graph(profile, results)} @router.get("/lineage/{job_id}") def feature_lineage(job_id: str): with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or not job.results_json: return {"error": "Job not completed"} try: results = normalize_results(json.loads(job.results_json)) except Exception: results = {} from services.explain_service import get_feature_lineage return get_feature_lineage(job_id, results) @router.get("/calibration/{job_id}") def calibration_report(job_id: str): with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or not job.results_json: return {"error": "Job not completed"} try: results = normalize_results(json.loads(job.results_json)) except Exception: results = {} from services.explain_service import get_calibration_report return get_calibration_report(job_id, results) @router.get("/thresholds/{job_id}") def threshold_report(job_id: str): with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or not job.results_json: return {"error": "Job not completed"} try: results = normalize_results(json.loads(job.results_json)) except Exception: results = {} from services.explain_service import get_threshold_tuning return get_threshold_tuning(job_id, results) @router.post("/counterfactual/{job_id}") def counterfactual(job_id: str, req: ExplainRequest): with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or not job.results_json: return {"error": "Job not completed"} try: results = normalize_results(json.loads(job.results_json)) except Exception: results = {} from services.explain_service import generate_counterfactual return generate_counterfactual(job_id, results, req.features) @router.get("/trust/{job_id}") def trust_heatmap(job_id: str): with get_db() as db: job = db.query(JobModel).filter(JobModel.id == job_id).first() if not job or not job.results_json: return {"error": "Job not completed"} dataset_id = job.dataset_id try: results = normalize_results(json.loads(job.results_json)) except Exception: results = {} from services.studio_service import build_trust_heatmap return build_trust_heatmap(dataset_id, results)