nba_plusminus / app.py
aggtamv's picture
show error change
30fa9c2
import torch
import numpy as np
import pandas as pd
import gradio as gr
import joblib
from model import LSTMModel
# Constants
input_size = 20
# Load model and scalers
model = LSTMModel(input_size=input_size)
model.load_state_dict(torch.load("model.pth", map_location="cpu"))
model.eval()
input_scaler = joblib.load("input_scaler.pkl")
target_scaler = joblib.load("target_scaler.pkl")
# Prediction logic
def predict_from_array(array_2d):
if array_2d.shape != (10, input_size):
return f"❌ Expected shape (10, {input_size}), got {array_2d.shape}"
try:
scaled = input_scaler.transform(array_2d)
tensor = torch.tensor(scaled, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
pred_scaled = model(tensor).item()
pred_real = target_scaler.inverse_transform([[pred_scaled]])[0][0]
return f"✅ Predicted Plus/Minus: {pred_real:.2f}"
except Exception as e:
return f"❌ Error during prediction: {e}"
# CSV Upload Handler
def predict_from_csv(file):
try:
df = pd.read_csv(file.name)
if df.shape != (10, input_size):
return f"❌ CSV must have shape (10, {input_size}). Got {df.shape}"
return predict_from_array(df.to_numpy(dtype=np.float32))
except Exception as e:
return f"❌ Error reading CSV: {e}"
# Generate random input
def generate_random_input():
return pd.DataFrame(np.random.uniform(0, 100, size=(10, input_size)))
def predict_from_table(df):
try:
if df.shape != (10, input_size):
return f"❌ Table must have shape (10, {input_size})"
return predict_from_array(df.to_numpy(dtype=np.float32))
except Exception as e:
return f"❌ Error: {e}"
# Build UI
with gr.Blocks() as app:
gr.Markdown("## 🏀 NBA Plus/Minus Predictor")
with gr.Tab("🎲 Generate Random Input"):
with gr.Row():
btn = gr.Button("Generate Random 10x20 Sample")
predict_btn = gr.Button("Predict")
input_table = gr.Dataframe(
headers=[f"F{i+1}" for i in range(input_size)],
row_count=10, col_count=input_size,
interactive=True,
label="Generated Features"
)
output_text = gr.Textbox(label="Prediction")
btn.click(fn=generate_random_input, outputs=input_table)
predict_btn.click(fn=predict_from_table, inputs=input_table, outputs=output_text)
with gr.Tab("📂 Upload CSV"):
with gr.Row():
file_input = gr.File(label="Upload CSV with shape (10×20)")
predict_btn = gr.Button("Predict")
file_output = gr.Textbox(label="Prediction")
predict_btn.click(fn=predict_from_csv, inputs=file_input, outputs=file_output)
app.launch(share=True, show_api=True, show_error=True)