Spaces:
Sleeping
Sleeping
File size: 8,050 Bytes
5256800 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
#!/usr/bin/env python3
# explain_model.py
import os
import json
import numpy as np
import pandas as pd
import torch
import joblib
import shap
import matplotlib.pyplot as plt
from safetensors.torch import load_file
from network import PricePredictor
# --- 0. Config ---
MODEL_DIR = "model"
DATA_DIR = "data"
SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl")
DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
CONFIG_PATH = os.path.join(MODEL_DIR, "config.json")
TARGET_COLUMN = "price_will_rise_30_in_6m"
# --- 1. Load model & assets ---
with open(CONFIG_PATH, "r") as f:
config = json.load(f)
feature_columns = config["feature_columns"]
input_size = config["input_size"]
model = PricePredictor(input_size=input_size)
model.load_state_dict(load_file(os.path.join(MODEL_DIR, "model.safetensors")))
model.eval()
scaler = joblib.load(SCALER_PATH)
full_data = pd.read_csv(DATA_PATH)
# Sanity checks
missing_cols = [c for c in feature_columns if c not in full_data.columns]
if missing_cols:
raise ValueError(f"Missing required feature columns in CSV: {missing_cols}")
features_df = full_data[feature_columns]
if features_df.shape[1] != input_size:
raise ValueError(
f"Config input_size={input_size}, but CSV provides {features_df.shape[1]} features. "
f"Ensure config['feature_columns'] matches the trained model."
)
# --- 2. Prepare Data for SHAP ---
bg_n = min(100, len(features_df))
explain_n = min(10, len(features_df))
background_idx = features_df.sample(n=bg_n, random_state=42).index
explain_idx = features_df.sample(n=explain_n, random_state=1).index
background_data = features_df.loc[background_idx]
explain_instances = features_df.loc[explain_idx]
# Use arrays for scaler to avoid feature-name warnings
background_data_scaled = scaler.transform(background_data.values)
explain_instances_scaled = scaler.transform(explain_instances.values)
background_tensor = torch.tensor(background_data_scaled, dtype=torch.float32) # no grad
explain_tensor = torch.tensor(explain_instances_scaled, dtype=torch.float32, requires_grad=True)
# --- Helpers ---
def get_shap_explanations(model, background_tensor, explain_tensor):
"""Try DeepExplainer then fall back to GradientExplainer. Return (explanation, explainer_used_name)."""
try:
print("Initializing SHAP DeepExplainer...")
explainer = shap.DeepExplainer(model, background_tensor)
print("Calculating SHAP values for the sample...")
exp = explainer(explain_tensor)
setattr(exp, "_expected_value_hint", getattr(explainer, "expected_value", None))
return exp, "deep"
except Exception as e:
print(f"[DeepExplainer failed: {e}] Falling back to GradientExplainer...")
explain_tensor.requires_grad_(True)
grad_explainer = shap.GradientExplainer(model, background_tensor)
exp = grad_explainer(explain_tensor)
setattr(exp, "_expected_value_hint", getattr(grad_explainer, "expected_value", None))
return exp, "grad"
def compute_base_value_safe(shap_explanation, instance_idx, model, background_tensor):
"""Return scalar base value robustly across SHAP versions."""
bv = getattr(shap_explanation, "base_values", None)
if bv is not None:
try:
return float(np.squeeze(bv[instance_idx]))
except Exception:
try:
return float(np.squeeze(bv))
except Exception:
pass
ev = getattr(shap_explanation, "_expected_value_hint", None)
if ev is not None:
try:
return float(np.squeeze(ev))
except Exception:
try:
return float(np.mean(ev))
except Exception:
pass
with torch.no_grad():
mu = background_tensor.mean(dim=0, keepdim=True)
out = model(mu).detach().cpu().squeeze()
return float(out.mean().item()) if out.numel() > 1 else float(out.item())
def stack_sample_shap_values(exp, n_features_expected):
"""
Some SHAP versions return exp.values with shape (n_samples, 1) or other oddities.
However, exp[i].values is typically the correct 1D (n_features,) vector.
We rebuild a full matrix by stacking per-sample slices.
"""
rows = []
n_samples = len(exp.values) if hasattr(exp.values, "__len__") else len(exp)
# Safer: iterate using the __getitem__ API
for i in range(n_samples):
v = np.asarray(exp[i].values).reshape(-1,)
rows.append(v)
M = np.vstack(rows) # (n_samples, n_features)
if M.shape[1] != n_features_expected:
raise RuntimeError(
f"Rebuilt SHAP matrix has shape {M.shape}; expected n_features={n_features_expected}."
)
return M
# --- 3. Compute SHAP explanations ---
shap_explanation, _ = get_shap_explanations(model, background_tensor, explain_tensor)
print("Calculation complete.")
# Attach unscaled display data for pretty plotting
shap_explanation.display_data = explain_instances.values
shap_explanation.feature_names = feature_columns
# --- 4a. Global Feature Importance (Bar / Summary) ---
print("\nGenerating global feature importance plot (summary_plot.png)...")
# Robustly build a (n_samples, n_features) matrix by stacking per-sample vectors
shap_vals_matrix = stack_sample_shap_values(shap_explanation, n_features_expected=len(feature_columns))
mean_abs_shap = np.abs(shap_vals_matrix).mean(axis=0) # (n_features,)
# Build a fresh Explanation with values aligned to feature_names
plot_explanation = shap.Explanation(values=mean_abs_shap, feature_names=feature_columns)
plt.figure()
shap.plots.bar(plot_explanation, show=False)
plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)")
plt.savefig("summary_plot.png", bbox_inches="tight")
plt.close()
print("Saved: summary_plot.png")
# --- 4b. Local Explanation (Force Plot) ---
print("\nGenerating local explanation for one card (force_plot.html)...")
instance_to_explain_index = 0
single_explanation = shap_explanation[instance_to_explain_index]
# Some SHAP versions drop display_data on slicing; pull directly if needed
if getattr(single_explanation, "display_data", None) is None:
row_unscaled = explain_instances.values[instance_to_explain_index]
else:
row_unscaled = single_explanation.display_data
features_row = np.atleast_2d(np.asarray(row_unscaled, dtype=float))
base_val = compute_base_value_safe(shap_explanation, instance_to_explain_index, model, background_tensor)
phi = np.asarray(single_explanation.values).reshape(-1,) # (n_features,)
force_plot = shap.force_plot(
base_val,
phi,
features=features_row,
feature_names=feature_columns
)
shap.save_html("force_plot.html", force_plot)
print("Saved: force_plot.html (open in a browser)")
# --- 4c. Optional: local waterfall PNG (often clearer) ---
try:
print("Generating local waterfall plot (waterfall_single.png)...")
plt.figure()
shap.plots.waterfall(single_explanation, show=False, max_display=20)
plt.savefig("waterfall_single.png", bbox_inches="tight")
plt.close()
print("Saved: waterfall_single.png")
except Exception as e:
print(f"Waterfall plot skipped (reason: {e})")
# --- 5. Print metadata for the explained card ---
original_card_data = full_data.loc[explain_idx[instance_to_explain_index]]
name_val = original_card_data.get("name", "N/A")
tcgp_val = original_card_data.get("tcgplayer_id", "N/A")
label_val = original_card_data.get(TARGET_COLUMN, None)
label_str = "RISE" if bool(label_val) else "NOT RISE" if label_val is not None else "N/A"
print("\n--- Card Explained in force_plot.html / waterfall_single.png ---")
print(f"Name: {name_val}")
print(f"TCGPlayer ID: {tcgp_val}")
print(f"Actual Outcome in Dataset: {label_str}")
# TODO: convert the model into a format where i can share on hugging face as a model that can be pulled down and used
# TODO: include the SHAP charts force_plot.html and summary_plot.png explaining the model, as well as compute some other evaluation metrics for explanation in the card |