import os import gradio as gr import tensorflow as tf import joblib import numpy as np # ------------------------------------------------------- # CONFIG # ------------------------------------------------------- MODEL_REPO = "munem420/stock-forecaster-lstm" MODEL_FILENAME = "model_lstm.h5" SCALER_FILENAME = "scalers.joblib" os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # ------------------------------------------------------- # LOAD MODEL AND SCALERS # ------------------------------------------------------- print("📦 Loading model and scalers...") try: model_path = tf.keras.utils.get_file( MODEL_FILENAME, f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILENAME}" ) scalers_path = tf.keras.utils.get_file( SCALER_FILENAME, f"https://huggingface.co/{MODEL_REPO}/resolve/main/{SCALER_FILENAME}" ) model = tf.keras.models.load_model( model_path, custom_objects={"mse": tf.keras.losses.MeanSquaredError()} ) scalers = joblib.load(scalers_path) print("✅ Model and scalers loaded successfully.") except Exception as e: print(f"❌ Error loading model or scalers: {e}") model, scalers = None, None # ------------------------------------------------------- # FORECAST FUNCTION # ------------------------------------------------------- def forecast_stock(ticker): if not model or not scalers: return "❌ Model or scalers not loaded properly." ticker = ticker.strip().upper() if ticker not in scalers: return f"⚠️ No scaler found for ticker '{ticker}'. Please check spelling." # Dummy inference example (replace with actual data fetching or preprocessing) # Here we just simulate 60 normalized close prices for inference scaler = scalers[ticker] dummy_data = np.linspace(0.9, 1.1, 60).reshape(-1, 1) X_pred = dummy_data.reshape(1, 60, 1) # Predict scaled value pred_scaled = model.predict(X_pred, verbose=0)[0][0] # Inverse transform prediction pred_actual = scaler.inverse_transform(np.array([[pred_scaled]]))[0][0] return f"🔮 Predicted next day close for **{ticker}**: ${pred_actual:.2f}" # ------------------------------------------------------- # GRADIO INTERFACE # ------------------------------------------------------- iface = gr.Interface( fn=forecast_stock, inputs=gr.Textbox(label="Enter Ticker or Company Name"), outputs=gr.Markdown(label="Prediction Result"), title="📊 Stock Price Forecaster (LSTM)", description="Enter a stock ticker or company name to predict the next day's closing price using a trained LSTM model." ) if __name__ == "__main__": iface.launch()