File size: 3,130 Bytes
2ae10e0
 
 
 
 
e3ca372
 
 
 
 
2ae10e0
 
 
 
 
 
 
e3ca372
2ae10e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
import numpy as np

from core.schemas import RetrievalConfig, ExplainResponse
from core.model_loader import load_model
from core.explain import explain_instance
from core.retrieval import retrieve_topk
from core.utils import safe_proba_to_scalar

app = FastAPI(title="Reasoning-RAG XAI API")

# Cached globals
MODEL = None
FEATURE_NAMES: Optional[List[str]] = None
BACKGROUND = None
NAMESPACE = "data/base_indices/iris_global"

# --- Target name mappings (extend per dataset/model) ---
TARGET_NAMES = {
    "iris": ["setosa", "versicolor", "virginica"],
    # add more datasets here if needed
}


class ExplainRequest(BaseModel):
    model_path: str
    feature_names: List[str]
    features: List[float]
    namespace: Optional[str] = None
    retrieval: Optional[RetrievalConfig] = None
    background_path: Optional[str] = None


@app.post("/explain", response_model=ExplainResponse)
def explain(req: ExplainRequest):
    global MODEL, FEATURE_NAMES, BACKGROUND

    # Load model if not cached
    if (MODEL is None) or (FEATURE_NAMES != req.feature_names):
        MODEL = load_model(req.model_path)
        FEATURE_NAMES = req.feature_names
        BACKGROUND = None  # optionally load background data

    # Convert input features
    x = np.asarray(req.features, dtype="float32").reshape(1, -1)

    # Prediction & probability
    y_class = 0
    proba_scalar = None
    try:
        y_pred = MODEL.predict(x)
        y_class = int(y_pred[0])

        if hasattr(MODEL, "predict_proba"):
            proba = MODEL.predict_proba(x)
            proba_scalar = float(proba[0][y_class])
    except Exception as e:
        print("Prediction error:", e)

    # --- Map class ID -> human-readable label ---
    model_key = "iris" if "iris" in req.model_path.lower() else None
    if model_key and model_key in TARGET_NAMES:
        y_label = TARGET_NAMES[model_key][y_class]
    else:
        y_label = str(y_class)

    # SHAP explanation
    exp = explain_instance(
        MODEL,
        x[0],
        FEATURE_NAMES,
        background_X=(BACKGROUND if BACKGROUND is not None else x),
    )

    # Retrieval
    similar = None
    ns = req.namespace or NAMESPACE
    if req.retrieval and req.retrieval.use_retrieval:
        shap_q = np.array(exp["shap_values"], dtype="float32")
        similar = retrieve_topk(ns, shap_q, x[0], alpha=req.retrieval.alpha, k=req.retrieval.k)

        # also map labels for retrieved cases
        if model_key and model_key in TARGET_NAMES:
            for case in similar:
                if case.get("y_pred") is not None:
                    try:
                        case["y_pred"] = TARGET_NAMES[model_key][int(case["y_pred"])]
                    except Exception:
                        case["y_pred"] = str(case["y_pred"])

    return ExplainResponse(
        prediction={
            "y_pred": y_label,     # now returns "setosa", "versicolor", etc.
            "proba": proba_scalar,
        },
        explanation=exp,
        similar_cases=similar or [],
        ood_flag=False
    )