Spaces:
Running
Running
File size: 4,893 Bytes
4b9dc02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import gradio as gr
import tensorflow as tf
import joblib
import os
import numpy as np
import pandas as pd
import yfinance as yf
from huggingface_hub import hf_hub_download
# --- 1. Download Model and Scalers from Hugging Face Hub ---
# This is better than manually uploading them. The Space will fetch them automatically.
MODEL_REPO = "munem420/stock-forecaster-lstm"
MODEL_FILENAME = "model_lstm.h5"
SCALER_FILENAME = "scalers.joblib"
print("--- Downloading model and scalers ---")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
print("β
Files downloaded successfully.")
except Exception as e:
print(f"β Error downloading files: {e}")
model_path, scalers_path = None, None
# --- 2. Load the Model and Scalers ---
loaded_model_lstm = None
loaded_scalers = None
if model_path and os.path.exists(model_path):
try:
loaded_model_lstm = tf.keras.models.load_model(model_path)
print("β
Model loaded successfully.")
except Exception as e:
print(f"β Error loading model: {e}")
if scalers_path and os.path.exists(scalers_path):
try:
loaded_scalers = joblib.load(scalers_path)
print("β
Scalers loaded successfully.")
except Exception as e:
print(f"β Error loading scalers: {e}")
# This dictionary is part of the original model's logic.
# A more robust solution would fetch this dynamically or store it better.
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
def get_ticker_from_input(input_name):
# Simplified version for this app
return input_name.upper()
# --- 3. The Main Forecasting Function (Adapted from your code) ---
def forecast_stock(input_name, model, scalers_dict, input_width=60):
if not model or not scalers_dict:
return "Error: Model or scalers not loaded."
ticker = get_ticker_from_input(input_name)
if not ticker:
return "Error: Invalid stock ticker."
print(f"\n--- Generating forecast for {ticker} ---")
# Fetch recent data using yfinance
try:
data_df = yf.download(ticker, period="1y", progress=False)
if data_df.empty:
return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
except Exception as e:
return f"Error fetching data for {ticker}: {e}"
if len(data_df) < input_width:
return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only have {len(data_df)}."
recent_data = data_df.tail(input_width)
close_prices = recent_data['Close'].values.reshape(-1, 1)
# Note: The original scalers were trained on specific stocks.
# Using a scaler for a different stock (e.g., AAPL) on a new ticker might not be accurate.
# For this example, we'll try to find a matching scaler or default to a common one.
scaler = scalers_dict.get(ticker)
if not scaler:
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
scaler = scalers_dict.get('ZURVY')
if not scaler:
return "Error: Default scaler 'ZURVY' not found."
scaled_data = scaler.transform(close_prices)
X_pred = scaled_data.reshape(1, input_width, 1)
prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
last_close = recent_data['Close'].iloc[-1]
result = (
f"Last known close for {ticker}: ${last_close:.2f}\n"
f"Predicted next day's close price: ${prediction_actual:.2f}"
)
print(result)
return result
# --- 4. Create the Gradio Interface ---
# We create a simple function that Gradio can expose as an API endpoint.
def predict_api(ticker_symbol):
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
# We use a dummy Gradio interface because we only need its backend API capabilities.
# The `gr.Blocks()` allows us to run the server without displaying a UI.
with gr.Blocks() as app:
gr.Markdown("This is the backend for the React Stock Forecaster App.")
# This creates an API endpoint at `/run/predict`
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
output_text = gr.Textbox(label="Forecast", visible=False)
# The Gradio API function must be tied to an event
# We will call this endpoint from our React app.
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
# --- 5. Mount and Serve the React App's Static Files ---
# Before running this, you must build your React app using `npm run build`.
# This will create a `build` directory with static files.
# Gradio will serve the `index.html` from this directory.
app.mount_static_directory("./build")
# Launch the server
app.launch() |