OffWorldTensor commited on
Commit
b475c91
·
1 Parent(s): f3b3602

Update model artifacts and explanations

Browse files
Files changed (7) hide show
  1. README.md +131 -3
  2. config.json +70 -0
  3. force_plot.html +0 -0
  4. model.safetensors +3 -0
  5. network.py +23 -0
  6. scaler.pkl +3 -0
  7. summary_plot.png +0 -0
README.md CHANGED
@@ -1,3 +1,131 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language: en
4
+ library_name: pytorch
5
+ tags:
6
+ - pytorch
7
+ - tabular-classification
8
+ - pokemon
9
+ - finance
10
+ - scikit-learn
11
+ - shap
12
+ ---
13
+
14
+ # Pokémon TCG Price Predictor
15
+
16
+ This repository contains a PyTorch model trained to predict whether a Pokémon TCG card's price will rise by at least 30% within the next 6 months.
17
+
18
+ This model is the backend for the **[PokePrice Gradio Demo](https://huggingface.co/spaces/OffWorldTensor/PokePrice)**.
19
+
20
+ ## Model Description
21
+
22
+ The model is a simple Multi-Layer Perceptron (MLP) implemented in PyTorch. It takes various features of a Pokémon card as input—such as its rarity, type, and historical price data—and outputs a single logit. A sigmoid function can be applied to this logit to get a probability score for the price rising.
23
+
24
+ - **Model type:** Tabular Binary Classification
25
+ - **Architecture:** `PricePredictor` (MLP)
26
+ - **Framework:** PyTorch
27
+ - **Training Data:** A custom dataset derived from the PokemonTCG/pokemon-tcg-data repository, augmented with pricing history.
28
+
29
+ ## How to Use
30
+
31
+ To use this model, you will need `torch`, `scikit-learn`, `pandas`, and `huggingface_hub`. You can download the model artifacts directly from the Hub.
32
+
33
+ First, ensure you have `network.py` (which defines the model class) in your working directory.
34
+
35
+ ```python
36
+ import torch
37
+ import joblib
38
+ import json
39
+ import pandas as pd
40
+ from huggingface_hub import hf_hub_download
41
+ from safetensors.torch import load_file
42
+
43
+ # Make sure you have network.py in the same directory
44
+ from network import PricePredictor
45
+
46
+ REPO_ID = "your-username/pokemon-price-predictor"
47
+ MODEL_FILENAME = "model.safetensors"
48
+ CONFIG_FILENAME = "config.json"
49
+ SCALER_FILENAME = "scaler.pkl"
50
+
51
+ print("Downloading model files from the Hub...")
52
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME)
53
+ config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME)
54
+ scaler_path = hf_hub_download(repo_id=REPO_ID, filename=SCALER_FILENAME)
55
+ print("Downloads complete.")
56
+
57
+ with open(config_path, "r") as f:
58
+ config = json.load(f)
59
+
60
+ feature_columns = config["feature_columns"]
61
+ input_size = config["input_size"]
62
+
63
+ model = PricePredictor(input_size=input_size)
64
+ model.load_state_dict(load_file(model_path))
65
+ model.eval()
66
+
67
+ scaler = joblib.load(scaler_path)
68
+
69
+ data_to_predict = {
70
+ 'rawPrice': [10.0], 'gradedPriceTen': [100.0], 'gradedPriceNine': [50.0],
71
+ }
72
+
73
+ input_df = pd.DataFrame(data_to_predict)
74
+ missing_cols = set(feature_columns) - set(input_df.columns)
75
+ for c in missing_cols:
76
+ input_df[c] = 0.0
77
+ input_df = input_df[feature_columns]
78
+
79
+
80
+ input_scaled = scaler.transform(input_df.values)
81
+ input_tensor = torch.tensor(input_scaled, dtype=torch.float32)
82
+
83
+ with torch.no_grad():
84
+ logits = model(input_tensor)
85
+ probability = torch.sigmoid(logits).item()
86
+
87
+ print(f"\nPrediction for the input card:")
88
+ print(f" - Probability of 30% price rise in 6 months: {probability:.4f}")
89
+
90
+ if probability > 0.5:
91
+ print(" - Prediction: Price WILL LIKELY rise.")
92
+ else:
93
+ print(" - Prediction: Price WILL LIKELY NOT rise.")
94
+ ```
95
+
96
+ ## Model Performance
97
+
98
+ The model was evaluated on a 20% held-out test set.
99
+
100
+ - **Accuracy:** 0.9515
101
+ - **Precision:** 0.9323
102
+ - **Recall:** 0.8986
103
+ - **F1-Score:** 0.9151
104
+
105
+ ## Model Explainability
106
+
107
+ To understand the model's decisions, SHAP (SHapley Additive exPlanations) values were computed.
108
+
109
+ ### Global Feature Importance
110
+
111
+ This plot shows the average impact of each feature on the model's output magnitude. Features at the top are most influential.
112
+
113
+ ![Global Feature Importance](explanation_outputs/summary_plot.png)
114
+
115
+ ### Local Explanation for a Single Card
116
+
117
+ A static waterfall plot provides a clear view of features pushing the prediction for a single card.
118
+
119
+ ![Local Waterfall Plot](explanation_outputs/force_plot.html)
120
+
121
+ An interactive force plot is also available. You can view it by downloading `force_plot.html` from this repository and opening it in your browser.
122
+
123
+ ## Limitations and Bias
124
+
125
+ - The model is trained on historical data and may not predict future trends accurately, especially in a volatile market.
126
+ - The definition of "price rise" is fixed at 30% over 6 months. The model is not trained for other thresholds or timeframes.
127
+ - The dataset may have inherent biases related to card popularity, set releases, or data collection artifacts.
128
+
129
+ ## Author
130
+
131
+ Callum Anderson
config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "input_size": 64,
3
+ "model_class": "PricePredictor",
4
+ "feature_columns": [
5
+ "rawPrice",
6
+ "gradedPriceTen",
7
+ "gradedPriceNine",
8
+ "first_raw",
9
+ "price_ratio_to_first",
10
+ "log_raw",
11
+ "log_g10",
12
+ "log_g9",
13
+ "price_vs_rolling_avg",
14
+ "rawPrice_missing",
15
+ "gradedPriceTen_missing",
16
+ "gradedPriceNine_missing",
17
+ "rarity_ACE SPEC Rare",
18
+ "rarity_Amazing Rare",
19
+ "rarity_Black White Rare",
20
+ "rarity_Classic Collection",
21
+ "rarity_Code Card",
22
+ "rarity_Common",
23
+ "rarity_Double Rare",
24
+ "rarity_Holo Rare",
25
+ "rarity_Hyper Rare",
26
+ "rarity_Illustration Rare",
27
+ "rarity_Prism Rare",
28
+ "rarity_Promo",
29
+ "rarity_Radiant Rare",
30
+ "rarity_Rare",
31
+ "rarity_Rare Ace",
32
+ "rarity_Rare BREAK",
33
+ "rarity_Secret Rare",
34
+ "rarity_Shiny Holo Rare",
35
+ "rarity_Shiny Rare",
36
+ "rarity_Shiny Ultra Rare",
37
+ "rarity_Special Illustration Rare",
38
+ "rarity_Ultra Rare",
39
+ "rarity_Uncommon",
40
+ "energyType_Colorless",
41
+ "energyType_Darkness",
42
+ "energyType_Dragon",
43
+ "energyType_Energy",
44
+ "energyType_Fairy",
45
+ "energyType_Fighting",
46
+ "energyType_Fire",
47
+ "energyType_Grass",
48
+ "energyType_Lightning",
49
+ "energyType_Metal",
50
+ "energyType_Psychic",
51
+ "energyType_Water",
52
+ "energyType_nan",
53
+ "cardType_Energy",
54
+ "cardType_Item",
55
+ "cardType_Pokemon",
56
+ "cardType_Stadium",
57
+ "cardType_Supporter",
58
+ "cardType_Tool",
59
+ "cardType_Trainer",
60
+ "cardType_nan",
61
+ "variant_1st Edition",
62
+ "variant_1st Edition Holofoil",
63
+ "variant_Holofoil",
64
+ "variant_Normal",
65
+ "variant_Reverse Holofoil",
66
+ "variant_Unlimited",
67
+ "variant_Unlimited Holofoil",
68
+ "variant_nan"
69
+ ]
70
+ }
force_plot.html ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38b217807a8bf227beba2a74448010f2234742071f415f52ea9429915d37cd54
3
+ size 199132
network.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ """
5
+ Neural Network Classifier Architecture
6
+ """
7
+
8
+ class PricePredictor(nn.Module):
9
+
10
+ def __init__(self, input_size: int):
11
+ super(PricePredictor, self).__init__()
12
+ self.model = nn.Sequential(
13
+ nn.Linear(input_size, 256),
14
+ nn.ReLU(),
15
+ nn.Dropout(0.4),
16
+ nn.Linear(256, 128),
17
+ nn.ReLU(),
18
+ nn.Dropout(0.4),
19
+ nn.Linear(128, 1),
20
+ )
21
+
22
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
23
+ return self.model(x)
scaler.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57bae1c7e9c16028c4f21def0302ba1514e7a3d8be131937702da75007ccd866
3
+ size 2151
summary_plot.png ADDED