munem420's picture
Create app.py
4b9dc02 verified
raw
history blame
4.89 kB
import gradio as gr
import tensorflow as tf
import joblib
import os
import numpy as np
import pandas as pd
import yfinance as yf
from huggingface_hub import hf_hub_download
# --- 1. Download Model and Scalers from Hugging Face Hub ---
# This is better than manually uploading them. The Space will fetch them automatically.
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"
print("--- Downloading model and scalers ---")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
print("βœ… Files downloaded successfully.")
except Exception as e:
print(f"❌ Error downloading files: {e}")
model_path, scalers_path = None, None
# --- 2. Load the Model and Scalers ---
loaded_model_lstm = None
loaded_scalers = None
if model_path and os.path.exists(model_path):
try:
loaded_model_lstm = tf.keras.models.load_model(model_path)
print("βœ… Model loaded successfully.")
except Exception as e:
print(f"❌ Error loading model: {e}")
if scalers_path and os.path.exists(scalers_path):
try:
loaded_scalers = joblib.load(scalers_path)
print("βœ… Scalers loaded successfully.")
except Exception as e:
print(f"❌ Error loading scalers: {e}")
# This dictionary is part of the original model's logic.
# A more robust solution would fetch this dynamically or store it better.
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
def get_ticker_from_input(input_name):
# Simplified version for this app
return input_name.upper()
# --- 3. The Main Forecasting Function (Adapted from your code) ---
def forecast_stock(input_name, model, scalers_dict, input_width=60):
if not model or not scalers_dict:
return "Error: Model or scalers not loaded."
ticker = get_ticker_from_input(input_name)
if not ticker:
return "Error: Invalid stock ticker."
print(f"\n--- Generating forecast for {ticker} ---")
# Fetch recent data using yfinance
try:
data_df = yf.download(ticker, period="1y", progress=False)
if data_df.empty:
return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
except Exception as e:
return f"Error fetching data for {ticker}: {e}"
if len(data_df) < input_width:
return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only have {len(data_df)}."
recent_data = data_df.tail(input_width)
close_prices = recent_data['Close'].values.reshape(-1, 1)
# Note: The original scalers were trained on specific stocks.
# Using a scaler for a different stock (e.g., AAPL) on a new ticker might not be accurate.
# For this example, we'll try to find a matching scaler or default to a common one.
scaler = scalers_dict.get(ticker)
if not scaler:
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
scaler = scalers_dict.get('ZURVY')
if not scaler:
return "Error: Default scaler 'ZURVY' not found."
scaled_data = scaler.transform(close_prices)
X_pred = scaled_data.reshape(1, input_width, 1)
prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
last_close = recent_data['Close'].iloc[-1]
result = (
f"Last known close for {ticker}: ${last_close:.2f}\n"
f"Predicted next day's close price: ${prediction_actual:.2f}"
)
print(result)
return result
# --- 4. Create the Gradio Interface ---
# We create a simple function that Gradio can expose as an API endpoint.
def predict_api(ticker_symbol):
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
# We use a dummy Gradio interface because we only need its backend API capabilities.
# The `gr.Blocks()` allows us to run the server without displaying a UI.
with gr.Blocks() as app:
gr.Markdown("This is the backend for the React Stock Forecaster App.")
# This creates an API endpoint at `/run/predict`
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
output_text = gr.Textbox(label="Forecast", visible=False)
# The Gradio API function must be tied to an event
# We will call this endpoint from our React app.
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
# --- 5. Mount and Serve the React App's Static Files ---
# Before running this, you must build your React app using `npm run build`.
# This will create a `build` directory with static files.
# Gradio will serve the `index.html` from this directory.
app.mount_static_directory("./build")
# Launch the server
app.launch()