munem420 commited on
Commit
dfb42e6
ยท
verified ยท
1 Parent(s): 4ef68bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -69
app.py CHANGED
@@ -3,92 +3,76 @@ import gradio as gr
3
  import tensorflow as tf
4
  import joblib
5
  import numpy as np
6
- import pandas as pd
7
- from huggingface_hub import hf_hub_download
8
-
9
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
10
 
 
 
 
11
  MODEL_REPO = "munem420/stock-forecaster-lstm"
12
  MODEL_FILENAME = "model_lstm.h5"
13
  SCALER_FILENAME = "scalers.joblib"
14
 
15
- print("--- Downloading model and scalers ---")
16
- try:
17
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
18
- scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
19
- print("โœ… Files downloaded successfully.")
20
- except Exception as e:
21
- print(f"โŒ Error downloading files: {e}")
22
- model_path, scalers_path = None, None
23
-
24
- loaded_model_lstm = None
25
- loaded_scalers = None
26
-
27
- if model_path and os.path.exists(model_path):
28
- try:
29
- loaded_model_lstm = tf.keras.models.load_model(
30
- model_path,
31
- custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
32
- )
33
- print("โœ… Model loaded successfully.")
34
- except Exception as e:
35
- print(f"โŒ Error loading model: {e}")
36
 
37
- if scalers_path and os.path.exists(scalers_path):
38
- try:
39
- loaded_scalers = joblib.load(scalers_path)
40
- print("โœ… Scalers loaded successfully.")
41
- except Exception as e:
42
- print(f"โŒ Error loading scalers: {e}")
43
 
44
- ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
45
-
46
- # Example placeholder DataFrame (replace with your actual data)
47
- data_df = pd.DataFrame({
48
- "Date": pd.date_range(start="2024-01-01", periods=100),
49
- "Close": np.linspace(100, 200, 100)
50
- })
51
-
52
- def get_ticker_from_input(input_name):
53
- return input_name.upper().strip()
54
-
55
- def forecast_stock(input_name):
56
- model = loaded_model_lstm
57
- scalers_dict = loaded_scalers
58
- input_width = 60
 
 
 
 
 
59
 
60
- if not model or not scalers_dict:
61
- return "Error: Model or scalers not loaded."
 
 
 
 
62
 
63
- ticker = get_ticker_from_input(input_name)
64
- if len(data_df) < input_width:
65
- return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
66
 
67
- recent_data = data_df.tail(input_width)
68
- close_prices = recent_data['Close'].values.reshape(-1, 1)
69
- scaler = scalers_dict.get(ticker)
 
 
70
 
71
- if not scaler:
72
- print(f"โš ๏ธ No specific scaler for {ticker}. Using fallback.")
73
- scaler = scalers_dict.get('ZURVY')
74
- if not scaler:
75
- return "Error: No default scaler found."
76
 
77
- scaled_data = scaler.transform(close_prices)
78
- X_pred = scaled_data.reshape(1, input_width, 1)
79
- prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
80
- prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
81
- last_close = recent_data['Close'].iloc[-1]
82
 
83
- return f"Last close for {ticker}: ${last_close:.2f}\nPredicted next day close: ${prediction_actual:.2f}"
84
 
85
- # โœ… Simple Gradio interface
 
 
86
  iface = gr.Interface(
87
  fn=forecast_stock,
88
  inputs=gr.Textbox(label="Enter Ticker or Company Name"),
89
- outputs=gr.Textbox(label="Predicted Next Day Close"),
90
- title="Stock Price Forecaster (LSTM)",
91
- description="Enter a stock ticker or company name to predict the next day's close price."
92
  )
93
 
94
  if __name__ == "__main__":
 
3
  import tensorflow as tf
4
  import joblib
5
  import numpy as np
 
 
 
 
6
 
7
+ # -------------------------------------------------------
8
+ # CONFIG
9
+ # -------------------------------------------------------
10
  MODEL_REPO = "munem420/stock-forecaster-lstm"
11
  MODEL_FILENAME = "model_lstm.h5"
12
  SCALER_FILENAME = "scalers.joblib"
13
 
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # -------------------------------------------------------
17
+ # LOAD MODEL AND SCALERS
18
+ # -------------------------------------------------------
19
+ print("๐Ÿ“ฆ Loading model and scalers...")
 
 
20
 
21
+ try:
22
+ model_path = tf.keras.utils.get_file(
23
+ MODEL_FILENAME,
24
+ f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILENAME}"
25
+ )
26
+ scalers_path = tf.keras.utils.get_file(
27
+ SCALER_FILENAME,
28
+ f"https://huggingface.co/{MODEL_REPO}/resolve/main/{SCALER_FILENAME}"
29
+ )
30
+
31
+ model = tf.keras.models.load_model(
32
+ model_path,
33
+ custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
34
+ )
35
+ scalers = joblib.load(scalers_path)
36
+
37
+ print("โœ… Model and scalers loaded successfully.")
38
+ except Exception as e:
39
+ print(f"โŒ Error loading model or scalers: {e}")
40
+ model, scalers = None, None
41
 
42
+ # -------------------------------------------------------
43
+ # FORECAST FUNCTION
44
+ # -------------------------------------------------------
45
+ def forecast_stock(ticker):
46
+ if not model or not scalers:
47
+ return "โŒ Model or scalers not loaded properly."
48
 
49
+ ticker = ticker.strip().upper()
50
+ if ticker not in scalers:
51
+ return f"โš ๏ธ No scaler found for ticker '{ticker}'. Please check spelling."
52
 
53
+ # Dummy inference example (replace with actual data fetching or preprocessing)
54
+ # Here we just simulate 60 normalized close prices for inference
55
+ scaler = scalers[ticker]
56
+ dummy_data = np.linspace(0.9, 1.1, 60).reshape(-1, 1)
57
+ X_pred = dummy_data.reshape(1, 60, 1)
58
 
59
+ # Predict scaled value
60
+ pred_scaled = model.predict(X_pred, verbose=0)[0][0]
 
 
 
61
 
62
+ # Inverse transform prediction
63
+ pred_actual = scaler.inverse_transform(np.array([[pred_scaled]]))[0][0]
 
 
 
64
 
65
+ return f"๐Ÿ”ฎ Predicted next day close for **{ticker}**: ${pred_actual:.2f}"
66
 
67
+ # -------------------------------------------------------
68
+ # GRADIO INTERFACE
69
+ # -------------------------------------------------------
70
  iface = gr.Interface(
71
  fn=forecast_stock,
72
  inputs=gr.Textbox(label="Enter Ticker or Company Name"),
73
+ outputs=gr.Markdown(label="Prediction Result"),
74
+ title="๐Ÿ“Š Stock Price Forecaster (LSTM)",
75
+ description="Enter a stock ticker or company name to predict the next day's closing price using a trained LSTM model."
76
  )
77
 
78
  if __name__ == "__main__":