munem420's picture
Update app.py
f0596fa verified
raw
history blame
3.94 kB
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import pandas as pd
import yfinance as yf
from huggingface_hub import hf_hub_download
# This forces TensorFlow to only use the CPU.
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# --- 1. Download Model and Scalers from Hugging Face Hub ---
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"
print("--- Downloading model and scalers ---")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
print("βœ… Files downloaded successfully.")
except Exception as e:
print(f"❌ Error downloading files: {e}")
model_path, scalers_path = None, None
# --- 2. Load the Model and Scalers ---
loaded_model_lstm = None
loaded_scalers = None
if model_path and os.path.exists(model_path):
try:
# FIX #1: Added custom_objects to handle the 'mse' metric during loading
loaded_model_lstm = tf.keras.models.load_model(
model_path,
custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
)
print("βœ… Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
if scalers_path and os.path.exists(scalers_path):
try:
loaded_scalers = joblib.load(scalers_path)
print("βœ… Scalers loaded successfully.")
except Exception as e:
print(f"❌ Error loading scalers: {e}")
# ... (The rest of the file is unchanged) ...
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
def get_ticker_from_input(input_name):
return input_name.upper()
def forecast_stock(input_name, model, scalers_dict, input_width=60):
if not model or not scalers_dict:
return "Error: Model or scalers not loaded. The backend may have failed to start."
ticker = get_ticker_from_input(input_name)
if not ticker:
return "Error: Invalid stock ticker."
print(f"\n--- Generating forecast for {ticker} ---")
try:
data_df = yf.download(ticker, period="1y", progress=False)
if data_df.empty:
return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
except Exception as e:
return f"Error fetching data for {ticker}: {e}"
if len(data_df) < input_width:
return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
recent_data = data_df.tail(input_width)
close_prices = recent_data['Close'].values.reshape(-1, 1)
scaler = scalers_dict.get(ticker)
if not scaler:
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
scaler = scalers_dict.get('ZURVY')
if not scaler:
return "Error: Default scaler 'ZURVY' not found."
scaled_data = scaler.transform(close_prices)
X_pred = scaled_data.reshape(1, input_width, 1)
prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
last_close = recent_data['Close'].iloc[-1]
result = (
f"Last known close for {ticker}: ${last_close:.2f}\n"
f"Predicted next day's close price: ${prediction_actual:.2f}"
)
print(result)
return result
def predict_api(ticker_symbol):
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
with gr.Blocks() as app:
gr.Markdown("This is the backend for the React Stock Forecaster App.")
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
output_text = gr.Textbox(label="Forecast", visible=False)
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
app = gr.mount_static_directory(app, "build")
if __name__ == "__main__":
app.launch()