Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,9 +8,9 @@ import yfinance as yf
|
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
| 10 |
# --- 1. Download Model and Scalers from Hugging Face Hub ---
|
| 11 |
-
# This is better than manually uploading them. The Space will fetch them automatically.
|
| 12 |
MODEL_REPO = "munem420/stock-forecaster-lstm"
|
| 13 |
-
|
|
|
|
| 14 |
SCALER_FILENAME = "scalers.joblib"
|
| 15 |
|
| 16 |
print("--- Downloading model and scalers ---")
|
|
@@ -20,6 +20,7 @@ try:
|
|
| 20 |
print("β
Files downloaded successfully.")
|
| 21 |
except Exception as e:
|
| 22 |
print(f"β Error downloading files: {e}")
|
|
|
|
| 23 |
model_path, scalers_path = None, None
|
| 24 |
|
| 25 |
# --- 2. Load the Model and Scalers ---
|
|
@@ -28,6 +29,7 @@ loaded_scalers = None
|
|
| 28 |
|
| 29 |
if model_path and os.path.exists(model_path):
|
| 30 |
try:
|
|
|
|
| 31 |
loaded_model_lstm = tf.keras.models.load_model(model_path)
|
| 32 |
print("β
Model loaded successfully.")
|
| 33 |
except Exception as e:
|
|
@@ -40,26 +42,21 @@ if scalers_path and os.path.exists(scalers_path):
|
|
| 40 |
except Exception as e:
|
| 41 |
print(f"β Error loading scalers: {e}")
|
| 42 |
|
| 43 |
-
# This dictionary is part of the original model's logic.
|
| 44 |
-
# A more robust solution would fetch this dynamically or store it better.
|
| 45 |
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
|
| 46 |
|
| 47 |
def get_ticker_from_input(input_name):
|
| 48 |
-
# Simplified version for this app
|
| 49 |
return input_name.upper()
|
| 50 |
|
| 51 |
-
# --- 3. The Main Forecasting Function
|
| 52 |
def forecast_stock(input_name, model, scalers_dict, input_width=60):
|
| 53 |
if not model or not scalers_dict:
|
| 54 |
-
return "Error: Model or scalers not loaded."
|
| 55 |
|
| 56 |
ticker = get_ticker_from_input(input_name)
|
| 57 |
if not ticker:
|
| 58 |
return "Error: Invalid stock ticker."
|
| 59 |
|
| 60 |
print(f"\n--- Generating forecast for {ticker} ---")
|
| 61 |
-
|
| 62 |
-
# Fetch recent data using yfinance
|
| 63 |
try:
|
| 64 |
data_df = yf.download(ticker, period="1y", progress=False)
|
| 65 |
if data_df.empty:
|
|
@@ -67,16 +64,12 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
|
|
| 67 |
except Exception as e:
|
| 68 |
return f"Error fetching data for {ticker}: {e}"
|
| 69 |
|
| 70 |
-
|
| 71 |
if len(data_df) < input_width:
|
| 72 |
-
return f"Error: Not enough historical data
|
| 73 |
|
| 74 |
recent_data = data_df.tail(input_width)
|
| 75 |
close_prices = recent_data['Close'].values.reshape(-1, 1)
|
| 76 |
|
| 77 |
-
# Note: The original scalers were trained on specific stocks.
|
| 78 |
-
# Using a scaler for a different stock (e.g., AAPL) on a new ticker might not be accurate.
|
| 79 |
-
# For this example, we'll try to find a matching scaler or default to a common one.
|
| 80 |
scaler = scalers_dict.get(ticker)
|
| 81 |
if not scaler:
|
| 82 |
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
|
|
@@ -91,7 +84,6 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
|
|
| 91 |
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
|
| 92 |
|
| 93 |
last_close = recent_data['Close'].iloc[-1]
|
| 94 |
-
|
| 95 |
result = (
|
| 96 |
f"Last known close for {ticker}: ${last_close:.2f}\n"
|
| 97 |
f"Predicted next day's close price: ${prediction_actual:.2f}"
|
|
@@ -100,27 +92,19 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
|
|
| 100 |
return result
|
| 101 |
|
| 102 |
# --- 4. Create the Gradio Interface ---
|
| 103 |
-
# We create a simple function that Gradio can expose as an API endpoint.
|
| 104 |
def predict_api(ticker_symbol):
|
| 105 |
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
|
| 106 |
|
| 107 |
-
# We use a dummy Gradio interface because we only need its backend API capabilities.
|
| 108 |
-
# The `gr.Blocks()` allows us to run the server without displaying a UI.
|
| 109 |
with gr.Blocks() as app:
|
| 110 |
gr.Markdown("This is the backend for the React Stock Forecaster App.")
|
| 111 |
-
# This creates an API endpoint at `/run/predict`
|
| 112 |
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
|
| 113 |
output_text = gr.Textbox(label="Forecast", visible=False)
|
| 114 |
-
|
| 115 |
-
# The Gradio API function must be tied to an event
|
| 116 |
-
# We will call this endpoint from our React app.
|
| 117 |
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
|
| 118 |
|
| 119 |
# --- 5. Mount and Serve the React App's Static Files ---
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
# Gradio will serve the `index.html` from this directory.
|
| 123 |
-
app.mount_static_directory("./build")
|
| 124 |
|
| 125 |
# Launch the server
|
| 126 |
-
|
|
|
|
|
|
| 8 |
from huggingface_hub import hf_hub_download
|
| 9 |
|
| 10 |
# --- 1. Download Model and Scalers from Hugging Face Hub ---
|
|
|
|
| 11 |
MODEL_REPO = "munem420/stock-forecaster-lstm"
|
| 12 |
+
# FIX #1: Corrected model filename from .h5 to .keras
|
| 13 |
+
MODEL_FILENAME = "model_lstm.keras"
|
| 14 |
SCALER_FILENAME = "scalers.joblib"
|
| 15 |
|
| 16 |
print("--- Downloading model and scalers ---")
|
|
|
|
| 20 |
print("β
Files downloaded successfully.")
|
| 21 |
except Exception as e:
|
| 22 |
print(f"β Error downloading files: {e}")
|
| 23 |
+
# Exit gracefully if files can't be downloaded
|
| 24 |
model_path, scalers_path = None, None
|
| 25 |
|
| 26 |
# --- 2. Load the Model and Scalers ---
|
|
|
|
| 29 |
|
| 30 |
if model_path and os.path.exists(model_path):
|
| 31 |
try:
|
| 32 |
+
# No custom_objects needed for the .keras format in this case
|
| 33 |
loaded_model_lstm = tf.keras.models.load_model(model_path)
|
| 34 |
print("β
Model loaded successfully.")
|
| 35 |
except Exception as e:
|
|
|
|
| 42 |
except Exception as e:
|
| 43 |
print(f"β Error loading scalers: {e}")
|
| 44 |
|
|
|
|
|
|
|
| 45 |
ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
|
| 46 |
|
| 47 |
def get_ticker_from_input(input_name):
|
|
|
|
| 48 |
return input_name.upper()
|
| 49 |
|
| 50 |
+
# --- 3. The Main Forecasting Function ---
|
| 51 |
def forecast_stock(input_name, model, scalers_dict, input_width=60):
|
| 52 |
if not model or not scalers_dict:
|
| 53 |
+
return "Error: Model or scalers not loaded. The backend may have failed to start."
|
| 54 |
|
| 55 |
ticker = get_ticker_from_input(input_name)
|
| 56 |
if not ticker:
|
| 57 |
return "Error: Invalid stock ticker."
|
| 58 |
|
| 59 |
print(f"\n--- Generating forecast for {ticker} ---")
|
|
|
|
|
|
|
| 60 |
try:
|
| 61 |
data_df = yf.download(ticker, period="1y", progress=False)
|
| 62 |
if data_df.empty:
|
|
|
|
| 64 |
except Exception as e:
|
| 65 |
return f"Error fetching data for {ticker}: {e}"
|
| 66 |
|
|
|
|
| 67 |
if len(data_df) < input_width:
|
| 68 |
+
return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
|
| 69 |
|
| 70 |
recent_data = data_df.tail(input_width)
|
| 71 |
close_prices = recent_data['Close'].values.reshape(-1, 1)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
| 73 |
scaler = scalers_dict.get(ticker)
|
| 74 |
if not scaler:
|
| 75 |
print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
|
|
|
|
| 84 |
prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
|
| 85 |
|
| 86 |
last_close = recent_data['Close'].iloc[-1]
|
|
|
|
| 87 |
result = (
|
| 88 |
f"Last known close for {ticker}: ${last_close:.2f}\n"
|
| 89 |
f"Predicted next day's close price: ${prediction_actual:.2f}"
|
|
|
|
| 92 |
return result
|
| 93 |
|
| 94 |
# --- 4. Create the Gradio Interface ---
|
|
|
|
| 95 |
def predict_api(ticker_symbol):
|
| 96 |
return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
|
| 97 |
|
|
|
|
|
|
|
| 98 |
with gr.Blocks() as app:
|
| 99 |
gr.Markdown("This is the backend for the React Stock Forecaster App.")
|
|
|
|
| 100 |
ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
|
| 101 |
output_text = gr.Textbox(label="Forecast", visible=False)
|
|
|
|
|
|
|
|
|
|
| 102 |
ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
|
| 103 |
|
| 104 |
# --- 5. Mount and Serve the React App's Static Files ---
|
| 105 |
+
# FIX #2: Changed to the correct function call format for Gradio
|
| 106 |
+
app = gr.mount_static_directory(app, "build")
|
|
|
|
|
|
|
| 107 |
|
| 108 |
# Launch the server
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
app.launch()
|