OffWorldTensor commited on
Commit
5256800
·
1 Parent(s): 074528e

feat: Refine Gradio UI and improve model card

Browse files
.gitattributes CHANGED
@@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.csv filter=lfs diff=lfs merge=lfs -text
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.csv filter=lfs diff=lfs merge=lfs -text
37
+ model/model.safetensors filter=lfs diff=lfs merge=lfs -text
38
+ data/pokemon_final_with_labels.csv filter=lfs diff=lfs merge=lfs -text
39
+ data/scaler.pkl filter=lfs diff=lfs merge=lfs -text
.idea/workspace.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PropertiesComponent">{
4
+ &quot;keyToString&quot;: {
5
+ &quot;settings.editor.selected.configurable&quot;: &quot;com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable&quot;
6
+ }
7
+ }</component>
8
+ </project>
README.md CHANGED
@@ -10,9 +10,16 @@ pinned: false
10
  license: mit
11
  tags:
12
  - pytorch
 
 
13
  - machine-learning
14
- - pokemon
15
  - price-prediction
 
 
 
 
 
16
  ---
17
 
18
  ## PokePrice: Pokémon Card Price Trend Predictor
@@ -20,10 +27,10 @@ tags:
20
  This application uses a PyTorch-based neural network to predict whether the market price of a specific Pokémon card will rise by 30% or more over the next six months.
21
 
22
  ### How It Works
23
- 1. **Select a Card:** Choose a Pokémon card from the dropdown menu. The list is populated from a dataset containing historical price information.
24
  2. **Get Prediction:** The model analyzes various features of the selected card, such as its rarity, type, and historical price data, to make a prediction.
25
  3. **View Results:** The application displays:
26
- * The prediction (whether the price is expected to **RISE** or **NOT RISE**).
27
  * The model's confidence level in the prediction.
28
  * A direct link to view the card on TCGPlayer.com.
29
  * The actual historical outcome if it exists in the dataset, for comparison.
 
10
  license: mit
11
  tags:
12
  - pytorch
13
+ - scikit-learn
14
+ - gradio
15
  - machine-learning
16
+ - tabular-classification
17
  - price-prediction
18
+ - finance
19
+ - pokemon
20
+ - pokemon-cards
21
+ - tcg
22
+ - collectibles
23
  ---
24
 
25
  ## PokePrice: Pokémon Card Price Trend Predictor
 
27
  This application uses a PyTorch-based neural network to predict whether the market price of a specific Pokémon card will rise by 30% or more over the next six months.
28
 
29
  ### How It Works
30
+ 1. **Enter a Card ID:** Input the numeric TCGPlayer ID for a specific Pokémon card. You can find this ID in the URL of the card's page on the TCGPlayer website (e.g., `tcgplayer.com/product/84198/...`).
31
  2. **Get Prediction:** The model analyzes various features of the selected card, such as its rarity, type, and historical price data, to make a prediction.
32
  3. **View Results:** The application displays:
33
+ * The card's name and the prediction (whether the price is expected to **RISE** or **NOT RISE**).
34
  * The model's confidence level in the prediction.
35
  * A direct link to view the card on TCGPlayer.com.
36
  * The actual historical outcome if it exists in the dataset, for comparison.
__pycache__/network.cpython-312.pyc ADDED
Binary file (1.53 kB). View file
 
app.py CHANGED
@@ -4,7 +4,6 @@ import joblib
4
  import pandas as pd
5
  import os
6
  import json
7
- import re
8
  from safetensors.torch import load_file
9
  from typing import List, Tuple
10
  from network import PricePredictor
@@ -16,6 +15,7 @@ DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
16
  TARGET_COLUMN = 'price_will_rise_30_in_6m'
17
 
18
 
 
19
  def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]:
20
  config_path = os.path.join(model_dir, "config.json")
21
  with open(config_path, "r") as f:
