import torch import numpy as np import pandas as pd import gradio as gr import joblib from model import LSTMModel # Constants input_size = 20 # Load model and scalers model = LSTMModel(input_size=input_size) model.load_state_dict(torch.load("model.pth", map_location="cpu")) model.eval() input_scaler = joblib.load("input_scaler.pkl") target_scaler = joblib.load("target_scaler.pkl") # Prediction logic def predict_from_array(array_2d): if array_2d.shape != (10, input_size): return f"❌ Expected shape (10, {input_size}), got {array_2d.shape}" try: scaled = input_scaler.transform(array_2d) tensor = torch.tensor(scaled, dtype=torch.float32).unsqueeze(0) with torch.no_grad(): pred_scaled = model(tensor).item() pred_real = target_scaler.inverse_transform([[pred_scaled]])[0][0] return f"✅ Predicted Plus/Minus: {pred_real:.2f}" except Exception as e: return f"❌ Error during prediction: {e}" # CSV Upload Handler def predict_from_csv(file): try: df = pd.read_csv(file.name) if df.shape != (10, input_size): return f"❌ CSV must have shape (10, {input_size}). Got {df.shape}" return predict_from_array(df.to_numpy(dtype=np.float32)) except Exception as e: return f"❌ Error reading CSV: {e}" # Generate random input def generate_random_input(): return pd.DataFrame(np.random.uniform(0, 100, size=(10, input_size))) def predict_from_table(df): try: if df.shape != (10, input_size): return f"❌ Table must have shape (10, {input_size})" return predict_from_array(df.to_numpy(dtype=np.float32)) except Exception as e: return f"❌ Error: {e}" # Build UI with gr.Blocks() as app: gr.Markdown("## 🏀 NBA Plus/Minus Predictor") with gr.Tab("🎲 Generate Random Input"): with gr.Row(): btn = gr.Button("Generate Random 10x20 Sample") predict_btn = gr.Button("Predict") input_table = gr.Dataframe( headers=[f"F{i+1}" for i in range(input_size)], row_count=10, col_count=input_size, interactive=True, label="Generated Features" ) output_text = gr.Textbox(label="Prediction") btn.click(fn=generate_random_input, outputs=input_table) predict_btn.click(fn=predict_from_table, inputs=input_table, outputs=output_text) with gr.Tab("📂 Upload CSV"): with gr.Row(): file_input = gr.File(label="Upload CSV with shape (10×20)") predict_btn = gr.Button("Predict") file_output = gr.Textbox(label="Prediction") predict_btn.click(fn=predict_from_csv, inputs=file_input, outputs=file_output) app.launch(share=True, show_api=True, show_error=True)