PlayflySports / app.py
Mpodszus's picture
Update app.py
5981fc9 verified
import pickle
import pandas as pd
import shap
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Load the XGBoost model
loaded_model = pickle.load(open("revenue_model.pkl", 'rb'))
# SHAP Explainer for XGBoost
explainer = shap.Explainer(loaded_model)
def safe_convert(value, default, min_val, max_val):
try:
num = float(value)
return max(min_val, min(num, max_val))
except (TypeError, ValueError):
return default
def main_func(InventoryBucket, Category, Product, Geography, Conference):
inventory_mapping = {
Inventory_Bucket_mapping = {
'Presenting Sponsor': 0,
'Rotational Signage': 1,
'Field-Level Signage': 2,
'Concourse Signage': 3,
'Fan Engagement Area': 4,
'Scoreboard Signage': 5,
'Marketing Space': 6,
'Tabling': 7,
'Game Sponsor': 8,
'Game Entitlement': 9,
'Rights & Licensing': 10,
'Autographed Item': 11,
'Ticket': 12,
'Schedule': 13,
'Program and Guide (Print)': 14,
'Activation': 15,
'Entitlement/Sponsorship': 16,
'Spot': 17,
'Live Read': 18,
'Live Mention': 19,
'Post Season Revenue': 20,
'Feature': 21,
'Season Ticket': 22,
'Field or Event Access': 23,
'LED Ribbon Board Signage': 24,
'Parking': 25,
'Club Ticket': 26,
'Courtside Ticket': 27,
'Roster Card': 28,
'Videoboard Promotion': 29,
'In-Game Promotion': 30,
'Miscellaneous Signage': 31,
'Sideline Branding': 32,
'Sideline Signage': 33,
'Sideline Assets': 34,
'Tailgating': 35,
'Suite': 36,
'Giveaway': 37,
'PA Announcement': 38,
'Venue': 39,
'Concession Signage': 40,
'Venue Entitlement': 41,
'Photo & Meet and Greet': 42,
'Billboard': 43,
'Group Tickets': 44,
'Post Season Ticket': 45,
'Display': 46,
'Yearbook': 47,
'Program and Guide (Digital)': 48,
'On-Field Promotion': 49,
'Tour': 50,
'Poster': 51,
'Stair Sign': 52,
'Pre-Game Specialty Sponsor': 53,
'Box Ticket': 54,
'Vomitory Signage': 55,
'Trip': 56,
'Virtual Signage': 57,
}
Category_mapping = {
'Entitlement & Sponsorship': 0,
'Signage': 1,
'Activation': 2,
'Intellectual Property': 3,
'Miscellaneous': 4,
'Print': 5,
'Promotion': 6,
'Radio': 7,
'Tickets & Hospitality': 8,
'Television': 9,
'Experiential': 10,
}
Product_mapping = {
'Arena Sports': 0,
'Baseball': 1,
'Basketball - Men's': 2,
'Basketball - Women's': 3,
'Football': 4,
'Gymnastics': 5,
'Olympic Sports': 6,
'Soccer': 7,
'Softball': 8,
'Volleyball': 9,
'Lacrosse - Men's': 10,
'Lacrosse - Women's': 11,
'Soccer - Men's': 12,
'Hockey': 13,
'Basketball': 14,
'Soccer - Women's': 15,
'Wrestling': 16,
'Lacrosse': 17,
}
Geography_mapping = {
'South': 0,
'East Coast': 1,
'Midwest': 2,
'West Coast': 3,
'Mountain West': 4,
}
Conference_mapping = {
'Power Four': 0,
'Group of Five': 1,
'Non-Power Four / Group of Five': 2,
'Conference': 3,
}
new_row = pd.DataFrame({
'InventoryBucket': [inventory_mapping[InventoryBucket]],
'Category': [category_mapping[Category]],
'Product': [product_mapping[Product]],
'Geography': [geography_mapping[Geography]],
'Conference': [conference_mapping[Conference]]
}).astype(float)
prediction = loaded_model.predict(new_row)[0]
shap_values = explainer(new_row)
fig, ax = plt.subplots(figsize=(8, 4))
shap.waterfall_plot(shap.Explanation(
values=shap_values.values[0],
base_values=shap_values.base_values[0],
data=new_row.iloc[0]
))
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return f"Predicted Revenue: ${prediction:,.2f}", local_plot
with gr.Blocks(title="Playfly Revenue Predictor") as demo:
with gr.Row():
gr.Markdown("## Playfly Revenue Predictor & Interpreter")
logo = gr.Image(label="", value="play-fly-logo.png", shape=(75, 75))
gr.Markdown("This app predicts **revenue** based on selected inventory, category, sport, geography, and conference.")
gr.Markdown("---")
label = gr.Label(label="Revenue Prediction")
local_plot = gr.Plot(label="SHAP Waterfall Plot")
with gr.Row():
InventoryBucket = gr.Dropdown(list(inventory_mapping.keys()), label="Inventory Bucket")
Category = gr.Dropdown(list(category_mapping.keys()), label="Category")
Product = gr.Dropdown(list(product_mapping.keys()), label="Product")
Geography = gr.Dropdown(list(geography_mapping.keys()), label="Geography")
Conference = gr.Dropdown(list(conference_mapping.keys()), label="Conference")
analyze_btn = gr.Button("Analyze", elem_id="analyze_btn")
analyze_btn.click(
main_func,
[InventoryBucket, Category, Product, Geography, Conference],
[label, local_plot]
)
gr.Markdown("---")
gr.Examples(
[
["Field-Level Signage", "Signage", "Football", "East Coast", "Power Four"],
["Program and Guide (Digital)", "Print", "Basketball - Women's", "Midwest", "Group of Five"]
],
[InventoryBucket, Category, Product, Geography, Conference],
[label, local_plot],
main_func,
cache_examples=True
)
demo.load(None, None, js="""
document.getElementById("analyze_btn").style.backgroundColor = "#4169E1";
document.getElementById("analyze_btn").style.color = "white";
""")
demo.launch()