@@ -40,51 +40,52 @@ def perform_prediction(model: torch.nn.Module, scaler, input_features: pd.Series
40
 
41
  return predicted_class, probability
42
 
 
43
  try:
44
  model, feature_columns = load_model_and_config(MODEL_DIR)
45
  scaler = joblib.load(SCALER_PATH)
46
  full_data = pd.read_csv(DATA_PATH)
47
-
48
- full_data['display_name'] = full_data.apply(
49
- lambda row: f"{row['name']} (ID: {row['tcgplayer_id']})", axis=1
50
- )
51
- card_choices = sorted(full_data['display_name'].unique().tolist())
52
  ASSETS_LOADED = True
53
  except FileNotFoundError as e:
54
  print(f"Error loading necessary files: {e}")
55
  print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.")
56
- card_choices = ["Error: Model or data files not found. Check logs."]
57
  ASSETS_LOADED = False
58
 
59
 
60
- def predict_price_trend(card_display_name: str) -> str:
61
  if not ASSETS_LOADED:
62
  return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories."
 
 
 
63
 
64
- try:
65
- tcgplayer_id = int(re.search(r'\(ID: (\d+)\)', card_display_name).group(1))
66
- except (AttributeError, ValueError):
67
- return f"## Input Error\nCould not parse ID from '{card_display_name}'. Please select a valid card from the dropdown."
68
 
69
- card_data = full_data[full_data['tcgplayer_id'] == tcgplayer_id]
70
  if card_data.empty:
71
- return f"## Internal Error\nCould not find data for ID {tcgplayer_id}. Please restart the Space or select another card."
72
 
 
73
  card_sample = card_data.iloc[0]
74
  sample_features = card_sample[feature_columns]
75
 
 
76
  predicted_class, probability = perform_prediction(model, scaler, sample_features)
77
 
78
  prediction_text = "**RISE**" if predicted_class else "**NOT RISE**"
79
  confidence = probability if predicted_class else 1 - probability
80
-
81
- # Construct the TCGPlayer link
82
  tcgplayer_link = f"https://www.tcgplayer.com/product/{tcgplayer_id}?Language=English"
83
 
 
84
  true_label_text = ""
85
- if TARGET_COLUMN in card_sample and pd.notna(card_sample[TARGET_COLUMN]):
86
- true_label = bool(card_sample[TARGET_COLUMN])
87
- true_label_text = f"\n- **Actual Result in Dataset:** The price did **{'RISE' if true_label else 'NOT RISE'}**."
 
 
 
88
 
89
  output = f"""
90
  ## 🔮 Prediction Report for {card_sample['name']}
@@ -96,22 +97,37 @@ def predict_price_trend(card_display_name: str) -> str:
96
  return output
97
 
98
 
99
- iface = gr.Interface(
100
- fn=predict_price_trend,
101
- inputs=gr.Dropdown(
102
- choices=card_choices,
103
- label="Select a Pokémon Card",
104
- info="Choose a card from the dataset to predict its price trend."
105
- ),
106
- outputs=gr.Markdown(),
107
- title="PricePoke: Pokémon Card Price Trend Predictor",
108
- description="""
109
- Select a Pokémon card to predict whether its market price will increase by 30% or more over the next 6 months.
110
- This model was trained on historical TCGPlayer market data.
111
- """,
112
- examples=[[card_choices[0]] if card_choices and ASSETS_LOADED else []],
113
- allow_flagging="never"
114
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
- iface.launch()
 
4
  import pandas as pd
5
  import os
6
  import json
 
7
  from safetensors.torch import load_file
8
  from typing import List, Tuple
9
  from network import PricePredictor
 
15
  TARGET_COLUMN = 'price_will_rise_30_in_6m'
16
 
17
 
18
+
19
  def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]:
20
  config_path = os.path.join(model_dir, "config.json")
21
  with open(config_path, "r") as f:
 
40
 
41
  return predicted_class, probability
42
 
43
+ # --- Asset Loading ---
44
  try:
45
  model, feature_columns = load_model_and_config(MODEL_DIR)
46
  scaler = joblib.load(SCALER_PATH)
47
  full_data = pd.read_csv(DATA_PATH)
 
 
 
 
 
48
  ASSETS_LOADED = True
49
  except FileNotFoundError as e:
50
  print(f"Error loading necessary files: {e}")
51
  print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.")
 
52
  ASSETS_LOADED = False
53
 
54
 
55
+ def predict_price_trend(card_identifier: str) -> str:
56
  if not ASSETS_LOADED:
57
  return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories."
58
+
59
+ if not card_identifier or not card_identifier.strip().isdigit():
60
+ return "## Input Error\nPlease enter a valid, numeric TCGPlayer ID."
61
 
62
+ # --- Find Card Logic ---
63
+ card_id = int(card_identifier.strip())
64
+ card_data = full_data[full_data['tcgplayer_id'] == card_id]
 
65
 
 
66
  if card_data.empty:
67
+ return f"## Card Not Found\nCould not find a card with TCGPlayer ID '{card_id}'. Please check the ID and try again."
68
 
69
+ # Since tcgplayer_id is unique, we can safely take the first (and only) row.
70
  card_sample = card_data.iloc[0]
71
  sample_features = card_sample[feature_columns]
72
 
73
+ # --- Prediction Logic ---
74
  predicted_class, probability = perform_prediction(model, scaler, sample_features)
75
 
76
  prediction_text = "**RISE**" if predicted_class else "**NOT RISE**"
77
  confidence = probability if predicted_class else 1 - probability
78
+ tcgplayer_id = card_sample['tcgplayer_id']
 
79
  tcgplayer_link = f"https://www.tcgplayer.com/product/{tcgplayer_id}?Language=English"
80
 
81
+ # --- Output Formatting ---
82
  true_label_text = ""
83
+ try:
84
+ if TARGET_COLUMN in card_sample and pd.notna(card_sample[TARGET_COLUMN]):
85
+ true_label = bool(card_sample[TARGET_COLUMN])
86
+ true_label_text = f"\n- **Actual Result in Dataset:** The price did **{'RISE' if true_label else 'NOT RISE'}**."
87
+ except (KeyError, TypeError):
88
+ pass # If target column is missing or value is invalid, just skip this part.
89
 
90
  output = f"""
91
  ## 🔮 Prediction Report for {card_sample['name']}
 
97
  return output
98
 
99
 
100
+ # --- Gradio UI ---
101
+ with gr.Blocks(theme=gr.themes.Soft(), title="PricePoke Predictor") as demo:
102
+ gr.Markdown(
103
+ """
104
+ # 📈 PricePoke: Pokémon Card Price Trend Predictor
105
+ Enter a Pokémon card's TCGPlayer ID to predict whether its market price will increase by 30% or more over the next 6 months.
106
+ This model was trained on historical TCGPlayer market data.
107
+ """
108
+ )
109
+ with gr.Row():
110
+ with gr.Column(scale=1):
111
+ card_input = gr.Textbox(
112
+ label="TCGPlayer ID",
113
+ placeholder="e.g., '84198'",
114
+ info="Find the ID in the card's URL on TCGPlayer's website (e.g., tcgplayer.com/product/84198/... has ID 84198)."
115
+ )
116
+ predict_button = gr.Button("Predict Trend", variant="primary")
117
+
118
+ gr.Markdown("---")
119
+ gr.Markdown("### Example Cards")
120
+ if ASSETS_LOADED:
121
+ example_df = full_data.sample(5, random_state=42)[['name', 'tcgplayer_id']]
122
+ gr.Markdown(example_df.to_markdown(index=False))
123
+ else:
124
+ gr.Markdown("Could not load examples.")
125
+
126
+ with gr.Column(scale=2):
127
+ output_markdown = gr.Markdown()
128
+
129
+ predict_button.click(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown])
130
+ card_input.submit(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown])
131
 
