File size: 5,021 Bytes
59cce6a
 
 
 
 
 
 
 
 
 
 
 
 
 
aff826c
59cce6a
 
5256800
59cce6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5256800
59cce6a
 
5256800
 
 
59cce6a
5256800
 
59cce6a
 
5256800
59cce6a
 
 
 
 
 
 
 
5256800
aff826c
 
59cce6a
5256800
 
 
 
 
187e2a5
59cce6a
 
 
 
 
aff826c
59cce6a
 
 
 
 
5256800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59cce6a
 
5256800
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import gradio as gr
import torch
import joblib
import pandas as pd
import os
import json
from safetensors.torch import load_file
from typing import List, Tuple
from network import PricePredictor

MODEL_DIR = "model"
DATA_DIR = "data"
SCALER_PATH = os.path.join(DATA_DIR, "scaler.pkl")
DATA_PATH = os.path.join(DATA_DIR, "pokemon_final_with_labels.csv")
TARGET_COLUMN = 'price_will_rise_30_in_6m'



def load_model_and_config(model_dir: str) -> Tuple[torch.nn.Module, List[str]]:
    config_path = os.path.join(model_dir, "config.json")
    with open(config_path, "r") as f:
        model_config = json.load(f)

    model = PricePredictor(input_size=model_config["input_size"])
    weights_path = os.path.join(model_dir, "model.safetensors")
    model.load_state_dict(load_file(weights_path))
    model.eval()
    return model, model_config["feature_columns"]


def perform_prediction(model: torch.nn.Module, scaler, input_features: pd.Series) -> Tuple[bool, float]:
    features_np = input_features.to_numpy(dtype="float32").reshape(1, -1)
    features_scaled = scaler.transform(features_np)
    features_tensor = torch.tensor(features_scaled, dtype=torch.float32)

    with torch.no_grad():
        logit = model(features_tensor)
        probability = torch.sigmoid(logit).item()
        predicted_class = bool(round(probability))

    return predicted_class, probability

try:
    model, feature_columns = load_model_and_config(MODEL_DIR)
    scaler = joblib.load(SCALER_PATH)
    full_data = pd.read_csv(DATA_PATH)
    ASSETS_LOADED = True
except FileNotFoundError as e:
    print(f"Error loading necessary files: {e}")
    print("Please make sure you have uploaded the 'model' and 'data' directories to your Hugging Face Space.")
    ASSETS_LOADED = False


def predict_price_trend(card_identifier: str) -> str:
    if not ASSETS_LOADED:
        return "## Application Error\nAssets could not be loaded. Please check the logs on Hugging Face Spaces for details. You may need to upload your `model` and `data` directories."
 
    if not card_identifier or not card_identifier.strip().isdigit():
        return "## Input Error\nPlease enter a valid, numeric TCGPlayer ID."

    card_id = int(card_identifier.strip())
    card_data = full_data[full_data['tcgplayer_id'] == card_id]

    if card_data.empty:
        return f"## Card Not Found\nCould not find a card with TCGPlayer ID '{card_id}'. Please check the ID and try again."

    card_sample = card_data.iloc[0]
    sample_features = card_sample[feature_columns]

    predicted_class, probability = perform_prediction(model, scaler, sample_features)

    prediction_text = "**RISE**" if predicted_class else "**NOT RISE**"
    confidence = probability if predicted_class else 1 - probability
    tcgplayer_id = card_sample['tcgplayer_id']
    tcgplayer_link = f"https://www.tcgplayer.com/product/{tcgplayer_id}?Language=English"

    true_label_text = ""
    try:
        if TARGET_COLUMN in card_sample and pd.notna(card_sample[TARGET_COLUMN]):
            true_label = bool(card_sample[TARGET_COLUMN])
            true_label_text = f"\n- **Actual Result in Dataset:** The price did **{'RISE' if true_label else 'NOT RISE'}**."
    except (KeyError, TypeError):
        pass

    output = f"""
    ## 🔮 Prediction Report for {card_sample['name']}
    - **Prediction:** The model predicts the card's price will {prediction_text} by 30% in the next 6 months.
    - **Confidence:** {confidence:.2%}
    - **View on TCGPlayer:** [Check Current Price]({tcgplayer_link})
    {true_label_text}
    """
    return output


with gr.Blocks(theme=gr.themes.Soft(), title="PricePoke Predictor") as demo:
    gr.Markdown(
        """
        # 📈 PricePoke: Pokémon Card Price Trend Predictor
        Enter a Pokémon card's TCGPlayer ID to predict whether its market price will increase by 30% or more over the next 6 months.
        This model was trained on historical TCGPlayer market data.
        """
    )
    with gr.Row():
        with gr.Column(scale=1):
            card_input = gr.Textbox(
                label="TCGPlayer ID",
                placeholder="e.g., '84198'",
                info="Find the ID in the card's URL on TCGPlayer's website (e.g., tcgplayer.com/product/84198/... has ID 84198)."
            )
            predict_button = gr.Button("Predict Trend", variant="primary")
            
            gr.Markdown("---")
            gr.Markdown("### Example Cards")
            if ASSETS_LOADED:
                example_df = full_data.sample(5, random_state=42)[['name', 'tcgplayer_id']]
                gr.Markdown(example_df.to_markdown(index=False))
            else:
                gr.Markdown("Could not load examples.")

        with gr.Column(scale=2):
            output_markdown = gr.Markdown()

    predict_button.click(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown])
    card_input.submit(fn=predict_price_trend, inputs=[card_input], outputs=[output_markdown])

if __name__ == "__main__":
    demo.launch()