Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import joblib | |
| from model import LSTMModel | |
| # Constants | |
| input_size = 20 | |
| # Load model and scalers | |
| model = LSTMModel(input_size=input_size) | |
| model.load_state_dict(torch.load("model.pth", map_location="cpu")) | |
| model.eval() | |
| input_scaler = joblib.load("input_scaler.pkl") | |
| target_scaler = joblib.load("target_scaler.pkl") | |
| # Prediction logic | |
| def predict_from_array(array_2d): | |
| if array_2d.shape != (10, input_size): | |
| return f"❌ Expected shape (10, {input_size}), got {array_2d.shape}" | |
| try: | |
| scaled = input_scaler.transform(array_2d) | |
| tensor = torch.tensor(scaled, dtype=torch.float32).unsqueeze(0) | |
| with torch.no_grad(): | |
| pred_scaled = model(tensor).item() | |
| pred_real = target_scaler.inverse_transform([[pred_scaled]])[0][0] | |
| return f"✅ Predicted Plus/Minus: {pred_real:.2f}" | |
| except Exception as e: | |
| return f"❌ Error during prediction: {e}" | |
| # CSV Upload Handler | |
| def predict_from_csv(file): | |
| try: | |
| df = pd.read_csv(file.name) | |
| if df.shape != (10, input_size): | |
| return f"❌ CSV must have shape (10, {input_size}). Got {df.shape}" | |
| return predict_from_array(df.to_numpy(dtype=np.float32)) | |
| except Exception as e: | |
| return f"❌ Error reading CSV: {e}" | |
| # Generate random input | |
| def generate_random_input(): | |
| return pd.DataFrame(np.random.uniform(0, 100, size=(10, input_size))) | |
| def predict_from_table(df): | |
| try: | |
| if df.shape != (10, input_size): | |
| return f"❌ Table must have shape (10, {input_size})" | |
| return predict_from_array(df.to_numpy(dtype=np.float32)) | |
| except Exception as e: | |
| return f"❌ Error: {e}" | |
| # Build UI | |
| with gr.Blocks() as app: | |
| gr.Markdown("## 🏀 NBA Plus/Minus Predictor") | |
| with gr.Tab("🎲 Generate Random Input"): | |
| with gr.Row(): | |
| btn = gr.Button("Generate Random 10x20 Sample") | |
| predict_btn = gr.Button("Predict") | |
| input_table = gr.Dataframe( | |
| headers=[f"F{i+1}" for i in range(input_size)], | |
| row_count=10, col_count=input_size, | |
| interactive=True, | |
| label="Generated Features" | |
| ) | |
| output_text = gr.Textbox(label="Prediction") | |
| btn.click(fn=generate_random_input, outputs=input_table) | |
| predict_btn.click(fn=predict_from_table, inputs=input_table, outputs=output_text) | |
| with gr.Tab("📂 Upload CSV"): | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload CSV with shape (10×20)") | |
| predict_btn = gr.Button("Predict") | |
| file_output = gr.Textbox(label="Prediction") | |
| predict_btn.click(fn=predict_from_csv, inputs=file_input, outputs=file_output) | |
| app.launch(share=True, show_api=True, show_error=True) | |