Spaces:
Sleeping
Sleeping
Commit
·
5256800
1
Parent(s):
074528e
feat: Refine Gradio UI and improve model card
Browse files- .gitattributes +3 -0
- .idea/workspace.xml +8 -0
- README.md +10 -3
- __pycache__/network.cpython-312.pyc +0 -0
- app.py +52 -36
- explain_model.py +202 -0
- requirements.txt +2 -1
.gitattributes
CHANGED
|
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
model/model.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/pokemon_final_with_labels.csv filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
data/scaler.pkl filter=lfs diff=lfs merge=lfs -text
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="PropertiesComponent">{
|
| 4 |
+
"keyToString": {
|
| 5 |
+
"settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable"
|
| 6 |
+
}
|
| 7 |
+
}</component>
|
| 8 |
+
</project>
|
README.md
CHANGED
|
@@ -10,9 +10,16 @@ pinned: false
|
|
| 10 |
license: mit
|
| 11 |
tags:
|
| 12 |
- pytorch
|
|
|
|
|
|
|
| 13 |
- machine-learning
|
| 14 |
-
-
|
| 15 |
- price-prediction
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
---
|
| 17 |
|
| 18 |
## PokePrice: Pokémon Card Price Trend Predictor
|
|
@@ -20,10 +27,10 @@ tags:
|
|
| 20 |
This application uses a PyTorch-based neural network to predict whether the market price of a specific Pokémon card will rise by 30% or more over the next six months.
|
| 21 |
|
| 22 |
### How It Works
|
| 23 |
-
1. **
|
| 24 |
2. **Get Prediction:** The model analyzes various features of the selected card, such as its rarity, type, and historical price data, to make a prediction.
|
| 25 |
3. **View Results:** The application displays:
|
| 26 |
-
* The prediction (whether the price is expected to **RISE** or **NOT RISE**).
|
| 27 |
* The model's confidence level in the prediction.
|
| 28 |
* A direct link to view the card on TCGPlayer.com.
|
| 29 |
* The actual historical outcome if it exists in the dataset, for comparison.
|
|
|
|
| 10 |
license: mit
|
| 11 |
tags:
|
| 12 |
- pytorch
|
| 13 |
+
- scikit-learn
|
| 14 |
+
- gradio
|
| 15 |
- machine-learning
|
| 16 |
+
- tabular-classification
|
| 17 |
- price-prediction
|
| 18 |
+
- finance
|
| 19 |
+
- pokemon
|
| 20 |
+
- pokemon-cards
|
| 21 |
+
- tcg
|
| 22 |
+
- collectibles
|
| 23 |
---
|
| 24 |
|
| 25 |
## PokePrice: Pokémon Card Price Trend Predictor
|
|
|
|
| 27 |
This application uses a PyTorch-based neural network to predict whether the market price of a specific Pokémon card will rise by 30% or more over the next six months.
|
| 28 |
|
| 29 |
### How It Works
|
| 30 |
+
1. **Enter a Card ID:** Input the numeric TCGPlayer ID for a specific Pokémon card. You can find this ID in the URL of the card's page on the TCGPlayer website (e.g., `tcgplayer.com/product/84198/...`).
|
| 31 |
2. **Get Prediction:** The model analyzes various features of the selected card, such as its rarity, type, and historical price data, to make a prediction.
|
| 32 |
3. **View Results:** The application displays:
|
| 33 |
+
* The card's name and the prediction (whether the price is expected to **RISE** or **NOT RISE**).
|
| 34 |
* The model's confidence level in the prediction.
|
| 35 |
* A direct link to view the card on TCGPlayer.com.
|
| 36 |
* The actual historical outcome if it exists in the dataset, for comparison.
|
__pycache__/network.cpython-312.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import joblib
|
|
| 4 |
import pandas as pd
|
| 5 |
import os
|
| 6 |
import json
|
| 7 |
-
import re
|
| 8 |
from safetensors.torch import load_file
|
| 9 |
from typing import List, Tuple
|
| 10 |
from network import PricePredictor
|
|
@@ -16,6 +15,7 @@ DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
|
|
| 16 |
TARGET_COLUMN = 'price_will_rise_30_in_6m'
|
| 17 |
|
| 18 |
|
|
|
|
| 19 |
def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]:
|
| 20 |
config_path = os.path.join(model_dir, "config.json")
|
| 21 |
with open(config_path, "r") as f:
|
|
@@ -40,51 +40,52 @@ def perform_prediction(model: torch.nn.Module, scaler, input_features: pd.Series
|
|
| 40 |
|
| 41 |
return predicted_class, probability
|
| 42 |
|
|
|
|
| 43 |
try:
|
| 44 |
model, feature_columns = load_model_and_config(MODEL_DIR)
|
| 45 |
scaler = joblib.load(SCALER_PATH)
|
| 46 |
full_data = pd.read_csv(DATA_PATH)
|
| 47 |
-
|
| 48 |
-
full_data['display_name'] = full_data.apply(
|
| 49 |
-
lambda row: f"{row['name']} (ID: {row['tcgplayer_id']})", axis=1
|
| 50 |
-
)
|
| 51 |
-
card_choices = sorted(full_data['display_name'].unique().tolist())
|
| 52 |
ASSETS_LOADED = True
|
| 53 |
except FileNotFoundError as e:
|
| 54 |
print(f"Error loading necessary files: {e}")
|
| 55 |
print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.")
|
| 56 |
-
card_choices = ["Error: Model or data files not found. Check logs."]
|
| 57 |
ASSETS_LOADED = False
|
| 58 |
|
| 59 |
|
| 60 |
-
def predict_price_trend(
|
| 61 |
if not ASSETS_LOADED:
|
| 62 |
return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories."
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
return f"## Input Error\nCould not parse ID from '{card_display_name}'. Please select a valid card from the dropdown."
|
| 68 |
|
| 69 |
-
card_data = full_data[full_data['tcgplayer_id'] == tcgplayer_id]
|
| 70 |
if card_data.empty:
|
| 71 |
-
return f"##
|
| 72 |
|
|
|
|
| 73 |
card_sample = card_data.iloc[0]
|
| 74 |
sample_features = card_sample[feature_columns]
|
| 75 |
|
|
|
|
| 76 |
predicted_class, probability = perform_prediction(model, scaler, sample_features)
|
| 77 |
|
| 78 |
prediction_text = "**RISE**" if predicted_class else "**NOT RISE**"
|
| 79 |
confidence = probability if predicted_class else 1 - probability
|
| 80 |
-
|
| 81 |
-
# Construct the TCGPlayer link
|
| 82 |
tcgplayer_link = f"https://www.tcgplayer.com/product/{tcgplayer_id}?Language=English"
|
| 83 |
|
|
|
|
| 84 |
true_label_text = ""
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
output = f"""
|
| 90 |
## 🔮 Prediction Report for {card_sample['name']}
|
|
@@ -96,22 +97,37 @@ def predict_price_trend(card_display_name: str) -> str:
|
|
| 96 |
return output
|
| 97 |
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
| 117 |
-
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
import os
|
| 6 |
import json
|
|
|
|
| 7 |
from safetensors.torch import load_file
|
| 8 |
from typing import List, Tuple
|
| 9 |
from network import PricePredictor
|
|
|
|
| 15 |
TARGET_COLUMN = 'price_will_rise_30_in_6m'
|
| 16 |
|
| 17 |
|
| 18 |
+
|
| 19 |
def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]:
|
| 20 |
config_path = os.path.join(model_dir, "config.json")
|
| 21 |
with open(config_path, "r") as f:
|
|
|
|
| 40 |
|
| 41 |
return predicted_class, probability
|
| 42 |
|
| 43 |
+
# --- Asset Loading ---
|
| 44 |
try:
|
| 45 |
model, feature_columns = load_model_and_config(MODEL_DIR)
|
| 46 |
scaler = joblib.load(SCALER_PATH)
|
| 47 |
full_data = pd.read_csv(DATA_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
ASSETS_LOADED = True
|
| 49 |
except FileNotFoundError as e:
|
| 50 |
print(f"Error loading necessary files: {e}")
|
| 51 |
print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.")
|
|
|
|
| 52 |
ASSETS_LOADED = False
|
| 53 |
|
| 54 |
|
| 55 |
+
def predict_price_trend(card_identifier: str) -> str:
|
| 56 |
if not ASSETS_LOADED:
|
| 57 |
return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories."
|
| 58 |
+
|
| 59 |
+
if not card_identifier or not card_identifier.strip().isdigit():
|
| 60 |
+
return "## Input Error\nPlease enter a valid, numeric TCGPlayer ID."
|
| 61 |
|
| 62 |
+
# --- Find Card Logic ---
|
| 63 |
+
card_id = int(card_identifier.strip())
|
| 64 |
+
card_data = full_data[full_data['tcgplayer_id'] == card_id]
|
|
|
|
| 65 |
|
|
|
|
| 66 |
if card_data.empty:
|
| 67 |
+
return f"## Card Not Found\nCould not find a card with TCGPlayer ID '{card_id}'. Please check the ID and try again."
|
| 68 |
|
| 69 |
+
# Since tcgplayer_id is unique, we can safely take the first (and only) row.
|
| 70 |
card_sample = card_data.iloc[0]
|
| 71 |
sample_features = card_sample[feature_columns]
|
| 72 |
|
| 73 |
+
# --- Prediction Logic ---
|
| 74 |
predicted_class, probability = perform_prediction(model, scaler, sample_features)
|
| 75 |
|
| 76 |
prediction_text = "**RISE**" if predicted_class else "**NOT RISE**"
|
| 77 |
confidence = probability if predicted_class else 1 - probability
|
| 78 |
+
tcgplayer_id = card_sample['tcgplayer_id']
|
|
|
|
| 79 |
tcgplayer_link = f"https://www.tcgplayer.com/product/{tcgplayer_id}?Language=English"
|
| 80 |
|
| 81 |
+
# --- Output Formatting ---
|
| 82 |
true_label_text = ""
|
| 83 |
+
try:
|
| 84 |
+
if TARGET_COLUMN in card_sample and pd.notna(card_sample[TARGET_COLUMN]):
|
| 85 |
+
true_label = bool(card_sample[TARGET_COLUMN])
|
| 86 |
+
true_label_text = f"\n- **Actual Result in Dataset:** The price did **{'RISE' if true_label else 'NOT RISE'}**."
|
| 87 |
+
except (KeyError, TypeError):
|
| 88 |
+
pass # If target column is missing or value is invalid, just skip this part.
|
| 89 |
|
| 90 |
output = f"""
|
| 91 |
## 🔮 Prediction Report for {card_sample['name']}
|
|
|
|
| 97 |
return output
|
| 98 |
|
| 99 |
|
| 100 |
+
# --- Gradio UI ---
|
| 101 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="PricePoke Predictor") as demo:
|
| 102 |
+
gr.Markdown(
|
| 103 |
+
"""
|
| 104 |
+
# 📈 PricePoke: Pokémon Card Price Trend Predictor
|
| 105 |
+
Enter a Pokémon card's TCGPlayer ID to predict whether its market price will increase by 30% or more over the next 6 months.
|
| 106 |
+
This model was trained on historical TCGPlayer market data.
|
| 107 |
+
"""
|
| 108 |
+
)
|
| 109 |
+
with gr.Row():
|
| 110 |
+
with gr.Column(scale=1):
|
| 111 |
+
card_input = gr.Textbox(
|
| 112 |
+
label="TCGPlayer ID",
|
| 113 |
+
placeholder="e.g., '84198'",
|
| 114 |
+
info="Find the ID in the card's URL on TCGPlayer's website (e.g., tcgplayer.com/product/84198/... has ID 84198)."
|
| 115 |
+
)
|
| 116 |
+
predict_button = gr.Button("Predict Trend", variant="primary")
|
| 117 |
+
|
| 118 |
+
gr.Markdown("---")
|
| 119 |
+
gr.Markdown("### Example Cards")
|
| 120 |
+
if ASSETS_LOADED:
|
| 121 |
+
example_df = full_data.sample(5, random_state=42)[['name', 'tcgplayer_id']]
|
| 122 |
+
gr.Markdown(example_df.to_markdown(index=False))
|
| 123 |
+
else:
|
| 124 |
+
gr.Markdown("Could not load examples.")
|
| 125 |
+
|
| 126 |
+
with gr.Column(scale=2):
|
| 127 |
+
output_markdown = gr.Markdown()
|
| 128 |
+
|
| 129 |
+
predict_button.click(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown])
|
| 130 |
+
card_input.submit(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown])
|
| 131 |
|
| 132 |
if __name__ == "__main__":
|
| 133 |
+
demo.launch()
|
explain_model.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# explain_model.py
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
import joblib
|
| 10 |
+
import shap
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
from safetensors.torch import load_file
|
| 13 |
+
from network import PricePredictor
|
| 14 |
+
|
| 15 |
+
# --- 0. Config ---
|
| 16 |
+
MODEL_DIR = "model"
|
| 17 |
+
DATA_DIR = "data"
|
| 18 |
+
SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl")
|
| 19 |
+
DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
|
| 20 |
+
CONFIG_PATH = os.path.join(MODEL_DIR, "config.json")
|
| 21 |
+
TARGET_COLUMN = "price_will_rise_30_in_6m"
|
| 22 |
+
|
| 23 |
+
# --- 1. Load model & assets ---
|
| 24 |
+
with open(CONFIG_PATH, "r") as f:
|
| 25 |
+
config = json.load(f)
|
| 26 |
+
|
| 27 |
+
feature_columns = config["feature_columns"]
|
| 28 |
+
input_size = config["input_size"]
|
| 29 |
+
|
| 30 |
+
model = PricePredictor(input_size=input_size)
|
| 31 |
+
model.load_state_dict(load_file(os.path.join(MODEL_DIR, "model.safetensors")))
|
| 32 |
+
model.eval()
|
| 33 |
+
|
| 34 |
+
scaler = joblib.load(SCALER_PATH)
|
| 35 |
+
full_data = pd.read_csv(DATA_PATH)
|
| 36 |
+
|
| 37 |
+
# Sanity checks
|
| 38 |
+
missing_cols = [c for c in feature_columns if c not in full_data.columns]
|
| 39 |
+
if missing_cols:
|
| 40 |
+
raise ValueError(f"Missing required feature columns in CSV: {missing_cols}")
|
| 41 |
+
|
| 42 |
+
features_df = full_data[feature_columns]
|
| 43 |
+
if features_df.shape[1] != input_size:
|
| 44 |
+
raise ValueError(
|
| 45 |
+
f"Config input_size={input_size}, but CSV provides {features_df.shape[1]} features. "
|
| 46 |
+
f"Ensure config['feature_columns'] matches the trained model."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# --- 2. Prepare Data for SHAP ---
|
| 50 |
+
bg_n = min(100, len(features_df))
|
| 51 |
+
explain_n = min(10, len(features_df))
|
| 52 |
+
|
| 53 |
+
background_idx = features_df.sample(n=bg_n, random_state=42).index
|
| 54 |
+
explain_idx = features_df.sample(n=explain_n, random_state=1).index
|
| 55 |
+
|
| 56 |
+
background_data = features_df.loc[background_idx]
|
| 57 |
+
explain_instances = features_df.loc[explain_idx]
|
| 58 |
+
|
| 59 |
+
# Use arrays for scaler to avoid feature-name warnings
|
| 60 |
+
background_data_scaled = scaler.transform(background_data.values)
|
| 61 |
+
explain_instances_scaled = scaler.transform(explain_instances.values)
|
| 62 |
+
|
| 63 |
+
background_tensor = torch.tensor(background_data_scaled, dtype=torch.float32) # no grad
|
| 64 |
+
explain_tensor = torch.tensor(explain_instances_scaled, dtype=torch.float32, requires_grad=True)
|
| 65 |
+
|
| 66 |
+
# --- Helpers ---
|
| 67 |
+
def get_shap_explanations(model, background_tensor, explain_tensor):
|
| 68 |
+
"""Try DeepExplainer then fall back to GradientExplainer. Return (explanation, explainer_used_name)."""
|
| 69 |
+
try:
|
| 70 |
+
print("Initializing SHAP DeepExplainer...")
|
| 71 |
+
explainer = shap.DeepExplainer(model, background_tensor)
|
| 72 |
+
print("Calculating SHAP values for the sample...")
|
| 73 |
+
exp = explainer(explain_tensor)
|
| 74 |
+
setattr(exp, "_expected_value_hint", getattr(explainer, "expected_value", None))
|
| 75 |
+
return exp, "deep"
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"[DeepExplainer failed: {e}] Falling back to GradientExplainer...")
|
| 78 |
+
explain_tensor.requires_grad_(True)
|
| 79 |
+
grad_explainer = shap.GradientExplainer(model, background_tensor)
|
| 80 |
+
exp = grad_explainer(explain_tensor)
|
| 81 |
+
setattr(exp, "_expected_value_hint", getattr(grad_explainer, "expected_value", None))
|
| 82 |
+
return exp, "grad"
|
| 83 |
+
|
| 84 |
+
def compute_base_value_safe(shap_explanation, instance_idx, model, background_tensor):
|
| 85 |
+
"""Return scalar base value robustly across SHAP versions."""
|
| 86 |
+
bv = getattr(shap_explanation, "base_values", None)
|
| 87 |
+
if bv is not None:
|
| 88 |
+
try:
|
| 89 |
+
return float(np.squeeze(bv[instance_idx]))
|
| 90 |
+
except Exception:
|
| 91 |
+
try:
|
| 92 |
+
return float(np.squeeze(bv))
|
| 93 |
+
except Exception:
|
| 94 |
+
pass
|
| 95 |
+
ev = getattr(shap_explanation, "_expected_value_hint", None)
|
| 96 |
+
if ev is not None:
|
| 97 |
+
try:
|
| 98 |
+
return float(np.squeeze(ev))
|
| 99 |
+
except Exception:
|
| 100 |
+
try:
|
| 101 |
+
return float(np.mean(ev))
|
| 102 |
+
except Exception:
|
| 103 |
+
pass
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
mu = background_tensor.mean(dim=0, keepdim=True)
|
| 106 |
+
out = model(mu).detach().cpu().squeeze()
|
| 107 |
+
return float(out.mean().item()) if out.numel() > 1 else float(out.item())
|
| 108 |
+
|
| 109 |
+
def stack_sample_shap_values(exp, n_features_expected):
|
| 110 |
+
"""
|
| 111 |
+
Some SHAP versions return exp.values with shape (n_samples, 1) or other oddities.
|
| 112 |
+
However, exp[i].values is typically the correct 1D (n_features,) vector.
|
| 113 |
+
We rebuild a full matrix by stacking per-sample slices.
|
| 114 |
+
"""
|
| 115 |
+
rows = []
|
| 116 |
+
n_samples = len(exp.values) if hasattr(exp.values, "__len__") else len(exp)
|
| 117 |
+
# Safer: iterate using the __getitem__ API
|
| 118 |
+
for i in range(n_samples):
|
| 119 |
+
v = np.asarray(exp[i].values).reshape(-1,)
|
| 120 |
+
rows.append(v)
|
| 121 |
+
M = np.vstack(rows) # (n_samples, n_features)
|
| 122 |
+
if M.shape[1] != n_features_expected:
|
| 123 |
+
raise RuntimeError(
|
| 124 |
+
f"Rebuilt SHAP matrix has shape {M.shape}; expected n_features={n_features_expected}."
|
| 125 |
+
)
|
| 126 |
+
return M
|
| 127 |
+
|
| 128 |
+
# --- 3. Compute SHAP explanations ---
|
| 129 |
+
shap_explanation, _ = get_shap_explanations(model, background_tensor, explain_tensor)
|
| 130 |
+
print("Calculation complete.")
|
| 131 |
+
|
| 132 |
+
# Attach unscaled display data for pretty plotting
|
| 133 |
+
shap_explanation.display_data = explain_instances.values
|
| 134 |
+
shap_explanation.feature_names = feature_columns
|
| 135 |
+
|
| 136 |
+
# --- 4a. Global Feature Importance (Bar / Summary) ---
|
| 137 |
+
print("\nGenerating global feature importance plot (summary_plot.png)...")
|
| 138 |
+
|
| 139 |
+
# Robustly build a (n_samples, n_features) matrix by stacking per-sample vectors
|
| 140 |
+
shap_vals_matrix = stack_sample_shap_values(shap_explanation, n_features_expected=len(feature_columns))
|
| 141 |
+
|
| 142 |
+
mean_abs_shap = np.abs(shap_vals_matrix).mean(axis=0) # (n_features,)
|
| 143 |
+
|
| 144 |
+
# Build a fresh Explanation with values aligned to feature_names
|
| 145 |
+
plot_explanation = shap.Explanation(values=mean_abs_shap, feature_names=feature_columns)
|
| 146 |
+
|
| 147 |
+
plt.figure()
|
| 148 |
+
shap.plots.bar(plot_explanation, show=False)
|
| 149 |
+
plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)")
|
| 150 |
+
plt.savefig("summary_plot.png", bbox_inches="tight")
|
| 151 |
+
plt.close()
|
| 152 |
+
print("Saved: summary_plot.png")
|
| 153 |
+
|
| 154 |
+
# --- 4b. Local Explanation (Force Plot) ---
|
| 155 |
+
print("\nGenerating local explanation for one card (force_plot.html)...")
|
| 156 |
+
instance_to_explain_index = 0
|
| 157 |
+
single_explanation = shap_explanation[instance_to_explain_index]
|
| 158 |
+
|
| 159 |
+
# Some SHAP versions drop display_data on slicing; pull directly if needed
|
| 160 |
+
if getattr(single_explanation, "display_data", None) is None:
|
| 161 |
+
row_unscaled = explain_instances.values[instance_to_explain_index]
|
| 162 |
+
else:
|
| 163 |
+
row_unscaled = single_explanation.display_data
|
| 164 |
+
features_row = np.atleast_2d(np.asarray(row_unscaled, dtype=float))
|
| 165 |
+
|
| 166 |
+
base_val = compute_base_value_safe(shap_explanation, instance_to_explain_index, model, background_tensor)
|
| 167 |
+
phi = np.asarray(single_explanation.values).reshape(-1,) # (n_features,)
|
| 168 |
+
|
| 169 |
+
force_plot = shap.force_plot(
|
| 170 |
+
base_val,
|
| 171 |
+
phi,
|
| 172 |
+
features=features_row,
|
| 173 |
+
feature_names=feature_columns
|
| 174 |
+
)
|
| 175 |
+
shap.save_html("force_plot.html", force_plot)
|
| 176 |
+
print("Saved: force_plot.html (open in a browser)")
|
| 177 |
+
|
| 178 |
+
# --- 4c. Optional: local waterfall PNG (often clearer) ---
|
| 179 |
+
try:
|
| 180 |
+
print("Generating local waterfall plot (waterfall_single.png)...")
|
| 181 |
+
plt.figure()
|
| 182 |
+
shap.plots.waterfall(single_explanation, show=False, max_display=20)
|
| 183 |
+
plt.savefig("waterfall_single.png", bbox_inches="tight")
|
| 184 |
+
plt.close()
|
| 185 |
+
print("Saved: waterfall_single.png")
|
| 186 |
+
except Exception as e:
|
| 187 |
+
print(f"Waterfall plot skipped (reason: {e})")
|
| 188 |
+
|
| 189 |
+
# --- 5. Print metadata for the explained card ---
|
| 190 |
+
original_card_data = full_data.loc[explain_idx[instance_to_explain_index]]
|
| 191 |
+
name_val = original_card_data.get("name", "N/A")
|
| 192 |
+
tcgp_val = original_card_data.get("tcgplayer_id", "N/A")
|
| 193 |
+
label_val = original_card_data.get(TARGET_COLUMN, None)
|
| 194 |
+
label_str = "RISE" if bool(label_val) else "NOT RISE" if label_val is not None else "N/A"
|
| 195 |
+
|
| 196 |
+
print("\n--- Card Explained in force_plot.html / waterfall_single.png ---")
|
| 197 |
+
print(f"Name: {name_val}")
|
| 198 |
+
print(f"TCGPlayer ID: {tcgp_val}")
|
| 199 |
+
print(f"Actual Outcome in Dataset: {label_str}")
|
| 200 |
+
|
| 201 |
+
# TODO: convert the model into a format where i can share on hugging face as a model that can be pulled down and used
|
| 202 |
+
# TODO: include the SHAP charts force_plot.html and summary_plot.png explaining the model, as well as compute some other evaluation metrics for explanation in the card
|
requirements.txt
CHANGED
|
@@ -3,4 +3,5 @@ pandas
|
|
| 3 |
numpy
|
| 4 |
scikit-learn
|
| 5 |
safetensors
|
| 6 |
-
gradio
|
|
|
|
|
|
| 3 |
numpy
|
| 4 |
scikit-learn
|
| 5 |
safetensors
|
| 6 |
+
gradio
|
| 7 |
+
tabulate
|