File size: 8,050 Bytes
5256800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#!/usr/bin/env python3
# explain_model.py

import os
import json
import numpy as np
import pandas as pd
import torch
import joblib
import shap
import matplotlib.pyplot as plt
from safetensors.torch import load_file
from network import PricePredictor

# --- 0. Config ---
MODEL_DIR = "model"
DATA_DIR = "data"
SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl")
DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
CONFIG_PATH = os.path.join(MODEL_DIR, "config.json")
TARGET_COLUMN = "price_will_rise_30_in_6m"

# --- 1. Load model & assets ---
with open(CONFIG_PATH, "r") as f:
    config = json.load(f)

feature_columns = config["feature_columns"]
input_size = config["input_size"]

model = PricePredictor(input_size=input_size)
model.load_state_dict(load_file(os.path.join(MODEL_DIR, "model.safetensors")))
model.eval()

scaler = joblib.load(SCALER_PATH)
full_data = pd.read_csv(DATA_PATH)

# Sanity checks
missing_cols = [c for c in feature_columns if c not in full_data.columns]
if missing_cols:
    raise ValueError(f"Missing required feature columns in CSV: {missing_cols}")

features_df = full_data[feature_columns]
if features_df.shape[1] != input_size:
    raise ValueError(
        f"Config input_size={input_size}, but CSV provides {features_df.shape[1]} features. "
        f"Ensure config['feature_columns'] matches the trained model."
    )

# --- 2. Prepare Data for SHAP ---
bg_n = min(100, len(features_df))
explain_n = min(10, len(features_df))

background_idx = features_df.sample(n=bg_n, random_state=42).index
explain_idx = features_df.sample(n=explain_n, random_state=1).index

background_data = features_df.loc[background_idx]
explain_instances = features_df.loc[explain_idx]

# Use arrays for scaler to avoid feature-name warnings
background_data_scaled = scaler.transform(background_data.values)
explain_instances_scaled = scaler.transform(explain_instances.values)

background_tensor = torch.tensor(background_data_scaled, dtype=torch.float32)  # no grad
explain_tensor = torch.tensor(explain_instances_scaled, dtype=torch.float32, requires_grad=True)

# --- Helpers ---
def get_shap_explanations(model, background_tensor, explain_tensor):
    """Try DeepExplainer then fall back to GradientExplainer. Return (explanation, explainer_used_name)."""
    try:
        print("Initializing SHAP DeepExplainer...")
        explainer = shap.DeepExplainer(model, background_tensor)
        print("Calculating SHAP values for the sample...")
        exp = explainer(explain_tensor)
        setattr(exp, "_expected_value_hint", getattr(explainer, "expected_value", None))
        return exp, "deep"
    except Exception as e:
        print(f"[DeepExplainer failed: {e}] Falling back to GradientExplainer...")
        explain_tensor.requires_grad_(True)
        grad_explainer = shap.GradientExplainer(model, background_tensor)
        exp = grad_explainer(explain_tensor)
        setattr(exp, "_expected_value_hint", getattr(grad_explainer, "expected_value", None))
        return exp, "grad"

def compute_base_value_safe(shap_explanation, instance_idx, model, background_tensor):
    """Return scalar base value robustly across SHAP versions."""
    bv = getattr(shap_explanation, "base_values", None)
    if bv is not None:
        try:
            return float(np.squeeze(bv[instance_idx]))
        except Exception:
            try:
                return float(np.squeeze(bv))
            except Exception:
                pass
    ev = getattr(shap_explanation, "_expected_value_hint", None)
    if ev is not None:
        try:
            return float(np.squeeze(ev))
        except Exception:
            try:
                return float(np.mean(ev))
            except Exception:
                pass
    with torch.no_grad():
        mu = background_tensor.mean(dim=0, keepdim=True)
        out = model(mu).detach().cpu().squeeze()
        return float(out.mean().item()) if out.numel() > 1 else float(out.item())

