munem420 commited on
Commit
daa61d7
Β·
verified Β·
1 Parent(s): 1018815

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -27
app.py CHANGED
@@ -8,9 +8,9 @@ import yfinance as yf
8
  from huggingface_hub import hf_hub_download
9
 
10
  # --- 1. Download Model and Scalers from Hugging Face Hub ---
11
- # This is better than manually uploading them. The Space will fetch them automatically.
12
  MODEL_REPO = "munem420/stock-forecaster-lstm"
13
- MODEL_FILENAME = "model_lstm.h5"
 
14
  SCALER_FILENAME = "scalers.joblib"
15
 
16
  print("--- Downloading model and scalers ---")
@@ -20,6 +20,7 @@ try:
20
  print("βœ… Files downloaded successfully.")
21
  except Exception as e:
22
  print(f"❌ Error downloading files: {e}")
 
23
  model_path, scalers_path = None, None
24
 
25
  # --- 2. Load the Model and Scalers ---
@@ -28,6 +29,7 @@ loaded_scalers = None
28
 
29
  if model_path and os.path.exists(model_path):
30
  try:
 
31
  loaded_model_lstm = tf.keras.models.load_model(model_path)
32
  print("βœ… Model loaded successfully.")
33
  except Exception as e:
@@ -40,26 +42,21 @@ if scalers_path and os.path.exists(scalers_path):
40
  except Exception as e:
41
  print(f"❌ Error loading scalers: {e}")
42
 
43
- # This dictionary is part of the original model's logic.
44
- # A more robust solution would fetch this dynamically or store it better.
45
  ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
46
 
47
  def get_ticker_from_input(input_name):
48
- # Simplified version for this app
49
  return input_name.upper()
50
 
51
- # --- 3. The Main Forecasting Function (Adapted from your code) ---
52
  def forecast_stock(input_name, model, scalers_dict, input_width=60):
53
  if not model or not scalers_dict:
54
- return "Error: Model or scalers not loaded."
55
 
56
  ticker = get_ticker_from_input(input_name)
57
  if not ticker:
58
  return "Error: Invalid stock ticker."
59
 
60
  print(f"\n--- Generating forecast for {ticker} ---")
61
-
62
- # Fetch recent data using yfinance
63
  try:
64
  data_df = yf.download(ticker, period="1y", progress=False)
65
  if data_df.empty:
@@ -67,16 +64,12 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
67
  except Exception as e:
68
  return f"Error fetching data for {ticker}: {e}"
69
 
70
-
71
  if len(data_df) < input_width:
72
- return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only have {len(data_df)}."
73
 
74
  recent_data = data_df.tail(input_width)
75
  close_prices = recent_data['Close'].values.reshape(-1, 1)
76
 
77
- # Note: The original scalers were trained on specific stocks.
78
- # Using a scaler for a different stock (e.g., AAPL) on a new ticker might not be accurate.
79
- # For this example, we'll try to find a matching scaler or default to a common one.
80
  scaler = scalers_dict.get(ticker)
81
  if not scaler:
82
  print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
@@ -91,7 +84,6 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
91
  prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
92
 
93
  last_close = recent_data['Close'].iloc[-1]
