Auto_ML / backend /api /routes /explain.py
abhiraj12's picture
Streamline export bundle by removing auxiliary files
807485b
"""api/routes/explain.py — SHAP explainability endpoints (Feature 3)."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import Dict, Any
import json
from infra.database import get_db, JobModel
from infra.result_contract import normalize_results
router = APIRouter(prefix="/api", tags=["explainability"])
class ExplainRequest(BaseModel):
features: Dict[str, Any]
@router.get("/shap/{job_id}")
def get_shap_summary(job_id: str):
"""Global SHAP feature importance for a completed job."""
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or not job.results_json:
raise HTTPException(status_code=404, detail="Job not found or not completed")
try:
results = normalize_results(json.loads(job.results_json))
except Exception:
results = {}
from services.explain_service import get_global_shap
return get_global_shap(job_id, results)
@router.post("/explain/{job_id}")
def explain_prediction(job_id: str, req: ExplainRequest):
"""
Feature 3: Per-prediction local SHAP explanation.
Returns feature contributions, base value, and the prediction.
"""
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or job.status != "completed":
raise HTTPException(status_code=404, detail="Job not completed")
try:
results = normalize_results(json.loads(job.results_json) if job.results_json else {})
except Exception:
results = {}
from services.explain_service import explain_local
return explain_local(job_id, results, req.features)
@router.get("/pipeline/{job_id}")
def pipeline_graph(job_id: str):
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or not job.results_json:
return {"error": "Job not completed"}
from infra.database import DatasetModel
dataset = db.query(DatasetModel).filter(DatasetModel.id == job.dataset_id).first()
try:
profile = json.loads(dataset.profile_json) if dataset and dataset.profile_json else {}
except Exception:
profile = {}
try:
results = normalize_results(json.loads(job.results_json))
except Exception:
results = {}
from core.debugger import generate_pipeline_graph
return {"mermaid": generate_pipeline_graph(profile, results)}
@router.get("/lineage/{job_id}")
def feature_lineage(job_id: str):
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or not job.results_json:
return {"error": "Job not completed"}
try:
results = normalize_results(json.loads(job.results_json))
except Exception:
results = {}
from services.explain_service import get_feature_lineage
return get_feature_lineage(job_id, results)
@router.get("/calibration/{job_id}")
def calibration_report(job_id: str):
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or not job.results_json:
return {"error": "Job not completed"}
try:
results = normalize_results(json.loads(job.results_json))
except Exception:
results = {}
from services.explain_service import get_calibration_report
return get_calibration_report(job_id, results)
@router.get("/thresholds/{job_id}")
def threshold_report(job_id: str):
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or not job.results_json:
return {"error": "Job not completed"}
try:
results = normalize_results(json.loads(job.results_json))
except Exception:
results = {}
from services.explain_service import get_threshold_tuning
return get_threshold_tuning(job_id, results)
@router.post("/counterfactual/{job_id}")
def counterfactual(job_id: str, req: ExplainRequest):
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or not job.results_json:
return {"error": "Job not completed"}
try:
results = normalize_results(json.loads(job.results_json))
except Exception:
results = {}
from services.explain_service import generate_counterfactual
return generate_counterfactual(job_id, results, req.features)
@router.get("/trust/{job_id}")
def trust_heatmap(job_id: str):
with get_db() as db:
job = db.query(JobModel).filter(JobModel.id == job_id).first()
if not job or not job.results_json:
return {"error": "Job not completed"}
dataset_id = job.dataset_id
try:
results = normalize_results(json.loads(job.results_json))
except Exception:
results = {}
from services.studio_service import build_trust_heatmap
return build_trust_heatmap(dataset_id, results)