munem420's picture
Update app.py
6d7a718 verified
raw
history blame
3.4 kB
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"
print("--- Downloading model and scalers ---")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
print("βœ… Files downloaded successfully.")
except Exception as e:
print(f"❌ Error downloading files: {e}")
model_path, scalers_path = None, None
loaded_model_lstm = None
loaded_scalers = None
if model_path and os.path.exists(model_path):
try:
loaded_model_lstm = tf.keras.models.load_model(
model_path,
custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
)
print("βœ… Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
if scalers_path and os.path.exists(scalers_path):
try:
loaded_scalers = joblib.load(scalers_path)
print("βœ… Scalers loaded successfully.")
except Exception as e:
print(f"❌ Error loading scalers: {e}")
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
def get_ticker_from_input(input_name):
return input_name.upper()
def forecast_stock(input_name, model, scalers_dict, input_width=60):
if not model or not scalers_dict:
return "Error: Model or scalers not loaded. The backend may have failed to start."
ticker = get_ticker_from_input(input_name)
if not ticker:
return "Error: Invalid stock ticker."
print(f"\n--- Generating forecast for {ticker} ---")
if len(data_df) < input_width:
return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
recent_data = data_df.tail(input_width)
close_prices = recent_data['Close'].values.reshape(-input, 1)
scaler = scalers_dict.get(ticker)
if not scaler:
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
scaler = scalers_dict.get('ZURVY')
if not scaler:
return "Error: Default scaler 'ZURVY' not found."
scaled_data = scaler.transform(close_prices)
X_pred = scaled_data.reshape(1, input_width, 1)
prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
last_close = recent_data['Close'].iloc[-1]
result = (
f"Last known close for {ticker}: ${last_close:.2f}\n"
f"Predicted next day's close price: ${prediction_actual:.2f}"
)
print(result)
return result
def predict_api(ticker_symbol):
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
with gr.Blocks() as app:
gr.Markdown("This is the backend for the React Stock Forecaster App.")
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
output_text = gr.Textbox(label="Forecast", visible=False)
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
app = gr.mount_static_directory(app, "build")
if __name__ == "__main__":
app.launch()