munem420's picture
Update app.py
4ef68bd verified
raw
history blame
3.18 kB
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"
print("--- Downloading model and scalers ---")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
print("βœ… Files downloaded successfully.")
except Exception as e:
print(f"❌ Error downloading files: {e}")
model_path, scalers_path = None, None
loaded_model_lstm = None
loaded_scalers = None
if model_path and os.path.exists(model_path):
try:
loaded_model_lstm = tf.keras.models.load_model(
model_path,
custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
)
print("βœ… Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
if scalers_path and os.path.exists(scalers_path):
try:
loaded_scalers = joblib.load(scalers_path)
print("βœ… Scalers loaded successfully.")
except Exception as e:
print(f"❌ Error loading scalers: {e}")
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
# Example placeholder DataFrame (replace with your actual data)
data_df = pd.DataFrame({
"Date": pd.date_range(start="2024-01-01", periods=100),
"Close": np.linspace(100, 200, 100)
})
def get_ticker_from_input(input_name):
return input_name.upper().strip()
def forecast_stock(input_name):
model = loaded_model_lstm
scalers_dict = loaded_scalers
input_width = 60
if not model or not scalers_dict:
return "Error: Model or scalers not loaded."
ticker = get_ticker_from_input(input_name)
if len(data_df) < input_width:
return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
recent_data = data_df.tail(input_width)
close_prices = recent_data['Close'].values.reshape(-1, 1)
scaler = scalers_dict.get(ticker)
if not scaler:
print(f"⚠️ No specific scaler for {ticker}. Using fallback.")
scaler = scalers_dict.get('ZURVY')
if not scaler:
return "Error: No default scaler found."
scaled_data = scaler.transform(close_prices)
X_pred = scaled_data.reshape(1, input_width, 1)
prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
last_close = recent_data['Close'].iloc[-1]
return f"Last close for {ticker}: ${last_close:.2f}\nPredicted next day close: ${prediction_actual:.2f}"
# βœ… Simple Gradio interface
iface = gr.Interface(
fn=forecast_stock,
inputs=gr.Textbox(label="Enter Ticker or Company Name"),
outputs=gr.Textbox(label="Predicted Next Day Close"),
title="Stock Price Forecaster (LSTM)",
description="Enter a stock ticker or company name to predict the next day's close price."
)
if __name__ == "__main__":
iface.launch()