munem420 commited on
Commit
eeb9a7a
Β·
verified Β·
1 Parent(s): 7df7c8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -51
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import gradio as gr
4
  import tensorflow as tf
5
  import joblib
@@ -8,105 +7,127 @@ import pandas as pd
8
  import yfinance as yf
9
  from huggingface_hub import hf_hub_download
10
 
11
- # This forces TensorFlow to only use the CPU.
 
12
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
13
 
14
- # --- 1. Download Model and Scalers from Hugging Face Hub ---
15
- # FIX #1: Ensured the repository name is 100% correct.
16
  MODEL_REPO = "munem420/stock-forecaster-lstm"
17
- MODEL_FILENAME = "model_lstm.h5"
18
  SCALER_FILENAME = "scalers.joblib"
19
 
20
- print("--- Downloading model and scalers ---")
21
  try:
22
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
23
  scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
24
  print("βœ… Files downloaded successfully.")
25
  except Exception as e:
26
- print(f"❌ Error downloading files: {e}")
 
27
  model_path, scalers_path = None, None
28
 
29
- # --- 2. Load the Model and Scalers ---
30
  loaded_model_lstm = None
31
  loaded_scalers = None
32
 
33
  if model_path and os.path.exists(model_path):
34
  try:
35
  loaded_model_lstm = tf.keras.models.load_model(model_path)
36
- print("βœ… Model loaded successfully.")
37
  except Exception as e:
38
- print(f"❌ Error loading model: {e}")
39
 
40
  if scalers_path and os.path.exists(scalers_path):
41
  try:
42
  loaded_scalers = joblib.load(scalers_path)
43
  print("βœ… Scalers loaded successfully.")
44
  except Exception as e:
45
- print(f"❌ Error loading scalers: {e}")
46
-
47
- ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
48
-
49
- def get_ticker_from_input(input_name):
50
- return input_name.upper()
51
-
52
- # --- 3. The Main Forecasting Function ---
53
- def forecast_stock(input_name, model, scalers_dict, input_width=60):
54
- if not model or not scalers_dict:
55
- return "Error: Model or scalers not loaded. The backend may have failed to start."
56
- # ... (The rest of the function is the same)
57
- ticker = get_ticker_from_input(input_name)
58
  if not ticker:
59
- return "Error: Invalid stock ticker."
60
 
61
  print(f"\n--- Generating forecast for {ticker} ---")
 
 
62
  try:
63
- data_df = yf.download(ticker, period="1y", progress=False)
 
64
  if data_df.empty:
65
- return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
66
  except Exception as e:
67
- return f"Error fetching data for {ticker}: {e}"
68
 
69
  if len(data_df) < input_width:
70
- return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
71
 
 
72
  recent_data = data_df.tail(input_width)
73
  close_prices = recent_data['Close'].values.reshape(-1, 1)
74
 
75
- scaler = scalers_dict.get(ticker)
 
 
76
  if not scaler:
77
  print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
78
- scaler = scalers_dict.get('ZURVY')
79
  if not scaler:
80
- return "Error: Default scaler 'ZURVY' not found."
81
 
82
- scaled_data = scaler.transform(close_prices)
83
- X_pred = scaled_data.reshape(1, input_width, 1)
 
 
84
 
85
- prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
86
- prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
 
 
87
 
 
88
  last_close = recent_data['Close'].iloc[-1]
89
- result = (
90
- f"Last known close for {ticker}: ${last_close:.2f}\n"
91
- f"Predicted next day's close price: ${prediction_actual:.2f}"
 
92
  )
93
- print(result)
94
- return result
95
-
96
- # --- 4. Create the Gradio Interface ---
97
- def predict_api(ticker_symbol):
98
- return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
99
-
100
- with gr.Blocks() as app:
101
- gr.Markdown("This is the backend for the React Stock Forecaster App.")
 
 
 
102
  ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
103
  output_text = gr.Textbox(label="Forecast", visible=False)
104
- ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
 
 
 
 
 
 
 
105
 
106
- # --- 5. Mount and Serve the React App's Static Files ---
107
- # This function requires a modern version of Gradio, specified in README.md
108
  app = gr.mount_static_directory(app, "build")
109
 
110
- # Launch the server
111
  if __name__ == "__main__":
112
  app.launch()
 
1
  import os
 
2
  import gradio as gr
3
  import tensorflow as tf
4
  import joblib
 
7
  import yfinance as yf
8
  from huggingface_hub import hf_hub_download
9
 
10
+ # --- 0. Force CPU-only mode for TensorFlow ---
11
+ # This prevents TensorFlow from trying to allocate GPU memory on a CPU-only instance.
12
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
13
 
14
+ # --- 1. Define Constants and Download Model/Scalers from Hugging Face Hub ---
 
15
  MODEL_REPO = "munem420/stock-forecaster-lstm"
16
+ MODEL_FILENAME = "model_lstm.keras"
17
  SCALER_FILENAME = "scalers.joblib"
18
 
19
+ print("--- Downloading model and scalers from Hugging Face Hub ---")
20
  try:
