Spaces:
Build error
Build error
File size: 2,046 Bytes
2ae10e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | """
Precompute a 'global' reasoning space from a baseline model + dataset.
Usage:
python scripts/build_base_index.py \
--model_path path/to/model.pkl \
--csv path/to/data.csv \
--features col1,col2,col3 \
--target target_col \
--namespace data/base_indices/recidivism_global \
--sample 2000
"""
# Query_Your_Model/scripts/build_base_index.py
import sys, os
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import os
import pandas as pd
import numpy as np
from ..core.model_loader import load_model, predict
from ..core.explain import explain_instance
from ..core.storage import ensure_dir, init_matrix_files, append_case
from ..core.utils import case_id_from_vector
# Hardcoded defaults for Iris demo
MODEL_PATH = "Query_Your_Model/model_data/model.pkl"
CSV_PATH = "Query_Your_Model/model_data/data.csv"
FEATURES = ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"]
TARGET = "target"
NAMESPACE = "Query_Your_Model/data/base_indices/iris_global"
SAMPLE = 100 # how many rows to sample
def main():
print("Building reasoning index...")
df = pd.read_csv(CSV_PATH)
if SAMPLE and SAMPLE < len(df):
df = df.sample(SAMPLE, random_state=42)
X = df[FEATURES].values
model = load_model(MODEL_PATH)
ensure_dir(NAMESPACE)
init_matrix_files(NAMESPACE, feature_dim=len(FEATURES), shap_dim=len(FEATURES))
bg = df[FEATURES].sample(min(100, len(df)), random_state=0).values.astype("float32")
for i, row in df.iterrows():
x = row[FEATURES].values.astype("float32")
y_pred, _ = predict(model, x.reshape(1, -1))
exp = explain_instance(model, x, FEATURES, background_X=bg, top_k=8)
shap_vec = np.array(exp["shap_values"], dtype="float32")
cid = case_id_from_vector(x, prefix="iris")
meta = {"y_pred": float(y_pred[0])}
append_case(NAMESPACE, cid, x, shap_vec, meta)
print(f"Done! Index saved to {NAMESPACE}")
if __name__ == "__main__":
main()
|