94
-
95
  result = (
96
  f"Last known close for {ticker}: ${last_close:.2f}\n"
97
  f"Predicted next day's close price: ${prediction_actual:.2f}"
@@ -100,27 +92,19 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
100
  return result
101
 
102
  # --- 4. Create the Gradio Interface ---
103
- # We create a simple function that Gradio can expose as an API endpoint.
104
  def predict_api(ticker_symbol):
105
  return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
106
 
107
- # We use a dummy Gradio interface because we only need its backend API capabilities.
108
- # The `gr.Blocks()` allows us to run the server without displaying a UI.
109
  with gr.Blocks() as app:
110
  gr.Markdown("This is the backend for the React Stock Forecaster App.")
111
- # This creates an API endpoint at `/run/predict`
112
  ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
113
  output_text = gr.Textbox(label="Forecast", visible=False)
114
-
115
- # The Gradio API function must be tied to an event
116
- # We will call this endpoint from our React app.
117
  ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
118
 
119
  # --- 5. Mount and Serve the React App's Static Files ---
120
- # Before running this, you must build your React app using `npm run build`.
121
- # This will create a `build` directory with static files.
122
- # Gradio will serve the `index.html` from this directory.
123
- app.mount_static_directory("./build")
124
 
125
  # Launch the server
126
- app.launch()
 
 
8
  from huggingface_hub import hf_hub_download
9
 
10
  # --- 1. Download Model and Scalers from Hugging Face Hub ---
 
11
  MODEL_REPO = "munem420/stock-forecaster-lstm"
12
+ # FIX #1: Corrected model filename from .h5 to .keras
13
+ MODEL_FILENAME = "model_lstm.keras"
14
  SCALER_FILENAME = "scalers.joblib"
15
 
16
  print("--- Downloading model and scalers ---")
 
20
  print("βœ… Files downloaded successfully.")
21
  except Exception as e:
22
  print(f"❌ Error downloading files: {e}")
23
+ # Exit gracefully if files can't be downloaded
24
  model_path, scalers_path = None, None
25
 
26
  # --- 2. Load the Model and Scalers ---
 
29
 
30
  if model_path and os.path.exists(model_path):
31
  try:
32
+ # No custom_objects needed for the .keras format in this case
33
  loaded_model_lstm = tf.keras.models.load_model(model_path)
34
  print("βœ… Model loaded successfully.")
35
  except Exception as e:
 
42
  except Exception as e:
43
  print(f"❌ Error loading scalers: {e}")
44
 
 
 
45
  ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
46
 
47
  def get_ticker_from_input(input_name):
 
48
  return input_name.upper()
49
 
50
+ # --- 3. The Main Forecasting Function ---
51
  def forecast_stock(input_name, model, scalers_dict, input_width=60):
52
  if not model or not scalers_dict:
53
+ return "Error: Model or scalers not loaded. The backend may have failed to start."
54
 
55
  ticker = get_ticker_from_input(input_name)
56
  if not ticker:
57
  return "Error: Invalid stock ticker."
58
 
59
  print(f"\n--- Generating forecast for {ticker} ---")
 
 
60
  try:
61
  data_df = yf.download(ticker, period="1y", progress=False)
62
  if data_df.empty:
 
64
  except Exception as e:
65
  return f"Error fetching data for {ticker}: {e}"
66
 
 
67
  if len(data_df) < input_width:
68
+ return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
69
 
70
  recent_data = data_df.tail(input_width)
71
  close_prices = recent_data['Close'].values.reshape(-1, 1)
72
 
 
 
 
73
  scaler = scalers_dict.get(ticker)
74
  if not scaler:
75
  print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
 
84
  prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
85
 
86
  last_close = recent_data['Close'].iloc[-1]
 
87
  result = (
88
  f"Last known close for {ticker}: ${last_close:.2f}\n"
89
  f"Predicted next day's close price: ${prediction_actual:.2f}"
 
92
  return result
93
 
94
  # --- 4. Create the Gradio Interface ---
 
95
  def predict_api(ticker_symbol):
96
  return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
97
 
 
 
98
  with gr.Blocks() as app:
99
  gr.Markdown("This is the backend for the React Stock Forecaster App.")
 
100
  ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
101
  output_text = gr.Textbox(label="Forecast", visible=False)
 
 
 
102
  ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
103
 
104
  # --- 5. Mount and Serve the React App's Static Files ---
105
+ # FIX #2: Changed to the correct function call format for Gradio
106
+ app = gr.mount_static_directory(app, "build")
 
 
107
 
108
  # Launch the server
109
+ if __name__ == "__main__":
110
+ app.launch()