21
  model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
22
  scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
23
  print("βœ… Files downloaded successfully.")
24
  except Exception as e:
25
+ print(f"❌ Critical Error: Could not download files from the Hub. {e}")
26
+ # Set paths to None so the app knows that loading failed.
27
  model_path, scalers_path = None, None
28
 
29
+ # --- 2. Load the Model and Scalers into Memory ---
30
  loaded_model_lstm = None
31
  loaded_scalers = None
32
 
33
  if model_path and os.path.exists(model_path):
34
  try:
35
  loaded_model_lstm = tf.keras.models.load_model(model_path)
36
+ print("βœ… TensorFlow model loaded successfully.")
37
  except Exception as e:
38
+ print(f"❌ Critical Error: Could not load the TensorFlow model. {e}")
39
 
40
  if scalers_path and os.path.exists(scalers_path):
41
  try:
42
  loaded_scalers = joblib.load(scalers_path)
43
  print("βœ… Scalers loaded successfully.")
44
  except Exception as e:
45
+ print(f"❌ Critical Error: Could not load the scalers file. {e}")
46
+
47
+ # --- 3. The Core Forecasting Function ---
48
+ def forecast_stock(input_name: str, input_width: int = 60) -> str:
49
+ """
50
+ Fetches live stock data, preprocesses it, and returns a prediction string.
51
+ """
52
+ # Fail fast if the model/scalers didn't load during startup
53
+ if not loaded_model_lstm or not loaded_scalers:
54
+ return "Error: Model or scalers are not loaded. The backend may have failed to start correctly. Check the Space logs."
55
+
56
+ ticker = input_name.strip().upper()
 
57
  if not ticker:
58
+ return "Error: Please enter a stock ticker."
59
 
60
  print(f"\n--- Generating forecast for {ticker} ---")
61
+
62
+ # Fetch recent data using yfinance
63
  try:
64
+ # Fetch more than needed to ensure we have enough valid trading days
65
+ data_df = yf.download(ticker, period="200d", progress=False)
66
  if data_df.empty:
67
+ return f"Error: No data found for ticker '{ticker}'. It may be delisted or an invalid symbol."
68
  except Exception as e:
69
+ return f"Error fetching data for '{ticker}': {e}"
70
 
71
  if len(data_df) < input_width:
72
+ return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only found {len(data_df)}."
73
 
74
+ # Prepare the data for the model
75
  recent_data = data_df.tail(input_width)
76
  close_prices = recent_data['Close'].values.reshape(-1, 1)
77
 
78
+ # Find the correct scaler. The original model was trained on specific stocks.
79
+ # We try to find a matching scaler, otherwise, we use a default as a fallback.
80
+ scaler = loaded_scalers.get(ticker)
81
  if not scaler:
82
  print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
83
+ scaler = loaded_scalers.get('ZURVY')
84
  if not scaler:
85
+ return "Error: Critical failure. The default 'ZURVY' scaler could not be found."
86
 
87
+ # Scale the data and make a prediction
88
+ try:
89
+ scaled_data = scaler.transform(close_prices)
90
+ X_pred = scaled_data.reshape(1, input_width, 1) # Reshape for LSTM: [batch, timesteps, features]
91
 
92
+ prediction_scaled = loaded_model_lstm.predict(X_pred, verbose=0)[0][0]
93
+ prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
94
+ except Exception as e:
95
+ return f"An error occurred during model prediction: {e}"
96
 
97
+ # Format the final result
98
  last_close = recent_data['Close'].iloc[-1]
99
+ result_str = (
100
+ f"Forecast for: {ticker}\n"
101
+ f"Last Close Price: ${last_close:.2f}\n"
102
+ f"Predicted Next Day's Close: ${prediction_actual:.2f}"
103
  )
104
+ print(result_str)
105
+ return result_str
106
+
107
+ # --- 4. Create the Gradio Interface and API Endpoint ---
108
+ def predict_api(ticker_symbol: str) -> str:
109
+ """A simple wrapper for the main forecast function to be exposed as an API."""
110
+ return forecast_stock(ticker_symbol)
111
+
112
+ with gr.Blocks(title="Stock Forecaster Backend") as app:
113
+ gr.Markdown("## Stock Forecaster Backend\nThis Gradio app serves the API for the React frontend.")
114
+
115
+ # These components are not visible but are required to create the API endpoint
116
  ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
117
  output_text = gr.Textbox(label="Forecast", visible=False)
118
+
119
+ # This creates the API endpoint at /run/predict
120
+ ticker_input.submit(
121
+ fn=predict_api,
122
+ inputs=[ticker_input],
123
+ outputs=[output_text],
124
+ api_name="predict"
125
+ )
126
 
127
+ # --- 5. Mount the static React build directory to be served ---
128
+ # This requires a recent version of Gradio (e.g., 4.x), specified in README.md
129
  app = gr.mount_static_directory(app, "build")
130
 
131
+ # --- 6. Launch the Gradio App ---
132
  if __name__ == "__main__":
133
  app.launch()