munem420's picture
Update app.py
77fb042 verified
raw
history blame
4.07 kB
import os
import gradio as gr
import tensorflow as tf
import joblib
import numpy as np
import pandas as pd
import yfinance as yf
from huggingface_hub import hf_hub_download
# This forces TensorFlow to only use the CPU.
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# --- 1. Download Model and Scalers from Hugging Face Hub ---
# FIX #1: Ensured the repository name is 100% correct.
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.keras"
SCALER_FILENAME = "scalers.joblib"
print("--- Downloading model and scalers ---")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
print("βœ… Files downloaded successfully.")
except Exception as e:
print(f"❌ Error downloading files: {e}")
model_path, scalers_path = None, None
# --- 2. Load the Model and Scalers ---
loaded_model_lstm = None
loaded_scalers = None
if model_path and os.path.exists(model_path):
try:
loaded_model_lstm = tf.keras.models.load_model(model_path)
print("βœ… Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
if scalers_path and os.path.exists(scalers_path):
try:
loaded_scalers = joblib.load(scalers_path)
print("βœ… Scalers loaded successfully.")
except Exception as e:
print(f"❌ Error loading scalers: {e}")
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
def get_ticker_from_input(input_name):
return input_name.upper()
# --- 3. The Main Forecasting Function ---
def forecast_stock(input_name, model, scalers_dict, input_width=60):
if not model or not scalers_dict:
return "Error: Model or scalers not loaded. The backend may have failed to start."
# ... (The rest of the function is the same)
ticker = get_ticker_from_input(input_name)
if not ticker:
return "Error: Invalid stock ticker."
print(f"\n--- Generating forecast for {ticker} ---")
try:
data_df = yf.download(ticker, period="1y", progress=False)
if data_df.empty:
return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
except Exception as e:
return f"Error fetching data for {ticker}: {e}"
if len(data_df) < input_width:
return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
recent_data = data_df.tail(input_width)
close_prices = recent_data['Close'].values.reshape(-1, 1)
scaler = scalers_dict.get(ticker)
if not scaler:
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
scaler = scalers_dict.get('ZURVY')
if not scaler:
return "Error: Default scaler 'ZURVY' not found."
scaled_data = scaler.transform(close_prices)
X_pred = scaled_data.reshape(1, input_width, 1)
prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
last_close = recent_data['Close'].iloc[-1]
result = (
f"Last known close for {ticker}: ${last_close:.2f}\n"
f"Predicted next day's close price: ${prediction_actual:.2f}"
)
print(result)
return result
# --- 4. Create the Gradio Interface ---
def predict_api(ticker_symbol):
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
with gr.Blocks() as app:
gr.Markdown("This is the backend for the React Stock Forecaster App.")
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
output_text = gr.Textbox(label="Forecast", visible=False)
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
# --- 5. Mount and Serve the React App's Static Files ---
# This function requires a modern version of Gradio, specified in README.md
app = gr.mount_static_directory(app, "build")
# Launch the server
if __name__ == "__main__":
app.launch()