munem420 commited on
Commit
2a568a4
·
verified ·
1 Parent(s): 018d39f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -17
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
-
3
  import gradio as gr
4
  import tensorflow as tf
5
  import joblib
@@ -12,9 +11,8 @@ from huggingface_hub import hf_hub_download
12
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
13
 
14
  # --- 1. Download Model and Scalers from Hugging Face Hub ---
15
- # FIX #1: Ensured the repository name is 100% correct.
16
  MODEL_REPO = "munem420/stock-forecaster-lstm"
17
- MODEL_FILENAME = "model_lstm.h5"
18
  SCALER_FILENAME = "scalers.joblib"
19
 
20
  print("--- Downloading model and scalers ---")
@@ -32,7 +30,11 @@ loaded_scalers = None
32
 
33
  if model_path and os.path.exists(model_path):
34
  try:
35
- loaded_model_lstm = tf.keras.models.load_model(model_path)
 
 
 
 
36
  print("✅ Model loaded successfully.")
37
  except Exception as e:
38
  print(f"❌ Error loading model: {e}")
@@ -44,20 +46,19 @@ if scalers_path and os.path.exists(scalers_path):
44
  except Exception as e:
45
  print(f"❌ Error loading scalers: {e}")
46
 
 
 
47
  ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
48
 
49
  def get_ticker_from_input(input_name):
50
  return input_name.upper()
51
 
52
- # --- 3. The Main Forecasting Function ---
53
  def forecast_stock(input_name, model, scalers_dict, input_width=60):
54
  if not model or not scalers_dict:
55
  return "Error: Model or scalers not loaded. The backend may have failed to start."
56
- # ... (The rest of the function is the same)
57
  ticker = get_ticker_from_input(input_name)
58
  if not ticker:
59
  return "Error: Invalid stock ticker."
60
-
61
  print(f"\n--- Generating forecast for {ticker} ---")
62
  try:
63
  data_df = yf.download(ticker, period="1y", progress=False)
@@ -65,26 +66,20 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
65
  return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
66
  except Exception as e:
67
  return f"Error fetching data for {ticker}: {e}"
68
-
69
  if len(data_df) < input_width:
70
  return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
71
-
72
  recent_data = data_df.tail(input_width)
73
  close_prices = recent_data['Close'].values.reshape(-1, 1)
74
-
75
  scaler = scalers_dict.get(ticker)
76
  if not scaler:
77
  print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
78
  scaler = scalers_dict.get('ZURVY')
79
  if not scaler:
80
  return "Error: Default scaler 'ZURVY' not found."
81
-
82
  scaled_data = scaler.transform(close_prices)
83
  X_pred = scaled_data.reshape(1, input_width, 1)
84
-
85
  prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
86
  prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
87
-
88
  last_close = recent_data['Close'].iloc[-1]
89
  result = (
90
  f"Last known close for {ticker}: ${last_close:.2f}\n"
@@ -93,7 +88,6 @@ def forecast_stock(input_name, model, scalers_dict, input_width=60):
93
  print(result)
94
  return result
95
 
96
- # --- 4. Create the Gradio Interface ---
97
  def predict_api(ticker_symbol):
98
  return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
99
 
@@ -103,10 +97,7 @@ with gr.Blocks() as app:
103
  output_text = gr.Textbox(label="Forecast", visible=False)
104
  ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
105
 
106
- # --- 5. Mount and Serve the React App's Static Files ---
107
- # This function requires a modern version of Gradio, specified in README.md
108
  app = gr.mount_static_directory(app, "build")
109
 
110
- # Launch the server
111
  if __name__ == "__main__":
112
  app.launch()
 
1
  import os
 
2
  import gradio as gr
3
  import tensorflow as tf
4
  import joblib
 
11
  os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
12
 
13
  # --- 1. Download Model and Scalers from Hugging Face Hub ---
 
14
  MODEL_REPO = "munem420/stock-forecaster-lstm"
15
+ MODEL_FILENAME = "model_lstm.keras"
16
  SCALER_FILENAME = "scalers.joblib"
17
 
18
  print("--- Downloading model and scalers ---")
 
30
 
31
  if model_path and os.path.exists(model_path):
32
  try:
33
+ # FIX #1: Added custom_objects to handle the 'mse' metric during loading
34
+ loaded_model_lstm = tf.keras.models.load_model(
35
+ model_path,
36
+ custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
37
+ )
38
  print("✅ Model loaded successfully.")
39
  except Exception as e:
40
  print(f"❌ Error loading model: {e}")
 
46
  except Exception as e:
47
  print(f"❌ Error loading scalers: {e}")
48
 
49
+ # ... (The rest of the file is unchanged) ...
50
+
51
  ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
52
 
53
  def get_ticker_from_input(input_name):
54
  return input_name.upper()
55
 
 
56
  def forecast_stock(input_name, model, scalers_dict, input_width=60):
57
  if not model or not scalers_dict:
58
  return "Error: Model or scalers not loaded. The backend may have failed to start."
 
59
  ticker = get_ticker_from_input(input_name)
60
  if not ticker:
61
  return "Error: Invalid stock ticker."
 
62
  print(f"\n--- Generating forecast for {ticker} ---")
63
  try:
64
  data_df = yf.download(ticker, period="1y", progress=False)
 
66
  return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
67
  except Exception as e:
68
  return f"Error fetching data for {ticker}: {e}"
 
69
  if len(data_df) < input_width:
70
  return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
 
71
  recent_data = data_df.tail(input_width)
72
  close_prices = recent_data['Close'].values.reshape(-1, 1)
 
73
  scaler = scalers_dict.get(ticker)
74
  if not scaler:
75
  print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
76
  scaler = scalers_dict.get('ZURVY')
77
  if not scaler:
78
  return "Error: Default scaler 'ZURVY' not found."
 
79
  scaled_data = scaler.transform(close_prices)
80
  X_pred = scaled_data.reshape(1, input_width, 1)
 
81
  prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
82
  prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
 
83
  last_close = recent_data['Close'].iloc[-1]
84
  result = (
85
  f"Last known close for {ticker}: ${last_close:.2f}\n"
 
88
  print(result)
89
  return result
90
 
 
91
  def predict_api(ticker_symbol):
92
  return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
93
 
 
97
  output_text = gr.Textbox(label="Forecast", visible=False)
98
  ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
99
 
 
 
100
  app = gr.mount_static_directory(app, "build")
101
 
 
102
  if __name__ == "__main__":
103
  app.launch()