132
  if __name__ == "__main__":
133
+ demo.launch()
explain_model.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # explain_model.py
3
+
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import joblib
10
+ import shap
11
+ import matplotlib.pyplot as plt
12
+ from safetensors.torch import load_file
13
+ from network import PricePredictor
14
+
15
+ # --- 0. Config ---
16
+ MODEL_DIR = "model"
17
+ DATA_DIR = "data"
18
+ SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl")
19
+ DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
20
+ CONFIG_PATH = os.path.join(MODEL_DIR, "config.json")
21
+ TARGET_COLUMN = "price_will_rise_30_in_6m"
22
+
23
+ # --- 1. Load model & assets ---
24
+ with open(CONFIG_PATH, "r") as f:
25
+ config = json.load(f)
26
+
27
+ feature_columns = config["feature_columns"]
28
+ input_size = config["input_size"]
29
+
30
+ model = PricePredictor(input_size=input_size)
31
+ model.load_state_dict(load_file(os.path.join(MODEL_DIR, "model.safetensors")))
32
+ model.eval()
33
+
34
+ scaler = joblib.load(SCALER_PATH)
35
+ full_data = pd.read_csv(DATA_PATH)
36
+
37
+ # Sanity checks
38
+ missing_cols = [c for c in feature_columns if c not in full_data.columns]
39
+ if missing_cols:
40
+ raise ValueError(f"Missing required feature columns in CSV: {missing_cols}")
41
+
42
+ features_df = full_data[feature_columns]
43
+ if features_df.shape[1] != input_size:
44
+ raise ValueError(
45
+ f"Config input_size={input_size}, but CSV provides {features_df.shape[1]} features. "
46
+ f"Ensure config['feature_columns'] matches the trained model."
47
+ )
48
+
49
+ # --- 2. Prepare Data for SHAP ---
50
+ bg_n = min(100, len(features_df))
51
+ explain_n = min(10, len(features_df))
52
+
53
+ background_idx = features_df.sample(n=bg_n, random_state=42).index
54
+ explain_idx = features_df.sample(n=explain_n, random_state=1).index
55
+
56
+ background_data = features_df.loc[background_idx]
57
+ explain_instances = features_df.loc[explain_idx]
58
+
59
+ # Use arrays for scaler to avoid feature-name warnings
60
+ background_data_scaled = scaler.transform(background_data.values)
61
+ explain_instances_scaled = scaler.transform(explain_instances.values)
62
+
63
+ background_tensor = torch.tensor(background_data_scaled, dtype=torch.float32) # no grad
64
+ explain_tensor = torch.tensor(explain_instances_scaled, dtype=torch.float32, requires_grad=True)
65
+
66
+ # --- Helpers ---
67
+ def get_shap_explanations(model, background_tensor, explain_tensor):
68
+ """Try DeepExplainer then fall back to GradientExplainer. Return (explanation, explainer_used_name)."""
69
+ try:
70
+ print("Initializing SHAP DeepExplainer...")
71
+ explainer = shap.DeepExplainer(model, background_tensor)
72
+ print("Calculating SHAP values for the sample...")
73
+ exp = explainer(explain_tensor)
74
+ setattr(exp, "_expected_value_hint", getattr(explainer, "expected_value", None))
75
+ return exp, "deep"
76
+ except Exception as e:
77
+ print(f"[DeepExplainer failed: {e}] Falling back to GradientExplainer...")
78
+ explain_tensor.requires_grad_(True)
79
+ grad_explainer = shap.GradientExplainer(model, background_tensor)
80
+ exp = grad_explainer(explain_tensor)
81
+ setattr(exp, "_expected_value_hint", getattr(grad_explainer, "expected_value", None))
82
+ return exp, "grad"
83
+
84
+ def compute_base_value_safe(shap_explanation, instance_idx, model, background_tensor):
85
+ """Return scalar base value robustly across SHAP versions."""
86
+ bv = getattr(shap_explanation, "base_values", None)
87
+ if bv is not None:
88
+ try:
89
+ return float(np.squeeze(bv[instance_idx]))
90
+ except Exception:
91
+ try:
92
+ return float(np.squeeze(bv))
93
+ except Exception:
94
+ pass
95
+ ev = getattr(shap_explanation, "_expected_value_hint", None)
96
+ if ev is not None:
97
+ try:
98
+ return float(np.squeeze(ev))
99
+ except Exception:
100
+ try:
101
+ return float(np.mean(ev))
102
+ except Exception:
103
+ pass
104
+ with torch.no_grad():
105
+ mu = background_tensor.mean(dim=0, keepdim=True)
106
+ out = model(mu).detach().cpu().squeeze()
107
+ return float(out.mean().item()) if out.numel() > 1 else float(out.item())
108
+
109
+ def stack_sample_shap_values(exp, n_features_expected):
110
+ """
111
+ Some SHAP versions return exp.values with shape (n_samples, 1) or other oddities.
112
+ However, exp[i].values is typically the correct 1D (n_features,) vector.
113
+ We rebuild a full matrix by stacking per-sample slices.
114
+ """
115
+ rows = []
116
+ n_samples = len(exp.values) if hasattr(exp.values, "__len__") else len(exp)
117
+ # Safer: iterate using the __getitem__ API
118
+ for i in range(n_samples):
119
+ v = np.asarray(exp[i].values).reshape(-1,)
120
+ rows.append(v)
121
+ M = np.vstack(rows) # (n_samples, n_features)
122
+ if M.shape[1] != n_features_expected:
123
+ raise RuntimeError(
124
+ f"Rebuilt SHAP matrix has shape {M.shape}; expected n_features={n_features_expected}."
125
+ )
126
+ return M
127
+
128
+ # --- 3. Compute SHAP explanations ---
129
+ shap_explanation, _ = get_shap_explanations(model, background_tensor, explain_tensor)
130
+ print("Calculation complete.")
131
+
132
+ # Attach unscaled display data for pretty plotting
133
+ shap_explanation.display_data = explain_instances.values
134
+ shap_explanation.feature_names = feature_columns
135
+
136
+ # --- 4a. Global Feature Importance (Bar / Summary) ---
137
+ print("\nGenerating global feature importance plot (summary_plot.png)...")
138
+
139
+ # Robustly build a (n_samples, n_features) matrix by stacking per-sample vectors
140
+ shap_vals_matrix = stack_sample_shap_values(shap_explanation, n_features_expected=len(feature_columns))
141
+
142
+ mean_abs_shap = np.abs(shap_vals_matrix).mean(axis=0) # (n_features,)
143
+
144
+ # Build a fresh Explanation with values aligned to feature_names
145
+ plot_explanation = shap.Explanation(values=mean_abs_shap, feature_names=feature_columns)
146
+
147
+ plt.figure()
148
+ shap.plots.bar(plot_explanation, show=False)
149
+ plt.xlabel("mean(|SHAP value|) (average impact on model output magnitude)")
150
+ plt.savefig("summary_plot.png", bbox_inches="tight")
151
+ plt.close()
152
+ print("Saved: summary_plot.png")
153
+
154
+ # --- 4b. Local Explanation (Force Plot) ---
155
+ print("\nGenerating local explanation for one card (force_plot.html)...")
156
+ instance_to_explain_index = 0
157
+ single_explanation = shap_explanation[instance_to_explain_index]
158
+
159
+ # Some SHAP versions drop display_data on slicing; pull directly if needed
160
+ if getattr(single_explanation, "display_data", None) is None:
161
+ row_unscaled = explain_instances.values[instance_to_explain_index]
162
+ else:
163
+ row_unscaled = single_explanation.display_data
164
+ features_row = np.atleast_2d(np.asarray(row_unscaled, dtype=float))
165
+
166
+ base_val = compute_base_value_safe(shap_explanation, instance_to_explain_index, model, background_tensor)
167
+ phi = np.asarray(single_explanation.values).reshape(-1,) # (n_features,)
168
+
169
+ force_plot = shap.force_plot(
170
+ base_val,
171
+ phi,
172
+ features=features_row,
173
+ feature_names=feature_columns
174
+ )
175
+ shap.save_html("force_plot.html", force_plot)
176
+ print("Saved: force_plot.html (open in a browser)")
177
+
178
+ # --- 4c. Optional: local waterfall PNG (often clearer) ---
179
+ try:
180
+ print("Generating local waterfall plot (waterfall_single.png)...")
181
+ plt.figure()
182
+ shap.plots.waterfall(single_explanation, show=False, max_display=20)
183
+ plt.savefig("waterfall_single.png", bbox_inches="tight")
184
+ plt.close()
185
+ print("Saved: waterfall_single.png")
186
+ except Exception as e:
187
+ print(f"Waterfall plot skipped (reason: {e})")
188
+
189
+ # --- 5. Print metadata for the explained card ---
190
+ original_card_data = full_data.loc[explain_idx[instance_to_explain_index]]
191
+ name_val = original_card_data.get("name", "N/A")
192
+ tcgp_val = original_card_data.get("tcgplayer_id", "N/A")
193
+ label_val = original_card_data.get(TARGET_COLUMN, None)
194
+ label_str = "RISE" if bool(label_val) else "NOT RISE" if label_val is not None else "N/A"
195
+
196
+ print("\n--- Card Explained in force_plot.html / waterfall_single.png ---")
197
+ print(f"Name: {name_val}")
198
+ print(f"TCGPlayer ID: {tcgp_val}")
199
+ print(f"Actual Outcome in Dataset: {label_str}")
200
+
201
+ # TODO: convert the model into a format where i can share on hugging face as a model that can be pulled down and used
202
+ # TODO: include the SHAP charts force_plot.html and summary_plot.png explaining the model, as well as compute some other evaluation metrics for explanation in the card
requirements.txt CHANGED
@@ -3,4 +3,5 @@ pandas
3
  numpy
4
  scikit-learn
5
  safetensors
6
- gradio
 
 
3
  numpy
4
  scikit-learn
5
  safetensors
6
+ gradio
7
+ tabulate