Tiffany Degbotse commited on
Commit
2ae10e0
·
1 Parent(s): 6e3858e

query with your model

Browse files
app/__pycache__/api_fastapi.cpython-313.pyc ADDED
Binary file (4.17 kB). View file
 
app/api_fastapi.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import Optional, List
4
+ import numpy as np
5
+
6
+ from Query_Your_Model.core.schemas import RetrievalConfig, ExplainResponse
7
+ from Query_Your_Model.core.model_loader import load_model
8
+ from Query_Your_Model.core.explain import explain_instance
9
+ from Query_Your_Model.core.retrieval import retrieve_topk
10
+ from Query_Your_Model.core.utils import safe_proba_to_scalar
11
+
12
+ app = FastAPI(title="Reasoning-RAG XAI API")
13
+
14
+ # Cached globals
15
+ MODEL = None
16
+ FEATURE_NAMES: Optional[List[str]] = None
17
+ BACKGROUND = None
18
+ NAMESPACE = "Query_Your_Model/data/base_indices/iris_global"
19
+
20
+ # --- Target name mappings (extend per dataset/model) ---
21
+ TARGET_NAMES = {
22
+ "iris": ["setosa", "versicolor", "virginica"],
23
+ # add more datasets here if needed
24
+ }
25
+
26
+
27
+ class ExplainRequest(BaseModel):
28
+ model_path: str
29
+ feature_names: List[str]
30
+ features: List[float]
31
+ namespace: Optional[str] = None
32
+ retrieval: Optional[RetrievalConfig] = None
33
+ background_path: Optional[str] = None
34
+
35
+
36
+ @app.post("/explain", response_model=ExplainResponse)
37
+ def explain(req: ExplainRequest):
38
+ global MODEL, FEATURE_NAMES, BACKGROUND
39
+
40
+ # Load model if not cached
41
+ if (MODEL is None) or (FEATURE_NAMES != req.feature_names):
42
+ MODEL = load_model(req.model_path)
43
+ FEATURE_NAMES = req.feature_names
44
+ BACKGROUND = None # optionally load background data
45
+
46
+ # Convert input features
47
+ x = np.asarray(req.features, dtype="float32").reshape(1, -1)
48
+
49
+ # Prediction & probability
50
+ y_class = 0
51
+ proba_scalar = None
52
+ try:
53
+ y_pred = MODEL.predict(x)
54
+ y_class = int(y_pred[0])
55
+
56
+ if hasattr(MODEL, "predict_proba"):
57
+ proba = MODEL.predict_proba(x)
58
+ proba_scalar = float(proba[0][y_class])
59
+ except Exception as e:
60
+ print("Prediction error:", e)
61
+
62
+ # --- Map class ID -> human-readable label ---
63
+ model_key = "iris" if "iris" in req.model_path.lower() else None
64
+ if model_key and model_key in TARGET_NAMES:
65
+ y_label = TARGET_NAMES[model_key][y_class]
66
+ else:
67
+ y_label = str(y_class)
68
+
69
+ # SHAP explanation
70
+ exp = explain_instance(
71
+ MODEL,
72
+ x[0],
73
+ FEATURE_NAMES,
74
+ background_X=(BACKGROUND if BACKGROUND is not None else x),
75
+ )
76
+
77
+ # Retrieval
78
+ similar = None
79
+ ns = req.namespace or NAMESPACE
80
+ if req.retrieval and req.retrieval.use_retrieval:
81
+ shap_q = np.array(exp["shap_values"], dtype="float32")
82
+ similar = retrieve_topk(ns, shap_q, x[0], alpha=req.retrieval.alpha, k=req.retrieval.k)
83
+
84
+ # also map labels for retrieved cases
85
+ if model_key and model_key in TARGET_NAMES:
86
+ for case in similar:
87
+ if case.get("y_pred") is not None:
88
+ try:
89
+ case["y_pred"] = TARGET_NAMES[model_key][int(case["y_pred"])]
90
+ except Exception:
91
+ case["y_pred"] = str(case["y_pred"])
92
+
93
+ return ExplainResponse(
94
+ prediction={
95
+ "y_pred": y_label, # now returns "setosa", "versicolor", etc.
96
+ "proba": proba_scalar,
97
+ },
98
+ explanation=exp,
99
+ similar_cases=similar or [],
100
+ ood_flag=False
101
+ )
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (183 Bytes). View file
 
core/__pycache__/explain.cpython-313.pyc ADDED
Binary file (4.5 kB). View file
 
core/__pycache__/model_loader.cpython-313.pyc ADDED
Binary file (1.24 kB). View file
 
core/__pycache__/retrieval.cpython-313.pyc ADDED
Binary file (3.52 kB). View file
 
core/__pycache__/schemas.cpython-313.pyc ADDED
Binary file (2.6 kB). View file
 
core/__pycache__/storage.cpython-313.pyc ADDED
Binary file (6.2 kB). View file
 
core/__pycache__/utils.cpython-313.pyc ADDED
Binary file (1.47 kB). View file
 