def stack_sample_shap_values(exp, n_features_expected):
    """
    Some SHAP versions return exp.values with shape (n_samples, 1) or other oddities.
    However, exp[i].values is typically the correct 1D (n_features,) vector.
    We rebuild a full matrix by stacking per-sample slices.
    """
    rows = []
    n_samples = len(exp.values) if hasattr(exp.values, "__len__") else len(exp)
    # Safer: iterate using the __getitem__ API
    for i in range(n_samples):
        v = np.asarray(exp[i].values).reshape(-1,)
        rows.append(v)
    M = np.vstack(rows)  # (n_samples, n_features)
    if M.shape[1] != n_features_expected:
        raise RuntimeError(
            f"Rebuilt SHAP matrix has shape {M.shape}; expected n_features={n_features_expected}."
        )
    return M

# --- 3. Compute SHAP explanations ---
shap_explanation, _ = get_shap_explanations(model, background_tensor, explain_tensor)
print("Calculation complete.")

# Attach unscaled display data for pretty plotting
shap_explanation.display_data = explain_instances.values
shap_explanation.feature_names = feature_columns

# --- 4a. Global Feature Importance (Bar / Summary) ---
print("\nGenerating global feature importance plot (summary_plot.png)...")

# Robustly build a (n_samples, n_features) matrix by stacking per-sample vectors
shap_vals_matrix = stack_sample_shap_values(shap_explanation, n_features_expected=len(feature_columns))

mean_abs_shap = np.abs(shap_vals_matrix).mean(axis=0)  # (n_features,)

# Build a fresh Explanation with values aligned to feature_names
plot_explanation = shap.Explanation(values=mean_abs_shap, feature_names=feature_columns)

plt.figure()
shap.plots.bar(plot_explanation, show=False)
plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)")
plt.savefig("summary_plot.png", bbox_inches="tight")
plt.close()
print("Saved: summary_plot.png")

# --- 4b. Local Explanation (Force Plot) ---
print("\nGenerating local explanation for one card (force_plot.html)...")
instance_to_explain_index = 0
single_explanation = shap_explanation[instance_to_explain_index]

# Some SHAP versions drop display_data on slicing; pull directly if needed
if getattr(single_explanation, "display_data", None) is None:
    row_unscaled = explain_instances.values[instance_to_explain_index]
else:
    row_unscaled = single_explanation.display_data
features_row = np.atleast_2d(np.asarray(row_unscaled, dtype=float))

base_val = compute_base_value_safe(shap_explanation, instance_to_explain_index, model, background_tensor)
phi = np.asarray(single_explanation.values).reshape(-1,)  # (n_features,)

force_plot = shap.force_plot(
    base_val,
    phi,
    features=features_row,
    feature_names=feature_columns
)
shap.save_html("force_plot.html", force_plot)
print("Saved: force_plot.html (open in a browser)")

# --- 4c. Optional: local waterfall PNG (often clearer) ---
try:
    print("Generating local waterfall plot (waterfall_single.png)...")
    plt.figure()
    shap.plots.waterfall(single_explanation, show=False, max_display=20)
    plt.savefig("waterfall_single.png", bbox_inches="tight")
    plt.close()
    print("Saved: waterfall_single.png")
except Exception as e:
    print(f"Waterfall plot skipped (reason: {e})")

# --- 5. Print metadata for the explained card ---
original_card_data = full_data.loc[explain_idx[instance_to_explain_index]]
name_val = original_card_data.get("name", "N/A")
tcgp_val = original_card_data.get("tcgplayer_id", "N/A")
label_val = original_card_data.get(TARGET_COLUMN, None)
label_str = "RISE" if bool(label_val) else "NOT RISE" if label_val is not None else "N/A"

print("\n--- Card Explained in force_plot.html / waterfall_single.png ---")
print(f"Name: {name_val}")
print(f"TCGPlayer ID: {tcgp_val}")
print(f"Actual Outcome in Dataset: {label_str}")

# TODO: convert the model into a format where i can share on hugging face as a model that can be pulled down and used
# TODO: include the SHAP charts force_plot.html and summary_plot.png explaining the model, as well as compute some other evaluation metrics for explanation in the card