EchoML / app /api_fastapi.py
tiffany101's picture
Update app/api_fastapi.py
e3ca372 verified
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
import numpy as np
from core.schemas import RetrievalConfig, ExplainResponse
from core.model_loader import load_model
from core.explain import explain_instance
from core.retrieval import retrieve_topk
from core.utils import safe_proba_to_scalar
app = FastAPI(title="Reasoning-RAG XAI API")
# Cached globals
MODEL = None
FEATURE_NAMES: Optional[List[str]] = None
BACKGROUND = None
NAMESPACE = "data/base_indices/iris_global"
# --- Target name mappings (extend per dataset/model) ---
TARGET_NAMES = {
"iris": ["setosa", "versicolor", "virginica"],
# add more datasets here if needed
}
class ExplainRequest(BaseModel):
model_path: str
feature_names: List[str]
features: List[float]
namespace: Optional[str] = None
retrieval: Optional[RetrievalConfig] = None
background_path: Optional[str] = None
@app.post("/explain", response_model=ExplainResponse)
def explain(req: ExplainRequest):
global MODEL, FEATURE_NAMES, BACKGROUND
# Load model if not cached
if (MODEL is None) or (FEATURE_NAMES != req.feature_names):
MODEL = load_model(req.model_path)
FEATURE_NAMES = req.feature_names
BACKGROUND = None # optionally load background data
# Convert input features
x = np.asarray(req.features, dtype="float32").reshape(1, -1)
# Prediction & probability
y_class = 0
proba_scalar = None
try:
y_pred = MODEL.predict(x)
y_class = int(y_pred[0])
if hasattr(MODEL, "predict_proba"):
proba = MODEL.predict_proba(x)
proba_scalar = float(proba[0][y_class])
except Exception as e:
print("Prediction error:", e)
# --- Map class ID -> human-readable label ---
model_key = "iris" if "iris" in req.model_path.lower() else None
if model_key and model_key in TARGET_NAMES:
y_label = TARGET_NAMES[model_key][y_class]
else:
y_label = str(y_class)
# SHAP explanation
exp = explain_instance(
MODEL,
x[0],
FEATURE_NAMES,
background_X=(BACKGROUND if BACKGROUND is not None else x),
)
# Retrieval
similar = None
ns = req.namespace or NAMESPACE
if req.retrieval and req.retrieval.use_retrieval:
shap_q = np.array(exp["shap_values"], dtype="float32")
similar = retrieve_topk(ns, shap_q, x[0], alpha=req.retrieval.alpha, k=req.retrieval.k)
# also map labels for retrieved cases
if model_key and model_key in TARGET_NAMES:
for case in similar:
if case.get("y_pred") is not None:
try:
case["y_pred"] = TARGET_NAMES[model_key][int(case["y_pred"])]
except Exception:
case["y_pred"] = str(case["y_pred"])
return ExplainResponse(
prediction={
"y_pred": y_label, # now returns "setosa", "versicolor", etc.
"proba": proba_scalar,
},
explanation=exp,
similar_cases=similar or [],
ood_flag=False
)