core/explain.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+ import numpy as np
3
+ import shap
4
+
5
+
6
+ def _pick_explainer(model, X_background: np.ndarray):
7
+ """
8
+ Choose an appropriate SHAP explainer.
9
+ - TreeExplainer for tree-based models
10
+ - LinearExplainer for linear models
11
+ - KernelExplainer fallback (slow but general)
12
+ """
13
+ try:
14
+ import xgboost # noqa: F401
15
+ is_tree = hasattr(model, "get_booster") or "xgb" in type(model).__name__.lower()
16
+ except Exception:
17
+ is_tree = False
18
+
19
+ is_tree = is_tree or any(
20
+ s in type(model).__name__.lower()
21
+ for s in ["randomforest", "gradientboost", "gbm", "lightgbm", "catboost"]
22
+ )
23
+
24
+ if is_tree:
25
+ return shap.TreeExplainer(model, feature_perturbation="tree_path_dependent")
26
+
27
+ is_linear = "linear" in type(model).__name__.lower() or hasattr(model, "coef_")
28
+ if is_linear:
29
+ return shap.LinearExplainer(model, X_background)
30
+
31
+ # Fallback for anything else
32
+ return shap.KernelExplainer(model.predict, X_background)
33
+
34
+
35
+ def explain_instance(
36
+ model,
37
+ x: np.ndarray,
38
+ feature_names: List[str],
39
+ background_X: np.ndarray,
40
+ top_k: int = 8,
41
+ ) -> Dict[str, Any]:
42
+ """
43
+ Compute SHAP for a single instance x (shape: (n_features,)).
44
+ Always reduces SHAP output to a vector of length = n_features.
45
+ Handles multiclass by averaging across classes.
46
+ """
47
+ x = x.reshape(1, -1)
48
+ explainer = _pick_explainer(model, background_X)
49
+
50
+ values = explainer.shap_values(x)
51
+
52
+ # SHAP returns different shapes depending on model type
53
+ if isinstance(values, list): # multiclass -> list of arrays
54
+ # stack into shape (n_classes, n_samples, n_features)
55
+ values_arr = np.stack(values, axis=0)
56
+ # average across classes -> shape (n_samples, n_features)
57
+ values_arr = np.mean(values_arr, axis=0)
58
+ else:
59
+ values_arr = values # already (n_samples, n_features)
60
+
61
+ # Always flatten to 1D vector
62
+ shap_vec = np.array(values_arr[0]).reshape(-1)
63
+
64
+ # Ensure length matches feature_names
65
+ n_features = len(feature_names)
66
+ if len(shap_vec) != n_features:
67
+ shap_vec = shap_vec[:n_features]
68
+
69
+ base_value = explainer.expected_value
70
+ if isinstance(base_value, (list, np.ndarray)):
71
+ base_value = float(np.mean(base_value))
72
+
73
+ # Top-k by absolute impact
74
+ abs_imp = np.abs(shap_vec)
75
+ idx = np.argsort(-abs_imp)[:top_k].ravel()
76
+
77
+ top = []
78
+ for i in idx:
79
+ i = int(i)
80
+ if i >= n_features: # safety check
81
+ continue
82
+
83
+ shap_val = shap_vec[i]
84
+ if isinstance(shap_val, (np.ndarray, list)):
85
+ shap_val = float(np.mean(shap_val))
86
+ else:
87
+ shap_val = float(shap_val)
88
+
89
+ abs_val = abs_imp[i]
90
+ if isinstance(abs_val, (np.ndarray, list)):
91
+ abs_val = float(np.mean(abs_val))
92
+ else:
93
+ abs_val = float(abs_val)
94
+
95
+ top.append({
96
+ "feature": feature_names[i],
97
+ "value": float(x[0, i]),
98
+ "shap": shap_val,
99
+ "abs_impact": abs_val,
100
+ })
101
+
102
+ return {
103
+ "shap_values": shap_vec.tolist(),
104
+ "base_value": float(base_value),
105
+ "topk": top,
106
+ }
core/model_loader.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ from typing import Any, Tuple, Optional
3
+ import numpy as np
4
+
5
+ def load_model(path: str) -> Any:
6
+ """Load a pickled sklearn-compatible model."""
7
+ model = joblib.load(path)
8
+ return model
9
+
10
+ def predict(model: Any, X: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
11
+ """Return (pred, proba_or_none). Handles regressors & classifiers."""
12
+ y_pred = model.predict(X)
13
+ proba = None
14
+ if hasattr(model, "predict_proba"):
15
+ try:
16
+ proba = model.predict_proba(X)
17
+ except Exception:
18
+ proba = None
19
+ return y_pred, proba
core/retrieval.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+ import numpy as np
3
+ from .storage import load_matrices
4
+
5
+
6
+ def _cosine(a: np.ndarray, b: np.ndarray) -> float:
7
+ na = np.linalg.norm(a) + 1e-12
8
+ nb = np.linalg.norm(b) + 1e-12
9
+ return float(np.dot(a, b) / (na * nb))
10
+
11
+
12
+ def combined_similarity(
13
+ shap_q: np.ndarray,
14
+ feat_q: np.ndarray,
15
+ shap_i: np.ndarray,
16
+ feat_i: np.ndarray,
17
+ alpha: float
18
+ ) -> float:
19
+ """similarity = alpha * cos(SHAP) + (1 - alpha) * cos(features)"""
20
+ return alpha * _cosine(shap_q, shap_i) + (1.0 - alpha) * _cosine(feat_q, feat_i)
21
+
22
+
23
+ def retrieve_topk(
24
+ namespace: str,
25
+ shap_q: np.ndarray,
26
+ x_q: np.ndarray,
27
+ alpha: float = 0.5,
28
+ k: int = 5
29
+ ) -> List[Dict[str, Any]]:
30
+ """
31
+ Retrieve top-k similar cases from a namespace.
32
+ Returns dicts with case_id, similarity, y_pred, shap_values, features, meta.
33
+ """
34
+ # Load stored matrices and metadata
35
+ X, SHAP, metas, case_ids = load_matrices(namespace)
36
+
37
+ # Flatten metas into a dict keyed by case_id
38
+ meta_dict: Dict[str, Dict[str, Any]] = {}
39
+ for m in metas:
40
+ if isinstance(m, dict):
41
+ meta_dict.update(m)
42
+
43
+ sims: List[Dict[str, Any]] = []
44
+ for i, cid in enumerate(case_ids):
45
+ feat = X[i]
46
+ shap = SHAP[i]
47
+
48
+ # compute similarity
49
+ score = combined_similarity(shap_q, x_q, shap, feat, alpha=alpha)
50
+
51
+ # get meta (safe fallback)
52
+ m = meta_dict.get(cid, {})
53
+
54
+ sims.append({
55
+ "case_id": cid,
56
+ "similarity": float(score),
57
+ "y_pred": m.get("y_pred"),
58
+ "shap_values": shap.tolist(),
59
+ "features": feat.tolist(),
60
+ "meta": m
61
+ })
62
+
63
+ # sort and return top-k
64
+ sims = sorted(sims, key=lambda d: -d["similarity"])
65
+ return sims[:k]
66
+
67
+
68
+ def ood_score(shap_query: np.ndarray, shaps_matrix: np.ndarray) -> float:
69
+ """Simple OOD heuristic: 1 - max cosine against corpus SHAPs."""
70
+ if shaps_matrix.size == 0:
71
+ return 1.0
72
+ best = -1.0
73
+ for i in range(shaps_matrix.shape[0]):
74
+ c = _cosine(shap_query, shaps_matrix[i])
75
+ if c > best:
76
+ best = c
77
+ return float(1.0 - best)
core/schemas.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import BaseModel
3
+
4
+ class Instance(BaseModel):
5
+ # Ordered feature vector for your model
6
+ features: List[float]
7
+ feature_names: List[str]
8
+
9
+ class PredictionResult(BaseModel):
10
+ y_pred: float
11
+ proba: Optional[float] = None
12
+
13
+ class Explanation(BaseModel):
14
+ shap_values: List[float] # reasoning vector
15
+ base_value: float
16
+ topk: List[Dict[str, Any]]
17
+
18
+ class RetrievalConfig(BaseModel):
19
+ alpha: float = 0.7 # weight for SHAP cosine vs feature cosine
20
+ k: int = 5
21
+ use_retrieval: bool = True
22
+ namespace: str = "global_default"
23
+
24
+ class RetrievedCase(BaseModel):
25
+ case_id: str
26
+ similarity: float
27
+ y_pred: Optional[float] = None
28
+ shap_values: Optional[List[float]] = None
29
+ features: Optional[List[float]] = None
30
+ meta: Optional[Dict[str, Any]] = None
31
+
32
+ class ExplainResponse(BaseModel):
33
+ prediction: PredictionResult
34
+ explanation: Explanation
35
+ similar_cases: Optional[List[RetrievedCase]] = None
36
+ ood_flag: bool = False
37
+ ood_reason: Optional[str] = None
core/storage.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import List, Dict, Any, Tuple
4
+ import numpy as np
5
+
6
+ INDEX_FILE = "index.jsonl"
7
+ FEATURE_FILE = "features.npy"
8
+ SHAP_FILE = "shap.npy"
9
+ META_FILE = "meta.jsonl"
10
+
11
+
12
+ def ensure_dir(path: str):
13
+ os.makedirs(path, exist_ok=True)
14
+
15
+
16
+ def append_jsonl(path: str, row: Dict[str, Any]):
17
+ with open(path, "a", encoding="utf-8") as f:
18
+ f.write(json.dumps(row) + "\n")
19
+
20
+
21
+ def load_index(path: str) -> List[Dict[str, Any]]:
22
+ rows = []
23
+ with open(path, "r", encoding="utf-8") as f:
24
+ for line in f:
25
+ if line.strip():
26
+ rows.append(json.loads(line))
27
+ return rows
28
+
29
+
30
+ def init_matrix_files(namespace_dir: str, feature_dim: int, shap_dim: int):
31
+ """Create empty .npy matrices if they don't exist."""
32
+ feat_path = os.path.join(namespace_dir, FEATURE_FILE)
33
+ shap_path = os.path.join(namespace_dir, SHAP_FILE)
34
+ if not os.path.exists(feat_path):
35
+ np.save(feat_path, np.zeros((0, feature_dim), dtype="float32"))
36
+ if not os.path.exists(shap_path):
37
+ np.save(shap_path, np.zeros((0, shap_dim), dtype="float32"))
38
+
39
+
40
+ def append_case(namespace_dir: str, case_id: str, features: np.ndarray, shap_vec: np.ndarray, meta: Dict[str, Any]):
41
+ """Append one case to the namespace store."""
42
+ ensure_dir(namespace_dir)
43
+
44
+ # grow matrices
45
+ feat_path = os.path.join(namespace_dir, FEATURE_FILE)
46
+ shap_path = os.path.join(namespace_dir, SHAP_FILE)
47
+ feats = np.load(feat_path)
48
+ shaps = np.load(shap_path)
49
+ feats = np.vstack([feats, features.reshape(1, -1).astype("float32")])
50
+ shaps = np.vstack([shaps, shap_vec.reshape(1, -1).astype("float32")])
51
+ np.save(feat_path, feats)
52
+ np.save(shap_path, shaps)
53
+
54
+ # index & meta
55
+ idx_path = os.path.join(namespace_dir, INDEX_FILE)
56
+ append_jsonl(idx_path, {"case_id": case_id, "row": feats.shape[0] - 1})
57
+ meta_path = os.path.join(namespace_dir, META_FILE)
58
+ append_jsonl(meta_path, {case_id: meta})
59
+
60
+
61
+ def load_matrices(namespace_dir: str) -> Tuple[np.ndarray, np.ndarray, List[Dict[str, Any]], List[str]]:
62
+ """
63
+ Load all stored matrices and metadata for retrieval.
64
+ Returns:
65
+ X (np.ndarray) : Features matrix
66
+ SHAP (np.ndarray) : SHAP values matrix
67
+ metas (list[dict]) : Metadata entries
68
+ case_ids (list[str]): Case IDs
69
+ """
70
+ # Load features & shap
71
+ feat_path = os.path.join(namespace_dir, FEATURE_FILE)
72
+ shap_path = os.path.join(namespace_dir, SHAP_FILE)
73
+ X = np.load(feat_path)
74
+ SHAP = np.load(shap_path)
75
+
76
+ # Load metadata
77
+ metas = []
78
+ meta_path = os.path.join(namespace_dir, META_FILE)
79
+ if os.path.exists(meta_path):
80
+ with open(meta_path, "r", encoding="utf-8") as f:
81
+ for line in f:
82
+ if line.strip():
83
+ metas.append(json.loads(line))
84
+
85
+ # Load case IDs
86
+ case_ids = []
87
+ idx_path = os.path.join(namespace_dir, INDEX_FILE)
88
+ if os.path.exists(idx_path):
89
+ with open(idx_path, "r", encoding="utf-8") as f:
90
+ for line in f:
91
+ if line.strip():
92
+ entry = json.loads(line)
93
+ case_ids.append(entry.get("case_id"))
94
+
95
+ return X, SHAP, metas, case_ids
core/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import List
3
+ import numpy as np
4
+
5
+ def case_id_from_vector(x: np.ndarray, prefix: str = "case") -> str:
6
+ h = hashlib.md5(x.tobytes()).hexdigest()[:10]
7
+ return f"{prefix}_{h}"
8
+
9
+ def to_numpy(lst, dtype="float32"):
10
+ return np.asarray(lst, dtype=dtype)
11
+
12
+ def safe_proba_to_scalar(proba, positive_index: int = 1):
13
+ """Return a single probability for binary classifiers when possible."""
14
+ if proba is None:
15
+ return None
16
+ arr = np.asarray(proba)
17
+ if arr.ndim == 2 and arr.shape[1] >= 2:
18
+ return float(arr[0, positive_index])
19
+ # fallback: average
20
+ return float(arr.mean())
data/base_indices/iris_global/features.npy ADDED
Binary file (1.73 kB). View file
 
data/base_indices/iris_global/index.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"case_id": "iris_00bbac633c", "row": 0}
2
+ {"case_id": "iris_b2a2c274fa", "row": 1}
3
+ {"case_id": "iris_f79fee902c", "row": 2}
4
+ {"case_id": "iris_9a1a194bc4", "row": 3}
5
+ {"case_id": "iris_d5f6d63eb7", "row": 4}
6
+ {"case_id": "iris_fb322465b8", "row": 5}
7
+ {"case_id": "iris_751758f9a1", "row": 6}
8
+ {"case_id": "iris_2a967bb0c8", "row": 7}
9
+ {"case_id": "iris_e7d76ce04e", "row": 8}
10
+ {"case_id": "iris_7be756b6c3", "row": 9}
11
+ {"case_id": "iris_2a8b6920ad", "row": 10}
12
+ {"case_id": "iris_95f30553cb", "row": 11}
13
+ {"case_id": "iris_190a3c83bf", "row": 12}
14
+ {"case_id": "iris_dddf43f88f", "row": 13}
15
+ {"case_id": "iris_be9c18be01", "row": 14}
16
+ {"case_id": "iris_0f92794e2b", "row": 15}
17
+ {"case_id": "iris_41d43e7e9a", "row": 16}
18
+ {"case_id": "iris_8f66265fc7", "row": 17}
19
+ {"case_id": "iris_d3945eb482", "row": 18}
20
+ {"case_id": "iris_4b1e78fdc5", "row": 19}
21
+ {"case_id": "iris_401ec0e4cd", "row": 20}
22
+ {"case_id": "iris_408e8d870d", "row": 21}
23
+ {"case_id": "iris_9cafb2d428", "row": 22}
24
+ {"case_id": "iris_e50ca52202", "row": 23}
25
+ {"case_id": "iris_2babca4f93", "row": 24}
26
+ {"case_id": "iris_306decaed1", "row": 25}
27
+ {"case_id": "iris_8925772bb8", "row": 26}
28
+ {"case_id": "iris_16f2c8a614", "row": 27}
29
+ {"case_id": "iris_affabb42bd", "row": 28}
30
+ {"case_id": "iris_cd147f78d3", "row": 29}
31
+ {"case_id": "iris_60ceafb3b7", "row": 30}
32
+ {"case_id": "iris_971bb14551", "row": 31}
33
+ {"case_id": "iris_3c46aadfa8", "row": 32}
34
+ {"case_id": "iris_8949d2093a", "row": 33}
35
+ {"case_id": "iris_54db69a5ef", "row": 34}
36
+ {"case_id": "iris_553603a759", "row": 35}
37
+ {"case_id": "iris_1fbd72f69e", "row": 36}
38
+ {"case_id": "iris_1aa1718647", "row": 37}
39
+ {"case_id": "iris_4e47b9e277", "row": 38}
40
+ {"case_id": "iris_0b3fb6e054", "row": 39}
41
+ {"case_id": "iris_afb9f3ce89", "row": 40}
42
+ {"case_id": "iris_d964678b78", "row": 41}
43
+ {"case_id": "iris_d5afa1ffc3", "row": 42}
44
+ {"case_id": "iris_8d176d6739", "row": 43}
45
+ {"case_id": "iris_b3b9231f82", "row": 44}
46
+ {"case_id": "iris_948f3351ef", "row": 45}
47
+ {"case_id": "iris_cf7d9336af", "row": 46}
48
+ {"case_id": "iris_1d9428989e", "row": 47}
49
+ {"case_id": "iris_ca1177d767", "row": 48}
50
+ {"case_id": "iris_7435ef9308", "row": 49}
51
+ {"case_id": "iris_187546a192", "row": 50}
52
+ {"case_id": "iris_f67c61b994", "row": 51}
53
+ {"case_id": "iris_12ca8c3bc8", "row": 52}
54
+ {"case_id": "iris_e883f0a96b", "row": 53}
55
+ {"case_id": "iris_5d30ef01ab", "row": 54}
56
+ {"case_id": "iris_06713bd1b2", "row": 55}
57
+ {"case_id": "iris_cdea50b849", "row": 56}
58
+ {"case_id": "iris_9a7d15fcb5", "row": 57}
59
+ {"case_id": "iris_aa4ec334d1", "row": 58}
60
+ {"case_id": "iris_1753b1a603", "row": 59}
61
+ {"case_id": "iris_bd16db5e4c", "row": 60}
62
+ {"case_id": "iris_45e9c6b8be", "row": 61}
63
+ {"case_id": "iris_90355b0853", "row": 62}
64
+ {"case_id": "iris_29f5ab1fcc", "row": 63}
65
+ {"case_id": "iris_ba49dde13f", "row": 64}
66
+ {"case_id": "iris_938819d7e3", "row": 65}
67
+ {"case_id": "iris_ced4f5a163", "row": 66}
68
+ {"case_id": "iris_a0555b0006", "row": 67}
69
+ {"case_id": "iris_245849f78c", "row": 68}
70
+ {"case_id": "iris_0315cdedea", "row": 69}
71
+ {"case_id": "iris_678b362b66", "row": 70}
72
+ {"case_id": "iris_495ee2afb0", "row": 71}
73
+ {"case_id": "iris_ab99322692", "row": 72}
74
+ {"case_id": "iris_afb9f3ce89", "row": 73}
75
+ {"case_id": "iris_f873cbc152", "row": 74}
76
+ {"case_id": "iris_63c413d1a9", "row": 75}
77
+ {"case_id": "iris_42ca7166cd", "row": 76}
78
+ {"case_id": "iris_31d4a40847", "row": 77}
79
+ {"case_id": "iris_d458f158e0", "row": 78}
80
+ {"case_id": "iris_373dfcf880", "row": 79}
81
+ {"case_id": "iris_2d037a30fc", "row": 80}
82
+ {"case_id": "iris_3954a347b6", "row": 81}
83
+ {"case_id": "iris_7c437b3319", "row": 82}
84
+ {"case_id": "iris_519f77cbe0", "row": 83}
85
+ {"case_id": "iris_01fe6fa830", "row": 84}
86
+ {"case_id": "iris_d0b253b7b8", "row": 85}
87
+ {"case_id": "iris_0a27e6c142", "row": 86}
88
+ {"case_id": "iris_46a04f9bf1", "row": 87}
89
+ {"case_id": "iris_ef474c84d7", "row": 88}
90
+ {"case_id": "iris_07a988927f", "row": 89}
91
+ {"case_id": "iris_c93d7a8c57", "row": 90}
92
+ {"case_id": "iris_3af34eed28", "row": 91}
93
+ {"case_id": "iris_be68e3ed79", "row": 92}
94
+ {"case_id": "iris_55c1ce18c8", "row": 93}
95
+ {"case_id": "iris_0393b6cfa4", "row": 94}
96
+ {"case_id": "iris_70d7d9f959", "row": 95}
97
+ {"case_id": "iris_da32c2c5cb", "row": 96}
98
+ {"case_id": "iris_44888f11a6", "row": 97}
99
+ {"case_id": "iris_feaabdd51f", "row": 98}
100
+ {"case_id": "iris_c6b9d16895", "row": 99}
data/base_indices/iris_global/meta.jsonl ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"iris_00bbac633c": {"y_pred": 1.0}}
2
+ {"iris_b2a2c274fa": {"y_pred": 0.0}}
3
+ {"iris_f79fee902c": {"y_pred": 2.0}}
4
+ {"iris_9a1a194bc4": {"y_pred": 1.0}}
5
+ {"iris_d5f6d63eb7": {"y_pred": 1.0}}
6
+ {"iris_fb322465b8": {"y_pred": 0.0}}
7
+ {"iris_751758f9a1": {"y_pred": 1.0}}
8
+ {"iris_2a967bb0c8": {"y_pred": 2.0}}
9
+ {"iris_e7d76ce04e": {"y_pred": 1.0}}
10
+ {"iris_7be756b6c3": {"y_pred": 1.0}}
11
+ {"iris_2a8b6920ad": {"y_pred": 2.0}}
12
+ {"iris_95f30553cb": {"y_pred": 0.0}}
13
+ {"iris_190a3c83bf": {"y_pred": 0.0}}
14
+ {"iris_dddf43f88f": {"y_pred": 0.0}}
15
+ {"iris_be9c18be01": {"y_pred": 0.0}}
16
+ {"iris_0f92794e2b": {"y_pred": 1.0}}
17
+ {"iris_41d43e7e9a": {"y_pred": 2.0}}
18
+ {"iris_8f66265fc7": {"y_pred": 1.0}}
19
+ {"iris_d3945eb482": {"y_pred": 1.0}}
20
+ {"iris_4b1e78fdc5": {"y_pred": 2.0}}
21
+ {"iris_401ec0e4cd": {"y_pred": 0.0}}
22
+ {"iris_408e8d870d": {"y_pred": 2.0}}
23
+ {"iris_9cafb2d428": {"y_pred": 0.0}}
24
+ {"iris_e50ca52202": {"y_pred": 2.0}}
25
+ {"iris_2babca4f93": {"y_pred": 2.0}}
26
+ {"iris_306decaed1": {"y_pred": 2.0}}
27
+ {"iris_8925772bb8": {"y_pred": 2.0}}
28
+ {"iris_16f2c8a614": {"y_pred": 2.0}}
29
+ {"iris_affabb42bd": {"y_pred": 0.0}}
30
+ {"iris_cd147f78d3": {"y_pred": 0.0}}
31
+ {"iris_60ceafb3b7": {"y_pred": 0.0}}
32
+ {"iris_971bb14551": {"y_pred": 0.0}}
33
+ {"iris_3c46aadfa8": {"y_pred": 1.0}}
34
+ {"iris_8949d2093a": {"y_pred": 0.0}}
35
+ {"iris_54db69a5ef": {"y_pred": 0.0}}
36
+ {"iris_553603a759": {"y_pred": 2.0}}
37
+ {"iris_1fbd72f69e": {"y_pred": 1.0}}
38
+ {"iris_1aa1718647": {"y_pred": 0.0}}
39
+ {"iris_4e47b9e277": {"y_pred": 0.0}}
40
+ {"iris_0b3fb6e054": {"y_pred": 0.0}}
41
+ {"iris_afb9f3ce89": {"y_pred": 2.0}}
42
+ {"iris_d964678b78": {"y_pred": 1.0}}
43
+ {"iris_d5afa1ffc3": {"y_pred": 1.0}}
44
+ {"iris_8d176d6739": {"y_pred": 0.0}}
45
+ {"iris_b3b9231f82": {"y_pred": 0.0}}
46
+ {"iris_948f3351ef": {"y_pred": 1.0}}
47
+ {"iris_cf7d9336af": {"y_pred": 2.0}}
48
+ {"iris_1d9428989e": {"y_pred": 2.0}}
49
+ {"iris_ca1177d767": {"y_pred": 1.0}}
50
+ {"iris_7435ef9308": {"y_pred": 2.0}}
51
+ {"iris_187546a192": {"y_pred": 1.0}}
52
+ {"iris_f67c61b994": {"y_pred": 2.0}}
53
+ {"iris_12ca8c3bc8": {"y_pred": 1.0}}
54
+ {"iris_e883f0a96b": {"y_pred": 0.0}}
55
+ {"iris_5d30ef01ab": {"y_pred": 2.0}}
56
+ {"iris_06713bd1b2": {"y_pred": 1.0}}
57
+ {"iris_cdea50b849": {"y_pred": 0.0}}
58
+ {"iris_9a7d15fcb5": {"y_pred": 0.0}}
59
+ {"iris_aa4ec334d1": {"y_pred": 0.0}}
60
+ {"iris_1753b1a603": {"y_pred": 1.0}}
61
+ {"iris_bd16db5e4c": {"y_pred": 2.0}}
62
+ {"iris_45e9c6b8be": {"y_pred": 0.0}}
63
+ {"iris_90355b0853": {"y_pred": 0.0}}
64
+ {"iris_29f5ab1fcc": {"y_pred": 0.0}}
65
+ {"iris_ba49dde13f": {"y_pred": 1.0}}
66
+ {"iris_938819d7e3": {"y_pred": 0.0}}
67
+ {"iris_ced4f5a163": {"y_pred": 1.0}}
68
+ {"iris_a0555b0006": {"y_pred": 2.0}}
69
+ {"iris_245849f78c": {"y_pred": 0.0}}
70
+ {"iris_0315cdedea": {"y_pred": 1.0}}
71
+ {"iris_678b362b66": {"y_pred": 2.0}}
72
+ {"iris_495ee2afb0": {"y_pred": 0.0}}
73
+ {"iris_ab99322692": {"y_pred": 2.0}}
74
+ {"iris_afb9f3ce89": {"y_pred": 2.0}}
75
+ {"iris_f873cbc152": {"y_pred": 1.0}}
76
+ {"iris_63c413d1a9": {"y_pred": 1.0}}
77
+ {"iris_42ca7166cd": {"y_pred": 2.0}}
78
+ {"iris_31d4a40847": {"y_pred": 1.0}}
79
+ {"iris_d458f158e0": {"y_pred": 0.0}}
80
+ {"iris_373dfcf880": {"y_pred": 1.0}}
81
+ {"iris_2d037a30fc": {"y_pred": 2.0}}
82
+ {"iris_3954a347b6": {"y_pred": 0.0}}
83
+ {"iris_7c437b3319": {"y_pred": 0.0}}
84
+ {"iris_519f77cbe0": {"y_pred": 1.0}}
85
+ {"iris_01fe6fa830": {"y_pred": 1.0}}
86
+ {"iris_d0b253b7b8": {"y_pred": 0.0}}
87
+ {"iris_0a27e6c142": {"y_pred": 2.0}}
88
+ {"iris_46a04f9bf1": {"y_pred": 0.0}}
89
+ {"iris_ef474c84d7": {"y_pred": 0.0}}
90
+ {"iris_07a988927f": {"y_pred": 1.0}}
91
+ {"iris_c93d7a8c57": {"y_pred": 1.0}}
92
+ {"iris_3af34eed28": {"y_pred": 2.0}}
93
+ {"iris_be68e3ed79": {"y_pred": 1.0}}
94
+ {"iris_55c1ce18c8": {"y_pred": 2.0}}
95
+ {"iris_0393b6cfa4": {"y_pred": 2.0}}
96
+ {"iris_70d7d9f959": {"y_pred": 1.0}}
97
+ {"iris_da32c2c5cb": {"y_pred": 0.0}}
98
+ {"iris_44888f11a6": {"y_pred": 0.0}}
99
+ {"iris_feaabdd51f": {"y_pred": 2.0}}
100
+ {"iris_c6b9d16895": {"y_pred": 2.0}}
data/base_indices/iris_global/shap.npy ADDED
Binary file (1.73 kB). View file
 
