Spaces:
Build error
Build error
Tiffany Degbotse commited on
Commit ·
2ae10e0
1
Parent(s): 6e3858e
query with your model
Browse files- app/__pycache__/api_fastapi.cpython-313.pyc +0 -0
- app/api_fastapi.py +101 -0
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-313.pyc +0 -0
- core/__pycache__/explain.cpython-313.pyc +0 -0
- core/__pycache__/model_loader.cpython-313.pyc +0 -0
- core/__pycache__/retrieval.cpython-313.pyc +0 -0
- core/__pycache__/schemas.cpython-313.pyc +0 -0
- core/__pycache__/storage.cpython-313.pyc +0 -0
- core/__pycache__/utils.cpython-313.pyc +0 -0
- core/explain.py +106 -0
- core/model_loader.py +19 -0
- core/retrieval.py +77 -0
- core/schemas.py +37 -0
- core/storage.py +95 -0
- core/utils.py +20 -0
- data/base_indices/iris_global/features.npy +0 -0
- data/base_indices/iris_global/index.jsonl +100 -0
- data/base_indices/iris_global/meta.jsonl +100 -0
- data/base_indices/iris_global/shap.npy +0 -0
- model_data/data.csv +151 -0
- model_data/model.pkl +0 -0
- requirements.txt +11 -0
- scripts/__pycache__/build_base_index.cpython-313.pyc +0 -0
- scripts/add_user_model.py +0 -0
- scripts/build_base_index.py +62 -0
- scripts/build_iris.bat +13 -0
- scripts/demo_predict.py +0 -0
- tests/__pycache__/test_similarity.cpython-313.pyc +0 -0
- tests/test_similarity.py +28 -0
app/__pycache__/api_fastapi.cpython-313.pyc
ADDED
|
Binary file (4.17 kB). View file
|
|
|
app/api_fastapi.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Optional, List
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from Query_Your_Model.core.schemas import RetrievalConfig, ExplainResponse
|
| 7 |
+
from Query_Your_Model.core.model_loader import load_model
|
| 8 |
+
from Query_Your_Model.core.explain import explain_instance
|
| 9 |
+
from Query_Your_Model.core.retrieval import retrieve_topk
|
| 10 |
+
from Query_Your_Model.core.utils import safe_proba_to_scalar
|
| 11 |
+
|
| 12 |
+
app = FastAPI(title="Reasoning-RAG XAI API")
|
| 13 |
+
|
| 14 |
+
# Cached globals
|
| 15 |
+
MODEL = None
|
| 16 |
+
FEATURE_NAMES: Optional[List[str]] = None
|
| 17 |
+
BACKGROUND = None
|
| 18 |
+
NAMESPACE = "Query_Your_Model/data/base_indices/iris_global"
|
| 19 |
+
|
| 20 |
+
# --- Target name mappings (extend per dataset/model) ---
|
| 21 |
+
TARGET_NAMES = {
|
| 22 |
+
"iris": ["setosa", "versicolor", "virginica"],
|
| 23 |
+
# add more datasets here if needed
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ExplainRequest(BaseModel):
|
| 28 |
+
model_path: str
|
| 29 |
+
feature_names: List[str]
|
| 30 |
+
features: List[float]
|
| 31 |
+
namespace: Optional[str] = None
|
| 32 |
+
retrieval: Optional[RetrievalConfig] = None
|
| 33 |
+
background_path: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@app.post("/explain", response_model=ExplainResponse)
|
| 37 |
+
def explain(req: ExplainRequest):
|
| 38 |
+
global MODEL, FEATURE_NAMES, BACKGROUND
|
| 39 |
+
|
| 40 |
+
# Load model if not cached
|
| 41 |
+
if (MODEL is None) or (FEATURE_NAMES != req.feature_names):
|
| 42 |
+
MODEL = load_model(req.model_path)
|
| 43 |
+
FEATURE_NAMES = req.feature_names
|
| 44 |
+
BACKGROUND = None # optionally load background data
|
| 45 |
+
|
| 46 |
+
# Convert input features
|
| 47 |
+
x = np.asarray(req.features, dtype="float32").reshape(1, -1)
|
| 48 |
+
|
| 49 |
+
# Prediction & probability
|
| 50 |
+
y_class = 0
|
| 51 |
+
proba_scalar = None
|
| 52 |
+
try:
|
| 53 |
+
y_pred = MODEL.predict(x)
|
| 54 |
+
y_class = int(y_pred[0])
|
| 55 |
+
|
| 56 |
+
if hasattr(MODEL, "predict_proba"):
|
| 57 |
+
proba = MODEL.predict_proba(x)
|
| 58 |
+
proba_scalar = float(proba[0][y_class])
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print("Prediction error:", e)
|
| 61 |
+
|
| 62 |
+
# --- Map class ID -> human-readable label ---
|
| 63 |
+
model_key = "iris" if "iris" in req.model_path.lower() else None
|
| 64 |
+
if model_key and model_key in TARGET_NAMES:
|
| 65 |
+
y_label = TARGET_NAMES[model_key][y_class]
|
| 66 |
+
else:
|
| 67 |
+
y_label = str(y_class)
|
| 68 |
+
|
| 69 |
+
# SHAP explanation
|
| 70 |
+
exp = explain_instance(
|
| 71 |
+
MODEL,
|
| 72 |
+
x[0],
|
| 73 |
+
FEATURE_NAMES,
|
| 74 |
+
background_X=(BACKGROUND if BACKGROUND is not None else x),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Retrieval
|
| 78 |
+
similar = None
|
| 79 |
+
ns = req.namespace or NAMESPACE
|
| 80 |
+
if req.retrieval and req.retrieval.use_retrieval:
|
| 81 |
+
shap_q = np.array(exp["shap_values"], dtype="float32")
|
| 82 |
+
similar = retrieve_topk(ns, shap_q, x[0], alpha=req.retrieval.alpha, k=req.retrieval.k)
|
| 83 |
+
|
| 84 |
+
# also map labels for retrieved cases
|
| 85 |
+
if model_key and model_key in TARGET_NAMES:
|
| 86 |
+
for case in similar:
|
| 87 |
+
if case.get("y_pred") is not None:
|
| 88 |
+
try:
|
| 89 |
+
case["y_pred"] = TARGET_NAMES[model_key][int(case["y_pred"])]
|
| 90 |
+
except Exception:
|
| 91 |
+
case["y_pred"] = str(case["y_pred"])
|
| 92 |
+
|
| 93 |
+
return ExplainResponse(
|
| 94 |
+
prediction={
|
| 95 |
+
"y_pred": y_label, # now returns "setosa", "versicolor", etc.
|
| 96 |
+
"proba": proba_scalar,
|
| 97 |
+
},
|
| 98 |
+
explanation=exp,
|
| 99 |
+
similar_cases=similar or [],
|
| 100 |
+
ood_flag=False
|
| 101 |
+
)
|
core/__init__.py
ADDED
|
File without changes
|
core/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (183 Bytes). View file
|
|
|
core/__pycache__/explain.cpython-313.pyc
ADDED
|
Binary file (4.5 kB). View file
|
|
|
core/__pycache__/model_loader.cpython-313.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
core/__pycache__/retrieval.cpython-313.pyc
ADDED
|
Binary file (3.52 kB). View file
|
|
|
core/__pycache__/schemas.cpython-313.pyc
ADDED
|
Binary file (2.6 kB). View file
|
|
|
core/__pycache__/storage.cpython-313.pyc
ADDED
|
Binary file (6.2 kB). View file
|
|
|
core/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
core/explain.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any
|
| 2 |
+
import numpy as np
|
| 3 |
+
import shap
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _pick_explainer(model, X_background: np.ndarray):
|
| 7 |
+
"""
|
| 8 |
+
Choose an appropriate SHAP explainer.
|
| 9 |
+
- TreeExplainer for tree-based models
|
| 10 |
+
- LinearExplainer for linear models
|
| 11 |
+
- KernelExplainer fallback (slow but general)
|
| 12 |
+
"""
|
| 13 |
+
try:
|
| 14 |
+
import xgboost # noqa: F401
|
| 15 |
+
is_tree = hasattr(model, "get_booster") or "xgb" in type(model).__name__.lower()
|
| 16 |
+
except Exception:
|
| 17 |
+
is_tree = False
|
| 18 |
+
|
| 19 |
+
is_tree = is_tree or any(
|
| 20 |
+
s in type(model).__name__.lower()
|
| 21 |
+
for s in ["randomforest", "gradientboost", "gbm", "lightgbm", "catboost"]
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
if is_tree:
|
| 25 |
+
return shap.TreeExplainer(model, feature_perturbation="tree_path_dependent")
|
| 26 |
+
|
| 27 |
+
is_linear = "linear" in type(model).__name__.lower() or hasattr(model, "coef_")
|
| 28 |
+
if is_linear:
|
| 29 |
+
return shap.LinearExplainer(model, X_background)
|
| 30 |
+
|
| 31 |
+
# Fallback for anything else
|
| 32 |
+
return shap.KernelExplainer(model.predict, X_background)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def explain_instance(
|
| 36 |
+
model,
|
| 37 |
+
x: np.ndarray,
|
| 38 |
+
feature_names: List[str],
|
| 39 |
+
background_X: np.ndarray,
|
| 40 |
+
top_k: int = 8,
|
| 41 |
+
) -> Dict[str, Any]:
|
| 42 |
+
"""
|
| 43 |
+
Compute SHAP for a single instance x (shape: (n_features,)).
|
| 44 |
+
Always reduces SHAP output to a vector of length = n_features.
|
| 45 |
+
Handles multiclass by averaging across classes.
|
| 46 |
+
"""
|
| 47 |
+
x = x.reshape(1, -1)
|
| 48 |
+
explainer = _pick_explainer(model, background_X)
|
| 49 |
+
|
| 50 |
+
values = explainer.shap_values(x)
|
| 51 |
+
|
| 52 |
+
# SHAP returns different shapes depending on model type
|
| 53 |
+
if isinstance(values, list): # multiclass -> list of arrays
|
| 54 |
+
# stack into shape (n_classes, n_samples, n_features)
|
| 55 |
+
values_arr = np.stack(values, axis=0)
|
| 56 |
+
# average across classes -> shape (n_samples, n_features)
|
| 57 |
+
values_arr = np.mean(values_arr, axis=0)
|
| 58 |
+
else:
|
| 59 |
+
values_arr = values # already (n_samples, n_features)
|
| 60 |
+
|
| 61 |
+
# Always flatten to 1D vector
|
| 62 |
+
shap_vec = np.array(values_arr[0]).reshape(-1)
|
| 63 |
+
|
| 64 |
+
# Ensure length matches feature_names
|
| 65 |
+
n_features = len(feature_names)
|
| 66 |
+
if len(shap_vec) != n_features:
|
| 67 |
+
shap_vec = shap_vec[:n_features]
|
| 68 |
+
|
| 69 |
+
base_value = explainer.expected_value
|
| 70 |
+
if isinstance(base_value, (list, np.ndarray)):
|
| 71 |
+
base_value = float(np.mean(base_value))
|
| 72 |
+
|
| 73 |
+
# Top-k by absolute impact
|
| 74 |
+
abs_imp = np.abs(shap_vec)
|
| 75 |
+
idx = np.argsort(-abs_imp)[:top_k].ravel()
|
| 76 |
+
|
| 77 |
+
top = []
|
| 78 |
+
for i in idx:
|
| 79 |
+
i = int(i)
|
| 80 |
+
if i >= n_features: # safety check
|
| 81 |
+
continue
|
| 82 |
+
|
| 83 |
+
shap_val = shap_vec[i]
|
| 84 |
+
if isinstance(shap_val, (np.ndarray, list)):
|
| 85 |
+
shap_val = float(np.mean(shap_val))
|
| 86 |
+
else:
|
| 87 |
+
shap_val = float(shap_val)
|
| 88 |
+
|
| 89 |
+
abs_val = abs_imp[i]
|
| 90 |
+
if isinstance(abs_val, (np.ndarray, list)):
|
| 91 |
+
abs_val = float(np.mean(abs_val))
|
| 92 |
+
else:
|
| 93 |
+
abs_val = float(abs_val)
|
| 94 |
+
|
| 95 |
+
top.append({
|
| 96 |
+
"feature": feature_names[i],
|
| 97 |
+
"value": float(x[0, i]),
|
| 98 |
+
"shap": shap_val,
|
| 99 |
+
"abs_impact": abs_val,
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
return {
|
| 103 |
+
"shap_values": shap_vec.tolist(),
|
| 104 |
+
"base_value": float(base_value),
|
| 105 |
+
"topk": top,
|
| 106 |
+
}
|
core/model_loader.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
from typing import Any, Tuple, Optional
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def load_model(path: str) -> Any:
|
| 6 |
+
"""Load a pickled sklearn-compatible model."""
|
| 7 |
+
model = joblib.load(path)
|
| 8 |
+
return model
|
| 9 |
+
|
| 10 |
+
def predict(model: Any, X: np.ndarray) -> Tuple[np.ndarray, Optional[np.ndarray]]:
|
| 11 |
+
"""Return (pred, proba_or_none). Handles regressors & classifiers."""
|
| 12 |
+
y_pred = model.predict(X)
|
| 13 |
+
proba = None
|
| 14 |
+
if hasattr(model, "predict_proba"):
|
| 15 |
+
try:
|
| 16 |
+
proba = model.predict_proba(X)
|
| 17 |
+
except Exception:
|
| 18 |
+
proba = None
|
| 19 |
+
return y_pred, proba
|
core/retrieval.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List
|
| 2 |
+
import numpy as np
|
| 3 |
+
from .storage import load_matrices
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _cosine(a: np.ndarray, b: np.ndarray) -> float:
|
| 7 |
+
na = np.linalg.norm(a) + 1e-12
|
| 8 |
+
nb = np.linalg.norm(b) + 1e-12
|
| 9 |
+
return float(np.dot(a, b) / (na * nb))
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def combined_similarity(
|
| 13 |
+
shap_q: np.ndarray,
|
| 14 |
+
feat_q: np.ndarray,
|
| 15 |
+
shap_i: np.ndarray,
|
| 16 |
+
feat_i: np.ndarray,
|
| 17 |
+
alpha: float
|
| 18 |
+
) -> float:
|
| 19 |
+
"""similarity = alpha * cos(SHAP) + (1 - alpha) * cos(features)"""
|
| 20 |
+
return alpha * _cosine(shap_q, shap_i) + (1.0 - alpha) * _cosine(feat_q, feat_i)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def retrieve_topk(
|
| 24 |
+
namespace: str,
|
| 25 |
+
shap_q: np.ndarray,
|
| 26 |
+
x_q: np.ndarray,
|
| 27 |
+
alpha: float = 0.5,
|
| 28 |
+
k: int = 5
|
| 29 |
+
) -> List[Dict[str, Any]]:
|
| 30 |
+
"""
|
| 31 |
+
Retrieve top-k similar cases from a namespace.
|
| 32 |
+
Returns dicts with case_id, similarity, y_pred, shap_values, features, meta.
|
| 33 |
+
"""
|
| 34 |
+
# Load stored matrices and metadata
|
| 35 |
+
X, SHAP, metas, case_ids = load_matrices(namespace)
|
| 36 |
+
|
| 37 |
+
# Flatten metas into a dict keyed by case_id
|
| 38 |
+
meta_dict: Dict[str, Dict[str, Any]] = {}
|
| 39 |
+
for m in metas:
|
| 40 |
+
if isinstance(m, dict):
|
| 41 |
+
meta_dict.update(m)
|
| 42 |
+
|
| 43 |
+
sims: List[Dict[str, Any]] = []
|
| 44 |
+
for i, cid in enumerate(case_ids):
|
| 45 |
+
feat = X[i]
|
| 46 |
+
shap = SHAP[i]
|
| 47 |
+
|
| 48 |
+
# compute similarity
|
| 49 |
+
score = combined_similarity(shap_q, x_q, shap, feat, alpha=alpha)
|
| 50 |
+
|
| 51 |
+
# get meta (safe fallback)
|
| 52 |
+
m = meta_dict.get(cid, {})
|
| 53 |
+
|
| 54 |
+
sims.append({
|
| 55 |
+
"case_id": cid,
|
| 56 |
+
"similarity": float(score),
|
| 57 |
+
"y_pred": m.get("y_pred"),
|
| 58 |
+
"shap_values": shap.tolist(),
|
| 59 |
+
"features": feat.tolist(),
|
| 60 |
+
"meta": m
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
# sort and return top-k
|
| 64 |
+
sims = sorted(sims, key=lambda d: -d["similarity"])
|
| 65 |
+
return sims[:k]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def ood_score(shap_query: np.ndarray, shaps_matrix: np.ndarray) -> float:
|
| 69 |
+
"""Simple OOD heuristic: 1 - max cosine against corpus SHAPs."""
|
| 70 |
+
if shaps_matrix.size == 0:
|
| 71 |
+
return 1.0
|
| 72 |
+
best = -1.0
|
| 73 |
+
for i in range(shaps_matrix.shape[0]):
|
| 74 |
+
c = _cosine(shap_query, shaps_matrix[i])
|
| 75 |
+
if c > best:
|
| 76 |
+
best = c
|
| 77 |
+
return float(1.0 - best)
|
core/schemas.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Dict, Any
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
class Instance(BaseModel):
|
| 5 |
+
# Ordered feature vector for your model
|
| 6 |
+
features: List[float]
|
| 7 |
+
feature_names: List[str]
|
| 8 |
+
|
| 9 |
+
class PredictionResult(BaseModel):
|
| 10 |
+
y_pred: float
|
| 11 |
+
proba: Optional[float] = None
|
| 12 |
+
|
| 13 |
+
class Explanation(BaseModel):
|
| 14 |
+
shap_values: List[float] # reasoning vector
|
| 15 |
+
base_value: float
|
| 16 |
+
topk: List[Dict[str, Any]]
|
| 17 |
+
|
| 18 |
+
class RetrievalConfig(BaseModel):
|
| 19 |
+
alpha: float = 0.7 # weight for SHAP cosine vs feature cosine
|
| 20 |
+
k: int = 5
|
| 21 |
+
use_retrieval: bool = True
|
| 22 |
+
namespace: str = "global_default"
|
| 23 |
+
|
| 24 |
+
class RetrievedCase(BaseModel):
|
| 25 |
+
case_id: str
|
| 26 |
+
similarity: float
|
| 27 |
+
y_pred: Optional[float] = None
|
| 28 |
+
shap_values: Optional[List[float]] = None
|
| 29 |
+
features: Optional[List[float]] = None
|
| 30 |
+
meta: Optional[Dict[str, Any]] = None
|
| 31 |
+
|
| 32 |
+
class ExplainResponse(BaseModel):
|
| 33 |
+
prediction: PredictionResult
|
| 34 |
+
explanation: Explanation
|
| 35 |
+
similar_cases: Optional[List[RetrievedCase]] = None
|
| 36 |
+
ood_flag: bool = False
|
| 37 |
+
ood_reason: Optional[str] = None
|
core/storage.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import List, Dict, Any, Tuple
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
INDEX_FILE = "index.jsonl"
|
| 7 |
+
FEATURE_FILE = "features.npy"
|
| 8 |
+
SHAP_FILE = "shap.npy"
|
| 9 |
+
META_FILE = "meta.jsonl"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def ensure_dir(path: str):
|
| 13 |
+
os.makedirs(path, exist_ok=True)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def append_jsonl(path: str, row: Dict[str, Any]):
|
| 17 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 18 |
+
f.write(json.dumps(row) + "\n")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_index(path: str) -> List[Dict[str, Any]]:
|
| 22 |
+
rows = []
|
| 23 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 24 |
+
for line in f:
|
| 25 |
+
if line.strip():
|
| 26 |
+
rows.append(json.loads(line))
|
| 27 |
+
return rows
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def init_matrix_files(namespace_dir: str, feature_dim: int, shap_dim: int):
|
| 31 |
+
"""Create empty .npy matrices if they don't exist."""
|
| 32 |
+
feat_path = os.path.join(namespace_dir, FEATURE_FILE)
|
| 33 |
+
shap_path = os.path.join(namespace_dir, SHAP_FILE)
|
| 34 |
+
if not os.path.exists(feat_path):
|
| 35 |
+
np.save(feat_path, np.zeros((0, feature_dim), dtype="float32"))
|
| 36 |
+
if not os.path.exists(shap_path):
|
| 37 |
+
np.save(shap_path, np.zeros((0, shap_dim), dtype="float32"))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def append_case(namespace_dir: str, case_id: str, features: np.ndarray, shap_vec: np.ndarray, meta: Dict[str, Any]):
|
| 41 |
+
"""Append one case to the namespace store."""
|
| 42 |
+
ensure_dir(namespace_dir)
|
| 43 |
+
|
| 44 |
+
# grow matrices
|
| 45 |
+
feat_path = os.path.join(namespace_dir, FEATURE_FILE)
|
| 46 |
+
shap_path = os.path.join(namespace_dir, SHAP_FILE)
|
| 47 |
+
feats = np.load(feat_path)
|
| 48 |
+
shaps = np.load(shap_path)
|
| 49 |
+
feats = np.vstack([feats, features.reshape(1, -1).astype("float32")])
|
| 50 |
+
shaps = np.vstack([shaps, shap_vec.reshape(1, -1).astype("float32")])
|
| 51 |
+
np.save(feat_path, feats)
|
| 52 |
+
np.save(shap_path, shaps)
|
| 53 |
+
|
| 54 |
+
# index & meta
|
| 55 |
+
idx_path = os.path.join(namespace_dir, INDEX_FILE)
|
| 56 |
+
append_jsonl(idx_path, {"case_id": case_id, "row": feats.shape[0] - 1})
|
| 57 |
+
meta_path = os.path.join(namespace_dir, META_FILE)
|
| 58 |
+
append_jsonl(meta_path, {case_id: meta})
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_matrices(namespace_dir: str) -> Tuple[np.ndarray, np.ndarray, List[Dict[str, Any]], List[str]]:
|
| 62 |
+
"""
|
| 63 |
+
Load all stored matrices and metadata for retrieval.
|
| 64 |
+
Returns:
|
| 65 |
+
X (np.ndarray) : Features matrix
|
| 66 |
+
SHAP (np.ndarray) : SHAP values matrix
|
| 67 |
+
metas (list[dict]) : Metadata entries
|
| 68 |
+
case_ids (list[str]): Case IDs
|
| 69 |
+
"""
|
| 70 |
+
# Load features & shap
|
| 71 |
+
feat_path = os.path.join(namespace_dir, FEATURE_FILE)
|
| 72 |
+
shap_path = os.path.join(namespace_dir, SHAP_FILE)
|
| 73 |
+
X = np.load(feat_path)
|
| 74 |
+
SHAP = np.load(shap_path)
|
| 75 |
+
|
| 76 |
+
# Load metadata
|
| 77 |
+
metas = []
|
| 78 |
+
meta_path = os.path.join(namespace_dir, META_FILE)
|
| 79 |
+
if os.path.exists(meta_path):
|
| 80 |
+
with open(meta_path, "r", encoding="utf-8") as f:
|
| 81 |
+
for line in f:
|
| 82 |
+
if line.strip():
|
| 83 |
+
metas.append(json.loads(line))
|
| 84 |
+
|
| 85 |
+
# Load case IDs
|
| 86 |
+
case_ids = []
|
| 87 |
+
idx_path = os.path.join(namespace_dir, INDEX_FILE)
|
| 88 |
+
if os.path.exists(idx_path):
|
| 89 |
+
with open(idx_path, "r", encoding="utf-8") as f:
|
| 90 |
+
for line in f:
|
| 91 |
+
if line.strip():
|
| 92 |
+
entry = json.loads(line)
|
| 93 |
+
case_ids.append(entry.get("case_id"))
|
| 94 |
+
|
| 95 |
+
return X, SHAP, metas, case_ids
|
core/utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
from typing import List
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def case_id_from_vector(x: np.ndarray, prefix: str = "case") -> str:
|
| 6 |
+
h = hashlib.md5(x.tobytes()).hexdigest()[:10]
|
| 7 |
+
return f"{prefix}_{h}"
|
| 8 |
+
|
| 9 |
+
def to_numpy(lst, dtype="float32"):
|
| 10 |
+
return np.asarray(lst, dtype=dtype)
|
| 11 |
+
|
| 12 |
+
def safe_proba_to_scalar(proba, positive_index: int = 1):
|
| 13 |
+
"""Return a single probability for binary classifiers when possible."""
|
| 14 |
+
if proba is None:
|
| 15 |
+
return None
|
| 16 |
+
arr = np.asarray(proba)
|
| 17 |
+
if arr.ndim == 2 and arr.shape[1] >= 2:
|
| 18 |
+
return float(arr[0, positive_index])
|
| 19 |
+
# fallback: average
|
| 20 |
+
return float(arr.mean())
|
data/base_indices/iris_global/features.npy
ADDED
|
Binary file (1.73 kB). View file
|
|
|
data/base_indices/iris_global/index.jsonl
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"case_id": "iris_00bbac633c", "row": 0}
|
| 2 |
+
{"case_id": "iris_b2a2c274fa", "row": 1}
|
| 3 |
+
{"case_id": "iris_f79fee902c", "row": 2}
|
| 4 |
+
{"case_id": "iris_9a1a194bc4", "row": 3}
|
| 5 |
+
{"case_id": "iris_d5f6d63eb7", "row": 4}
|
| 6 |
+
{"case_id": "iris_fb322465b8", "row": 5}
|
| 7 |
+
{"case_id": "iris_751758f9a1", "row": 6}
|
| 8 |
+
{"case_id": "iris_2a967bb0c8", "row": 7}
|
| 9 |
+
{"case_id": "iris_e7d76ce04e", "row": 8}
|
| 10 |
+
{"case_id": "iris_7be756b6c3", "row": 9}
|
| 11 |
+
{"case_id": "iris_2a8b6920ad", "row": 10}
|
| 12 |
+
{"case_id": "iris_95f30553cb", "row": 11}
|
| 13 |
+
{"case_id": "iris_190a3c83bf", "row": 12}
|
| 14 |
+
{"case_id": "iris_dddf43f88f", "row": 13}
|
| 15 |
+
{"case_id": "iris_be9c18be01", "row": 14}
|
| 16 |
+
{"case_id": "iris_0f92794e2b", "row": 15}
|
| 17 |
+
{"case_id": "iris_41d43e7e9a", "row": 16}
|
| 18 |
+
{"case_id": "iris_8f66265fc7", "row": 17}
|
| 19 |
+
{"case_id": "iris_d3945eb482", "row": 18}
|
| 20 |
+
{"case_id": "iris_4b1e78fdc5", "row": 19}
|
| 21 |
+
{"case_id": "iris_401ec0e4cd", "row": 20}
|
| 22 |
+
{"case_id": "iris_408e8d870d", "row": 21}
|
| 23 |
+
{"case_id": "iris_9cafb2d428", "row": 22}
|
| 24 |
+
{"case_id": "iris_e50ca52202", "row": 23}
|
| 25 |
+
{"case_id": "iris_2babca4f93", "row": 24}
|
| 26 |
+
{"case_id": "iris_306decaed1", "row": 25}
|
| 27 |
+
{"case_id": "iris_8925772bb8", "row": 26}
|
| 28 |
+
{"case_id": "iris_16f2c8a614", "row": 27}
|
| 29 |
+
{"case_id": "iris_affabb42bd", "row": 28}
|
| 30 |
+
{"case_id": "iris_cd147f78d3", "row": 29}
|
| 31 |
+
{"case_id": "iris_60ceafb3b7", "row": 30}
|
| 32 |
+
{"case_id": "iris_971bb14551", "row": 31}
|
| 33 |
+
{"case_id": "iris_3c46aadfa8", "row": 32}
|
| 34 |
+
{"case_id": "iris_8949d2093a", "row": 33}
|
| 35 |
+
{"case_id": "iris_54db69a5ef", "row": 34}
|
| 36 |
+
{"case_id": "iris_553603a759", "row": 35}
|
| 37 |
+
{"case_id": "iris_1fbd72f69e", "row": 36}
|
| 38 |
+
{"case_id": "iris_1aa1718647", "row": 37}
|
| 39 |
+
{"case_id": "iris_4e47b9e277", "row": 38}
|
| 40 |
+
{"case_id": "iris_0b3fb6e054", "row": 39}
|
| 41 |
+
{"case_id": "iris_afb9f3ce89", "row": 40}
|
| 42 |
+
{"case_id": "iris_d964678b78", "row": 41}
|
| 43 |
+
{"case_id": "iris_d5afa1ffc3", "row": 42}
|
| 44 |
+
{"case_id": "iris_8d176d6739", "row": 43}
|
| 45 |
+
{"case_id": "iris_b3b9231f82", "row": 44}
|
| 46 |
+
{"case_id": "iris_948f3351ef", "row": 45}
|
| 47 |
+
{"case_id": "iris_cf7d9336af", "row": 46}
|
| 48 |
+
{"case_id": "iris_1d9428989e", "row": 47}
|
| 49 |
+
{"case_id": "iris_ca1177d767", "row": 48}
|
| 50 |
+
{"case_id": "iris_7435ef9308", "row": 49}
|
| 51 |
+
{"case_id": "iris_187546a192", "row": 50}
|
| 52 |
+
{"case_id": "iris_f67c61b994", "row": 51}
|
| 53 |
+
{"case_id": "iris_12ca8c3bc8", "row": 52}
|
| 54 |
+
{"case_id": "iris_e883f0a96b", "row": 53}
|
| 55 |
+
{"case_id": "iris_5d30ef01ab", "row": 54}
|
| 56 |
+
{"case_id": "iris_06713bd1b2", "row": 55}
|
| 57 |
+
{"case_id": "iris_cdea50b849", "row": 56}
|
| 58 |
+
{"case_id": "iris_9a7d15fcb5", "row": 57}
|
| 59 |
+
{"case_id": "iris_aa4ec334d1", "row": 58}
|
| 60 |
+
{"case_id": "iris_1753b1a603", "row": 59}
|
| 61 |
+
{"case_id": "iris_bd16db5e4c", "row": 60}
|
| 62 |
+
{"case_id": "iris_45e9c6b8be", "row": 61}
|
| 63 |
+
{"case_id": "iris_90355b0853", "row": 62}
|
| 64 |
+
{"case_id": "iris_29f5ab1fcc", "row": 63}
|
| 65 |
+
{"case_id": "iris_ba49dde13f", "row": 64}
|
| 66 |
+
{"case_id": "iris_938819d7e3", "row": 65}
|
| 67 |
+
{"case_id": "iris_ced4f5a163", "row": 66}
|
| 68 |
+
{"case_id": "iris_a0555b0006", "row": 67}
|
| 69 |
+
{"case_id": "iris_245849f78c", "row": 68}
|
| 70 |
+
{"case_id": "iris_0315cdedea", "row": 69}
|
| 71 |
+
{"case_id": "iris_678b362b66", "row": 70}
|
| 72 |
+
{"case_id": "iris_495ee2afb0", "row": 71}
|
| 73 |
+
{"case_id": "iris_ab99322692", "row": 72}
|
| 74 |
+
{"case_id": "iris_afb9f3ce89", "row": 73}
|
| 75 |
+
{"case_id": "iris_f873cbc152", "row": 74}
|
| 76 |
+
{"case_id": "iris_63c413d1a9", "row": 75}
|
| 77 |
+
{"case_id": "iris_42ca7166cd", "row": 76}
|
| 78 |
+
{"case_id": "iris_31d4a40847", "row": 77}
|
| 79 |
+
{"case_id": "iris_d458f158e0", "row": 78}
|
| 80 |
+
{"case_id": "iris_373dfcf880", "row": 79}
|
| 81 |
+
{"case_id": "iris_2d037a30fc", "row": 80}
|
| 82 |
+
{"case_id": "iris_3954a347b6", "row": 81}
|
| 83 |
+
{"case_id": "iris_7c437b3319", "row": 82}
|
| 84 |
+
{"case_id": "iris_519f77cbe0", "row": 83}
|
| 85 |
+
{"case_id": "iris_01fe6fa830", "row": 84}
|
| 86 |
+
{"case_id": "iris_d0b253b7b8", "row": 85}
|
| 87 |
+
{"case_id": "iris_0a27e6c142", "row": 86}
|
| 88 |
+
{"case_id": "iris_46a04f9bf1", "row": 87}
|
| 89 |
+
{"case_id": "iris_ef474c84d7", "row": 88}
|
| 90 |
+
{"case_id": "iris_07a988927f", "row": 89}
|
| 91 |
+
{"case_id": "iris_c93d7a8c57", "row": 90}
|
| 92 |
+
{"case_id": "iris_3af34eed28", "row": 91}
|
| 93 |
+
{"case_id": "iris_be68e3ed79", "row": 92}
|
| 94 |
+
{"case_id": "iris_55c1ce18c8", "row": 93}
|
| 95 |
+
{"case_id": "iris_0393b6cfa4", "row": 94}
|
| 96 |
+
{"case_id": "iris_70d7d9f959", "row": 95}
|
| 97 |
+
{"case_id": "iris_da32c2c5cb", "row": 96}
|
| 98 |
+
{"case_id": "iris_44888f11a6", "row": 97}
|
| 99 |
+
{"case_id": "iris_feaabdd51f", "row": 98}
|
| 100 |
+
{"case_id": "iris_c6b9d16895", "row": 99}
|
data/base_indices/iris_global/meta.jsonl
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"iris_00bbac633c": {"y_pred": 1.0}}
|
| 2 |
+
{"iris_b2a2c274fa": {"y_pred": 0.0}}
|
| 3 |
+
{"iris_f79fee902c": {"y_pred": 2.0}}
|
| 4 |
+
{"iris_9a1a194bc4": {"y_pred": 1.0}}
|
| 5 |
+
{"iris_d5f6d63eb7": {"y_pred": 1.0}}
|
| 6 |
+
{"iris_fb322465b8": {"y_pred": 0.0}}
|
| 7 |
+
{"iris_751758f9a1": {"y_pred": 1.0}}
|
| 8 |
+
{"iris_2a967bb0c8": {"y_pred": 2.0}}
|
| 9 |
+
{"iris_e7d76ce04e": {"y_pred": 1.0}}
|
| 10 |
+
{"iris_7be756b6c3": {"y_pred": 1.0}}
|
| 11 |
+
{"iris_2a8b6920ad": {"y_pred": 2.0}}
|
| 12 |
+
{"iris_95f30553cb": {"y_pred": 0.0}}
|
| 13 |
+
{"iris_190a3c83bf": {"y_pred": 0.0}}
|
| 14 |
+
{"iris_dddf43f88f": {"y_pred": 0.0}}
|
| 15 |
+
{"iris_be9c18be01": {"y_pred": 0.0}}
|
| 16 |
+
{"iris_0f92794e2b": {"y_pred": 1.0}}
|
| 17 |
+
{"iris_41d43e7e9a": {"y_pred": 2.0}}
|
| 18 |
+
{"iris_8f66265fc7": {"y_pred": 1.0}}
|
| 19 |
+
{"iris_d3945eb482": {"y_pred": 1.0}}
|
| 20 |
+
{"iris_4b1e78fdc5": {"y_pred": 2.0}}
|
| 21 |
+
{"iris_401ec0e4cd": {"y_pred": 0.0}}
|
| 22 |
+
{"iris_408e8d870d": {"y_pred": 2.0}}
|
| 23 |
+
{"iris_9cafb2d428": {"y_pred": 0.0}}
|
| 24 |
+
{"iris_e50ca52202": {"y_pred": 2.0}}
|
| 25 |
+
{"iris_2babca4f93": {"y_pred": 2.0}}
|
| 26 |
+
{"iris_306decaed1": {"y_pred": 2.0}}
|
| 27 |
+
{"iris_8925772bb8": {"y_pred": 2.0}}
|
| 28 |
+
{"iris_16f2c8a614": {"y_pred": 2.0}}
|
| 29 |
+
{"iris_affabb42bd": {"y_pred": 0.0}}
|
| 30 |
+
{"iris_cd147f78d3": {"y_pred": 0.0}}
|
| 31 |
+
{"iris_60ceafb3b7": {"y_pred": 0.0}}
|
| 32 |
+
{"iris_971bb14551": {"y_pred": 0.0}}
|
| 33 |
+
{"iris_3c46aadfa8": {"y_pred": 1.0}}
|
| 34 |
+
{"iris_8949d2093a": {"y_pred": 0.0}}
|
| 35 |
+
{"iris_54db69a5ef": {"y_pred": 0.0}}
|
| 36 |
+
{"iris_553603a759": {"y_pred": 2.0}}
|
| 37 |
+
{"iris_1fbd72f69e": {"y_pred": 1.0}}
|
| 38 |
+
{"iris_1aa1718647": {"y_pred": 0.0}}
|
| 39 |
+
{"iris_4e47b9e277": {"y_pred": 0.0}}
|
| 40 |
+
{"iris_0b3fb6e054": {"y_pred": 0.0}}
|
| 41 |
+
{"iris_afb9f3ce89": {"y_pred": 2.0}}
|
| 42 |
+
{"iris_d964678b78": {"y_pred": 1.0}}
|
| 43 |
+
{"iris_d5afa1ffc3": {"y_pred": 1.0}}
|
| 44 |
+
{"iris_8d176d6739": {"y_pred": 0.0}}
|
| 45 |
+
{"iris_b3b9231f82": {"y_pred": 0.0}}
|
| 46 |
+
{"iris_948f3351ef": {"y_pred": 1.0}}
|
| 47 |
+
{"iris_cf7d9336af": {"y_pred": 2.0}}
|
| 48 |
+
{"iris_1d9428989e": {"y_pred": 2.0}}
|
| 49 |
+
{"iris_ca1177d767": {"y_pred": 1.0}}
|
| 50 |
+
{"iris_7435ef9308": {"y_pred": 2.0}}
|
| 51 |
+
{"iris_187546a192": {"y_pred": 1.0}}
|
| 52 |
+
{"iris_f67c61b994": {"y_pred": 2.0}}
|
| 53 |
+
{"iris_12ca8c3bc8": {"y_pred": 1.0}}
|
| 54 |
+
{"iris_e883f0a96b": {"y_pred": 0.0}}
|
| 55 |
+
{"iris_5d30ef01ab": {"y_pred": 2.0}}
|
| 56 |
+
{"iris_06713bd1b2": {"y_pred": 1.0}}
|
| 57 |
+
{"iris_cdea50b849": {"y_pred": 0.0}}
|
| 58 |
+
{"iris_9a7d15fcb5": {"y_pred": 0.0}}
|
| 59 |
+
{"iris_aa4ec334d1": {"y_pred": 0.0}}
|
| 60 |
+
{"iris_1753b1a603": {"y_pred": 1.0}}
|
| 61 |
+
{"iris_bd16db5e4c": {"y_pred": 2.0}}
|
| 62 |
+
{"iris_45e9c6b8be": {"y_pred": 0.0}}
|
| 63 |
+
{"iris_90355b0853": {"y_pred": 0.0}}
|
| 64 |
+
{"iris_29f5ab1fcc": {"y_pred": 0.0}}
|
| 65 |
+
{"iris_ba49dde13f": {"y_pred": 1.0}}
|
| 66 |
+
{"iris_938819d7e3": {"y_pred": 0.0}}
|
| 67 |
+
{"iris_ced4f5a163": {"y_pred": 1.0}}
|
| 68 |
+
{"iris_a0555b0006": {"y_pred": 2.0}}
|
| 69 |
+
{"iris_245849f78c": {"y_pred": 0.0}}
|
| 70 |
+
{"iris_0315cdedea": {"y_pred": 1.0}}
|
| 71 |
+
{"iris_678b362b66": {"y_pred": 2.0}}
|
| 72 |
+
{"iris_495ee2afb0": {"y_pred": 0.0}}
|
| 73 |
+
{"iris_ab99322692": {"y_pred": 2.0}}
|
| 74 |
+
{"iris_afb9f3ce89": {"y_pred": 2.0}}
|
| 75 |
+
{"iris_f873cbc152": {"y_pred": 1.0}}
|
| 76 |
+
{"iris_63c413d1a9": {"y_pred": 1.0}}
|
| 77 |
+
{"iris_42ca7166cd": {"y_pred": 2.0}}
|
| 78 |
+
{"iris_31d4a40847": {"y_pred": 1.0}}
|
| 79 |
+
{"iris_d458f158e0": {"y_pred": 0.0}}
|
| 80 |
+
{"iris_373dfcf880": {"y_pred": 1.0}}
|
| 81 |
+
{"iris_2d037a30fc": {"y_pred": 2.0}}
|
| 82 |
+
{"iris_3954a347b6": {"y_pred": 0.0}}
|
| 83 |
+
{"iris_7c437b3319": {"y_pred": 0.0}}
|
| 84 |
+
{"iris_519f77cbe0": {"y_pred": 1.0}}
|
| 85 |
+
{"iris_01fe6fa830": {"y_pred": 1.0}}
|
| 86 |
+
{"iris_d0b253b7b8": {"y_pred": 0.0}}
|
| 87 |
+
{"iris_0a27e6c142": {"y_pred": 2.0}}
|
| 88 |
+
{"iris_46a04f9bf1": {"y_pred": 0.0}}
|
| 89 |
+
{"iris_ef474c84d7": {"y_pred": 0.0}}
|
| 90 |
+
{"iris_07a988927f": {"y_pred": 1.0}}
|
| 91 |
+
{"iris_c93d7a8c57": {"y_pred": 1.0}}
|
| 92 |
+
{"iris_3af34eed28": {"y_pred": 2.0}}
|
| 93 |
+
{"iris_be68e3ed79": {"y_pred": 1.0}}
|
| 94 |
+
{"iris_55c1ce18c8": {"y_pred": 2.0}}
|
| 95 |
+
{"iris_0393b6cfa4": {"y_pred": 2.0}}
|
| 96 |
+
{"iris_70d7d9f959": {"y_pred": 1.0}}
|
| 97 |
+
{"iris_da32c2c5cb": {"y_pred": 0.0}}
|
| 98 |
+
{"iris_44888f11a6": {"y_pred": 0.0}}
|
| 99 |
+
{"iris_feaabdd51f": {"y_pred": 2.0}}
|
| 100 |
+
{"iris_c6b9d16895": {"y_pred": 2.0}}
|
data/base_indices/iris_global/shap.npy
ADDED
|
Binary file (1.73 kB). View file
|
|
|
model_data/data.csv
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sepal length (cm),sepal width (cm),petal length (cm),petal width (cm),target
|
| 2 |
+
5.1,3.5,1.4,0.2,0
|
| 3 |
+
4.9,3.0,1.4,0.2,0
|
| 4 |
+
4.7,3.2,1.3,0.2,0
|
| 5 |
+
4.6,3.1,1.5,0.2,0
|
| 6 |
+
5.0,3.6,1.4,0.2,0
|
| 7 |
+
5.4,3.9,1.7,0.4,0
|
| 8 |
+
4.6,3.4,1.4,0.3,0
|
| 9 |
+
5.0,3.4,1.5,0.2,0
|
| 10 |
+
4.4,2.9,1.4,0.2,0
|
| 11 |
+
4.9,3.1,1.5,0.1,0
|
| 12 |
+
5.4,3.7,1.5,0.2,0
|
| 13 |
+
4.8,3.4,1.6,0.2,0
|
| 14 |
+
4.8,3.0,1.4,0.1,0
|
| 15 |
+
4.3,3.0,1.1,0.1,0
|
| 16 |
+
5.8,4.0,1.2,0.2,0
|
| 17 |
+
5.7,4.4,1.5,0.4,0
|
| 18 |
+
5.4,3.9,1.3,0.4,0
|
| 19 |
+
5.1,3.5,1.4,0.3,0
|
| 20 |
+
5.7,3.8,1.7,0.3,0
|
| 21 |
+
5.1,3.8,1.5,0.3,0
|
| 22 |
+
5.4,3.4,1.7,0.2,0
|
| 23 |
+
5.1,3.7,1.5,0.4,0
|
| 24 |
+
4.6,3.6,1.0,0.2,0
|
| 25 |
+
5.1,3.3,1.7,0.5,0
|
| 26 |
+
4.8,3.4,1.9,0.2,0
|
| 27 |
+
5.0,3.0,1.6,0.2,0
|
| 28 |
+
5.0,3.4,1.6,0.4,0
|
| 29 |
+
5.2,3.5,1.5,0.2,0
|
| 30 |
+
5.2,3.4,1.4,0.2,0
|
| 31 |
+
4.7,3.2,1.6,0.2,0
|
| 32 |
+
4.8,3.1,1.6,0.2,0
|
| 33 |
+
5.4,3.4,1.5,0.4,0
|
| 34 |
+
5.2,4.1,1.5,0.1,0
|
| 35 |
+
5.5,4.2,1.4,0.2,0
|
| 36 |
+
4.9,3.1,1.5,0.2,0
|
| 37 |
+
5.0,3.2,1.2,0.2,0
|
| 38 |
+
5.5,3.5,1.3,0.2,0
|
| 39 |
+
4.9,3.6,1.4,0.1,0
|
| 40 |
+
4.4,3.0,1.3,0.2,0
|
| 41 |
+
5.1,3.4,1.5,0.2,0
|
| 42 |
+
5.0,3.5,1.3,0.3,0
|
| 43 |
+
4.5,2.3,1.3,0.3,0
|
| 44 |
+
4.4,3.2,1.3,0.2,0
|
| 45 |
+
5.0,3.5,1.6,0.6,0
|
| 46 |
+
5.1,3.8,1.9,0.4,0
|
| 47 |
+
4.8,3.0,1.4,0.3,0
|
| 48 |
+
5.1,3.8,1.6,0.2,0
|
| 49 |
+
4.6,3.2,1.4,0.2,0
|
| 50 |
+
5.3,3.7,1.5,0.2,0
|
| 51 |
+
5.0,3.3,1.4,0.2,0
|
| 52 |
+
7.0,3.2,4.7,1.4,1
|
| 53 |
+
6.4,3.2,4.5,1.5,1
|
| 54 |
+
6.9,3.1,4.9,1.5,1
|
| 55 |
+
5.5,2.3,4.0,1.3,1
|
| 56 |
+
6.5,2.8,4.6,1.5,1
|
| 57 |
+
5.7,2.8,4.5,1.3,1
|
| 58 |
+
6.3,3.3,4.7,1.6,1
|
| 59 |
+
4.9,2.4,3.3,1.0,1
|
| 60 |
+
6.6,2.9,4.6,1.3,1
|
| 61 |
+
5.2,2.7,3.9,1.4,1
|
| 62 |
+
5.0,2.0,3.5,1.0,1
|
| 63 |
+
5.9,3.0,4.2,1.5,1
|
| 64 |
+
6.0,2.2,4.0,1.0,1
|
| 65 |
+
6.1,2.9,4.7,1.4,1
|
| 66 |
+
5.6,2.9,3.6,1.3,1
|
| 67 |
+
6.7,3.1,4.4,1.4,1
|
| 68 |
+
5.6,3.0,4.5,1.5,1
|
| 69 |
+
5.8,2.7,4.1,1.0,1
|
| 70 |
+
6.2,2.2,4.5,1.5,1
|
| 71 |
+
5.6,2.5,3.9,1.1,1
|
| 72 |
+
5.9,3.2,4.8,1.8,1
|
| 73 |
+
6.1,2.8,4.0,1.3,1
|
| 74 |
+
6.3,2.5,4.9,1.5,1
|
| 75 |
+
6.1,2.8,4.7,1.2,1
|
| 76 |
+
6.4,2.9,4.3,1.3,1
|
| 77 |
+
6.6,3.0,4.4,1.4,1
|
| 78 |
+
6.8,2.8,4.8,1.4,1
|
| 79 |
+
6.7,3.0,5.0,1.7,1
|
| 80 |
+
6.0,2.9,4.5,1.5,1
|
| 81 |
+
5.7,2.6,3.5,1.0,1
|
| 82 |
+
5.5,2.4,3.8,1.1,1
|
| 83 |
+
5.5,2.4,3.7,1.0,1
|
| 84 |
+
5.8,2.7,3.9,1.2,1
|
| 85 |
+
6.0,2.7,5.1,1.6,1
|
| 86 |
+
5.4,3.0,4.5,1.5,1
|
| 87 |
+
6.0,3.4,4.5,1.6,1
|
| 88 |
+
6.7,3.1,4.7,1.5,1
|
| 89 |
+
6.3,2.3,4.4,1.3,1
|
| 90 |
+
5.6,3.0,4.1,1.3,1
|
| 91 |
+
5.5,2.5,4.0,1.3,1
|
| 92 |
+
5.5,2.6,4.4,1.2,1
|
| 93 |
+
6.1,3.0,4.6,1.4,1
|
| 94 |
+
5.8,2.6,4.0,1.2,1
|
| 95 |
+
5.0,2.3,3.3,1.0,1
|
| 96 |
+
5.6,2.7,4.2,1.3,1
|
| 97 |
+
5.7,3.0,4.2,1.2,1
|
| 98 |
+
5.7,2.9,4.2,1.3,1
|
| 99 |
+
6.2,2.9,4.3,1.3,1
|
| 100 |
+
5.1,2.5,3.0,1.1,1
|
| 101 |
+
5.7,2.8,4.1,1.3,1
|
| 102 |
+
6.3,3.3,6.0,2.5,2
|
| 103 |
+
5.8,2.7,5.1,1.9,2
|
| 104 |
+
7.1,3.0,5.9,2.1,2
|
| 105 |
+
6.3,2.9,5.6,1.8,2
|
| 106 |
+
6.5,3.0,5.8,2.2,2
|
| 107 |
+
7.6,3.0,6.6,2.1,2
|
| 108 |
+
4.9,2.5,4.5,1.7,2
|
| 109 |
+
7.3,2.9,6.3,1.8,2
|
| 110 |
+
6.7,2.5,5.8,1.8,2
|
| 111 |
+
7.2,3.6,6.1,2.5,2
|
| 112 |
+
6.5,3.2,5.1,2.0,2
|
| 113 |
+
6.4,2.7,5.3,1.9,2
|
| 114 |
+
6.8,3.0,5.5,2.1,2
|
| 115 |
+
5.7,2.5,5.0,2.0,2
|
| 116 |
+
5.8,2.8,5.1,2.4,2
|
| 117 |
+
6.4,3.2,5.3,2.3,2
|
| 118 |
+
6.5,3.0,5.5,1.8,2
|
| 119 |
+
7.7,3.8,6.7,2.2,2
|
| 120 |
+
7.7,2.6,6.9,2.3,2
|
| 121 |
+
6.0,2.2,5.0,1.5,2
|
| 122 |
+
6.9,3.2,5.7,2.3,2
|
| 123 |
+
5.6,2.8,4.9,2.0,2
|
| 124 |
+
7.7,2.8,6.7,2.0,2
|
| 125 |
+
6.3,2.7,4.9,1.8,2
|
| 126 |
+
6.7,3.3,5.7,2.1,2
|
| 127 |
+
7.2,3.2,6.0,1.8,2
|
| 128 |
+
6.2,2.8,4.8,1.8,2
|
| 129 |
+
6.1,3.0,4.9,1.8,2
|
| 130 |
+
6.4,2.8,5.6,2.1,2
|
| 131 |
+
7.2,3.0,5.8,1.6,2
|
| 132 |
+
7.4,2.8,6.1,1.9,2
|
| 133 |
+
7.9,3.8,6.4,2.0,2
|
| 134 |
+
6.4,2.8,5.6,2.2,2
|
| 135 |
+
6.3,2.8,5.1,1.5,2
|
| 136 |
+
6.1,2.6,5.6,1.4,2
|
| 137 |
+
7.7,3.0,6.1,2.3,2
|
| 138 |
+
6.3,3.4,5.6,2.4,2
|
| 139 |
+
6.4,3.1,5.5,1.8,2
|
| 140 |
+
6.0,3.0,4.8,1.8,2
|
| 141 |
+
6.9,3.1,5.4,2.1,2
|
| 142 |
+
6.7,3.1,5.6,2.4,2
|
| 143 |
+
6.9,3.1,5.1,2.3,2
|
| 144 |
+
5.8,2.7,5.1,1.9,2
|
| 145 |
+
6.8,3.2,5.9,2.3,2
|
| 146 |
+
6.7,3.3,5.7,2.5,2
|
| 147 |
+
6.7,3.0,5.2,2.3,2
|
| 148 |
+
6.3,2.5,5.0,1.9,2
|
| 149 |
+
6.5,3.0,5.2,2.0,2
|
| 150 |
+
6.2,3.4,5.4,2.3,2
|
| 151 |
+
5.9,3.0,5.1,1.8,2
|
model_data/model.pkl
ADDED
|
Binary file (94.9 kB). View file
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.26.4
|
| 2 |
+
pandas==2.2.2
|
| 3 |
+
scikit-learn==1.4.2
|
| 4 |
+
shap==0.45.0
|
| 5 |
+
fastapi==0.115.0
|
| 6 |
+
uvicorn==0.30.6
|
| 7 |
+
python-multipart==0.0.9
|
| 8 |
+
streamlit==1.39.0
|
| 9 |
+
pydantic==2.9.2
|
| 10 |
+
joblib==1.4.2
|
| 11 |
+
matplotlib==3.9.2
|
scripts/__pycache__/build_base_index.cpython-313.pyc
ADDED
|
Binary file (3.03 kB). View file
|
|
|
scripts/add_user_model.py
ADDED
|
File without changes
|
scripts/build_base_index.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Precompute a 'global' reasoning space from a baseline model + dataset.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python scripts/build_base_index.py \
|
| 6 |
+
--model_path path/to/model.pkl \
|
| 7 |
+
--csv path/to/data.csv \
|
| 8 |
+
--features col1,col2,col3 \
|
| 9 |
+
--target target_col \
|
| 10 |
+
--namespace data/base_indices/recidivism_global \
|
| 11 |
+
--sample 2000
|
| 12 |
+
"""
|
| 13 |
+
# Query_Your_Model/scripts/build_base_index.py
|
| 14 |
+
import sys, os
|
| 15 |
+
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import numpy as np
|
| 21 |
+
from ..core.model_loader import load_model, predict
|
| 22 |
+
from ..core.explain import explain_instance
|
| 23 |
+
from ..core.storage import ensure_dir, init_matrix_files, append_case
|
| 24 |
+
from ..core.utils import case_id_from_vector
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Hardcoded defaults for Iris demo
|
| 28 |
+
MODEL_PATH = "Query_Your_Model/model_data/model.pkl"
|
| 29 |
+
CSV_PATH = "Query_Your_Model/model_data/data.csv"
|
| 30 |
+
FEATURES = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
|
| 31 |
+
TARGET = "target"
|
| 32 |
+
NAMESPACE = "Query_Your_Model/data/base_indices/iris_global"
|
| 33 |
+
SAMPLE = 100 # how many rows to sample
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
print("Building reasoning index...")
|
| 37 |
+
df = pd.read_csv(CSV_PATH)
|
| 38 |
+
if SAMPLE and SAMPLE < len(df):
|
| 39 |
+
df = df.sample(SAMPLE, random_state=42)
|
| 40 |
+
|
| 41 |
+
X = df[FEATURES].values
|
| 42 |
+
model = load_model(MODEL_PATH)
|
| 43 |
+
|
| 44 |
+
ensure_dir(NAMESPACE)
|
| 45 |
+
init_matrix_files(NAMESPACE, feature_dim=len(FEATURES), shap_dim=len(FEATURES))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
bg = df[FEATURES].sample(min(100, len(df)), random_state=0).values.astype("float32")
|
| 49 |
+
|
| 50 |
+
for i, row in df.iterrows():
|
| 51 |
+
x = row[FEATURES].values.astype("float32")
|
| 52 |
+
y_pred, _ = predict(model, x.reshape(1, -1))
|
| 53 |
+
exp = explain_instance(model, x, FEATURES, background_X=bg, top_k=8)
|
| 54 |
+
shap_vec = np.array(exp["shap_values"], dtype="float32")
|
| 55 |
+
cid = case_id_from_vector(x, prefix="iris")
|
| 56 |
+
meta = {"y_pred": float(y_pred[0])}
|
| 57 |
+
append_case(NAMESPACE, cid, x, shap_vec, meta)
|
| 58 |
+
|
| 59 |
+
print(f"Done! Index saved to {NAMESPACE}")
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
scripts/build_iris.bat
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
echo Building Iris reasoning index...
|
| 3 |
+
|
| 4 |
+
python Query_Your_Model/scripts/build_base_index.py ^
|
| 5 |
+
--model_path Query_Your_Model/model_data/model.pkl ^
|
| 6 |
+
--csv Query_Your_Model/model_data/data.csv ^
|
| 7 |
+
--features "sepal length (cm),sepal width (cm),petal length (cm),petal width (cm)" ^
|
| 8 |
+
--target target ^
|
| 9 |
+
--namespace Query_Your_Model/data/base_indices/iris_global ^
|
| 10 |
+
--sample 100
|
| 11 |
+
|
| 12 |
+
echo Done! Index saved to Query_Your_Model/data/base_indices/iris_global
|
| 13 |
+
pause
|
scripts/demo_predict.py
ADDED
|
File without changes
|
tests/__pycache__/test_similarity.cpython-313.pyc
ADDED
|
Binary file (1.48 kB). View file
|
|
|
tests/test_similarity.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Query_Your_Model/tests/test_similarity.py
|
| 2 |
+
import numpy as np
|
| 3 |
+
from Query_Your_Model.core.retrieval import combined_similarity
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_combined_similarity_basic():
|
| 7 |
+
a = np.array([1, 0, 0], dtype="float32") # feature vector 1
|
| 8 |
+
b = np.array([0, 1, 0], dtype="float32") # feature vector 2 (orthogonal)
|
| 9 |
+
shap_a = np.array([0.5, 0.2, 0.1], dtype="float32") # shap for a
|
| 10 |
+
shap_b = np.array([-0.5, 0.0, 0.0], dtype="float32") # shap for b
|
| 11 |
+
|
| 12 |
+
# Similarity of identical pair (a,a)
|
| 13 |
+
s1 = combined_similarity(a, shap_a, a, shap_a, alpha=0.5)
|
| 14 |
+
print(f"Similarity (identical): {s1:.4f}")
|
| 15 |
+
assert s1 > 0.99, "Expected similarity close to 1 for identical vectors"
|
| 16 |
+
|
| 17 |
+
# Similarity of different pair (a,b)
|
| 18 |
+
s2 = combined_similarity(a, shap_a, b, shap_b, alpha=0.5)
|
| 19 |
+
print(f"Similarity (different/orthogonal): {s2:.4f}")
|
| 20 |
+
assert s2 < s1, "Expected orthogonal similarity to be smaller"
|
| 21 |
+
|
| 22 |
+
return s1, s2
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
print("Running combined_similarity tests...\n")
|
| 27 |
+
s1, s2 = test_combined_similarity_basic()
|
| 28 |
+
print("\n Test passed!")
|