model_data/data.csv ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
2
+ 5.1,3.5,1.4,0.2,0
3
+ 4.9,3.0,1.4,0.2,0
4
+ 4.7,3.2,1.3,0.2,0
5
+ 4.6,3.1,1.5,0.2,0
6
+ 5.0,3.6,1.4,0.2,0
7
+ 5.4,3.9,1.7,0.4,0
8
+ 4.6,3.4,1.4,0.3,0
9
+ 5.0,3.4,1.5,0.2,0
10
+ 4.4,2.9,1.4,0.2,0
11
+ 4.9,3.1,1.5,0.1,0
12
+ 5.4,3.7,1.5,0.2,0
13
+ 4.8,3.4,1.6,0.2,0
14
+ 4.8,3.0,1.4,0.1,0
15
+ 4.3,3.0,1.1,0.1,0
16
+ 5.8,4.0,1.2,0.2,0
17
+ 5.7,4.4,1.5,0.4,0
18
+ 5.4,3.9,1.3,0.4,0
19
+ 5.1,3.5,1.4,0.3,0
20
+ 5.7,3.8,1.7,0.3,0
21
+ 5.1,3.8,1.5,0.3,0
22
+ 5.4,3.4,1.7,0.2,0
23
+ 5.1,3.7,1.5,0.4,0
24
+ 4.6,3.6,1.0,0.2,0
25
+ 5.1,3.3,1.7,0.5,0
26
+ 4.8,3.4,1.9,0.2,0
27
+ 5.0,3.0,1.6,0.2,0
28
+ 5.0,3.4,1.6,0.4,0
29
+ 5.2,3.5,1.5,0.2,0
30
+ 5.2,3.4,1.4,0.2,0
31
+ 4.7,3.2,1.6,0.2,0
32
+ 4.8,3.1,1.6,0.2,0
33
+ 5.4,3.4,1.5,0.4,0
34
+ 5.2,4.1,1.5,0.1,0
35
+ 5.5,4.2,1.4,0.2,0
36
+ 4.9,3.1,1.5,0.2,0
37
+ 5.0,3.2,1.2,0.2,0
38
+ 5.5,3.5,1.3,0.2,0
39
+ 4.9,3.6,1.4,0.1,0
40
+ 4.4,3.0,1.3,0.2,0
41
+ 5.1,3.4,1.5,0.2,0
42
+ 5.0,3.5,1.3,0.3,0
43
+ 4.5,2.3,1.3,0.3,0
44
+ 4.4,3.2,1.3,0.2,0
45
+ 5.0,3.5,1.6,0.6,0
46
+ 5.1,3.8,1.9,0.4,0
47
+ 4.8,3.0,1.4,0.3,0
48
+ 5.1,3.8,1.6,0.2,0
49
+ 4.6,3.2,1.4,0.2,0
50
+ 5.3,3.7,1.5,0.2,0
51
+ 5.0,3.3,1.4,0.2,0
52
+ 7.0,3.2,4.7,1.4,1
53
+ 6.4,3.2,4.5,1.5,1
54
+ 6.9,3.1,4.9,1.5,1
55
+ 5.5,2.3,4.0,1.3,1
56
+ 6.5,2.8,4.6,1.5,1
57
+ 5.7,2.8,4.5,1.3,1
58
+ 6.3,3.3,4.7,1.6,1
59
+ 4.9,2.4,3.3,1.0,1
60
+ 6.6,2.9,4.6,1.3,1
61
+ 5.2,2.7,3.9,1.4,1
62
+ 5.0,2.0,3.5,1.0,1
63
+ 5.9,3.0,4.2,1.5,1
64
+ 6.0,2.2,4.0,1.0,1
65
+ 6.1,2.9,4.7,1.4,1
66
+ 5.6,2.9,3.6,1.3,1
67
+ 6.7,3.1,4.4,1.4,1
68
+ 5.6,3.0,4.5,1.5,1
69
+ 5.8,2.7,4.1,1.0,1
70
+ 6.2,2.2,4.5,1.5,1
71
+ 5.6,2.5,3.9,1.1,1
72
+ 5.9,3.2,4.8,1.8,1
73
+ 6.1,2.8,4.0,1.3,1
74
+ 6.3,2.5,4.9,1.5,1
75
+ 6.1,2.8,4.7,1.2,1
76
+ 6.4,2.9,4.3,1.3,1
77
+ 6.6,3.0,4.4,1.4,1
78
+ 6.8,2.8,4.8,1.4,1
79
+ 6.7,3.0,5.0,1.7,1
80
+ 6.0,2.9,4.5,1.5,1
81
+ 5.7,2.6,3.5,1.0,1
82
+ 5.5,2.4,3.8,1.1,1
83
+ 5.5,2.4,3.7,1.0,1
84
+ 5.8,2.7,3.9,1.2,1
85
+ 6.0,2.7,5.1,1.6,1
86
+ 5.4,3.0,4.5,1.5,1
87
+ 6.0,3.4,4.5,1.6,1
88
+ 6.7,3.1,4.7,1.5,1
89
+ 6.3,2.3,4.4,1.3,1
90
+ 5.6,3.0,4.1,1.3,1
91
+ 5.5,2.5,4.0,1.3,1
92
+ 5.5,2.6,4.4,1.2,1
93
+ 6.1,3.0,4.6,1.4,1
94
+ 5.8,2.6,4.0,1.2,1
95
+ 5.0,2.3,3.3,1.0,1
96
+ 5.6,2.7,4.2,1.3,1
97
+ 5.7,3.0,4.2,1.2,1
98
+ 5.7,2.9,4.2,1.3,1
99
+ 6.2,2.9,4.3,1.3,1
100
+ 5.1,2.5,3.0,1.1,1
101
+ 5.7,2.8,4.1,1.3,1
102
+ 6.3,3.3,6.0,2.5,2
103
+ 5.8,2.7,5.1,1.9,2
104
+ 7.1,3.0,5.9,2.1,2
105
+ 6.3,2.9,5.6,1.8,2
106
+ 6.5,3.0,5.8,2.2,2
107
+ 7.6,3.0,6.6,2.1,2
108
+ 4.9,2.5,4.5,1.7,2
109
+ 7.3,2.9,6.3,1.8,2
110
+ 6.7,2.5,5.8,1.8,2
111
+ 7.2,3.6,6.1,2.5,2
112
+ 6.5,3.2,5.1,2.0,2
113
+ 6.4,2.7,5.3,1.9,2
114
+ 6.8,3.0,5.5,2.1,2
115
+ 5.7,2.5,5.0,2.0,2
116
+ 5.8,2.8,5.1,2.4,2
117
+ 6.4,3.2,5.3,2.3,2
118
+ 6.5,3.0,5.5,1.8,2
119
+ 7.7,3.8,6.7,2.2,2
120
+ 7.7,2.6,6.9,2.3,2
121
+ 6.0,2.2,5.0,1.5,2
122
+ 6.9,3.2,5.7,2.3,2
123
+ 5.6,2.8,4.9,2.0,2
124
+ 7.7,2.8,6.7,2.0,2
125
+ 6.3,2.7,4.9,1.8,2
126
+ 6.7,3.3,5.7,2.1,2
127
+ 7.2,3.2,6.0,1.8,2
128
+ 6.2,2.8,4.8,1.8,2
129
+ 6.1,3.0,4.9,1.8,2
130
+ 6.4,2.8,5.6,2.1,2
131
+ 7.2,3.0,5.8,1.6,2
132
+ 7.4,2.8,6.1,1.9,2
133
+ 7.9,3.8,6.4,2.0,2
134
+ 6.4,2.8,5.6,2.2,2
135
+ 6.3,2.8,5.1,1.5,2
136
+ 6.1,2.6,5.6,1.4,2
137
+ 7.7,3.0,6.1,2.3,2
138
+ 6.3,3.4,5.6,2.4,2
139
+ 6.4,3.1,5.5,1.8,2
140
+ 6.0,3.0,4.8,1.8,2
141
+ 6.9,3.1,5.4,2.1,2
142
+ 6.7,3.1,5.6,2.4,2
143
+ 6.9,3.1,5.1,2.3,2
144
+ 5.8,2.7,5.1,1.9,2
145
+ 6.8,3.2,5.9,2.3,2
146
+ 6.7,3.3,5.7,2.5,2
147
+ 6.7,3.0,5.2,2.3,2
148
+ 6.3,2.5,5.0,1.9,2
149
+ 6.5,3.0,5.2,2.0,2
150
+ 6.2,3.4,5.4,2.3,2
151
+ 5.9,3.0,5.1,1.8,2
model_data/model.pkl ADDED
Binary file (94.9 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ pandas==2.2.2
3
+ scikit-learn==1.4.2
4
+ shap==0.45.0
5
+ fastapi==0.115.0
6
+ uvicorn==0.30.6
7
+ python-multipart==0.0.9
8
+ streamlit==1.39.0
9
+ pydantic==2.9.2
10
+ joblib==1.4.2
11
+ matplotlib==3.9.2
scripts/__pycache__/build_base_index.cpython-313.pyc ADDED
Binary file (3.03 kB). View file
 
scripts/add_user_model.py ADDED
File without changes
scripts/build_base_index.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Precompute a 'global' reasoning space from a baseline model + dataset.
3
+
4
+ Usage:
5
+ python scripts/build_base_index.py \
6
+ --model_path path/to/model.pkl \
7
+ --csv path/to/data.csv \
8
+ --features col1,col2,col3 \
9
+ --target target_col \
10
+ --namespace data/base_indices/recidivism_global \
11
+ --sample 2000
12
+ """
13
+ # Query_Your_Model/scripts/build_base_index.py
14
+ import sys, os
15
+ #sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
16
+
17
+
18
+ import os
19
+ import pandas as pd
20
+ import numpy as np
21
+ from ..core.model_loader import load_model, predict
22
+ from ..core.explain import explain_instance
23
+ from ..core.storage import ensure_dir, init_matrix_files, append_case
24
+ from ..core.utils import case_id_from_vector
25
+
26
+
27
+ # Hardcoded defaults for Iris demo
28
+ MODEL_PATH = "Query_Your_Model/model_data/model.pkl"
29
+ CSV_PATH = "Query_Your_Model/model_data/data.csv"
30
+ FEATURES = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
31
+ TARGET = "target"
32
+ NAMESPACE = "Query_Your_Model/data/base_indices/iris_global"
33
+ SAMPLE = 100 # how many rows to sample
34
+
35
+ def main():
36
+ print("Building reasoning index...")
37
+ df = pd.read_csv(CSV_PATH)
38
+ if SAMPLE and SAMPLE < len(df):
39
+ df = df.sample(SAMPLE, random_state=42)
40
+
41
+ X = df[FEATURES].values
42
+ model = load_model(MODEL_PATH)
43
+
44
+ ensure_dir(NAMESPACE)
45
+ init_matrix_files(NAMESPACE, feature_dim=len(FEATURES), shap_dim=len(FEATURES))
46
+
47
+
48
+ bg = df[FEATURES].sample(min(100, len(df)), random_state=0).values.astype("float32")
49
+
50
+ for i, row in df.iterrows():
51
+ x = row[FEATURES].values.astype("float32")
52
+ y_pred, _ = predict(model, x.reshape(1, -1))
53
+ exp = explain_instance(model, x, FEATURES, background_X=bg, top_k=8)
54
+ shap_vec = np.array(exp["shap_values"], dtype="float32")
55
+ cid = case_id_from_vector(x, prefix="iris")
56
+ meta = {"y_pred": float(y_pred[0])}
57
+ append_case(NAMESPACE, cid, x, shap_vec, meta)
58
+
59
+ print(f"Done! Index saved to {NAMESPACE}")
60
+
61
+ if __name__ == "__main__":
62
+ main()
scripts/build_iris.bat ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Building Iris reasoning index...
3
+
4
+ python Query_Your_Model/scripts/build_base_index.py ^
5
+ --model_path Query_Your_Model/model_data/model.pkl ^
6
+ --csv Query_Your_Model/model_data/data.csv ^
7
+ --features "sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)" ^
8
+ --target target ^
9
+ --namespace Query_Your_Model/data/base_indices/iris_global ^
10
+ --sample 100
11
+
12
+ echo Done! Index saved to Query_Your_Model/data/base_indices/iris_global
13
+ pause
scripts/demo_predict.py ADDED
File without changes
tests/__pycache__/test_similarity.cpython-313.pyc ADDED
Binary file (1.48 kB). View file
 
tests/test_similarity.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Query_Your_Model/tests/test_similarity.py
2
+ import numpy as np
3
+ from Query_Your_Model.core.retrieval import combined_similarity
4
+
5
+
6
+ def test_combined_similarity_basic():
7
+ a = np.array([1, 0, 0], dtype="float32") # feature vector 1
8
+ b = np.array([0, 1, 0], dtype="float32") # feature vector 2 (orthogonal)
9
+ shap_a = np.array([0.5, 0.2, 0.1], dtype="float32") # shap for a
10
+ shap_b = np.array([-0.5, 0.0, 0.0], dtype="float32") # shap for b
11
+
12
+ # Similarity of identical pair (a,a)
13
+ s1 = combined_similarity(a, shap_a, a, shap_a, alpha=0.5)
14
+ print(f"Similarity (identical): {s1:.4f}")
15
+ assert s1 > 0.99, "Expected similarity close to 1 for identical vectors"
16
+
17
+ # Similarity of different pair (a,b)
18
+ s2 = combined_similarity(a, shap_a, b, shap_b, alpha=0.5)
19
+ print(f"Similarity (different/orthogonal): {s2:.4f}")
20
+ assert s2 < s1, "Expected orthogonal similarity to be smaller"
21
+
22
+ return s1, s2
23
+
24
+
25
+ if __name__ == "__main__":
26
+ print("Running combined_similarity tests...\n")
27
+ s1, s2 = test_combined_similarity_basic()
28
+ print("\n Test passed!")