munem420 commited on
Commit
6d7a718
·
verified ·
1 Parent(s): f0c6972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -69
app.py CHANGED
@@ -1,76 +1,120 @@
1
  import os
2
- import pandas as pd
3
- import numpy as np
4
- import joblib
5
- import tensorflow as tf
6
- from fuzzywuzzy import process
7
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # --- Load data ---
10
- DATA_PATH = "/kaggle/working/all_stocks_long.csv" # update if needed
11
- MODEL_DIR = "/kaggle/working/stock-forecaster-lstm"
12
- MODEL_FILE = "model_lstm.h5"
13
- SCALER_FILE = "scalers.joblib"
14
-
15
- combined_fe = pd.read_csv(DATA_PATH, parse_dates=['date'])
16
- combined_fe['ticker'] = combined_fe['ticker'].str.upper()
17
-
18
- # --- Load model ---
19
- model_path = os.path.join(MODEL_DIR, MODEL_FILE)
20
- scaler_path = os.path.join(MODEL_DIR, SCALER_FILE)
21
-
22
- loaded_model = tf.keras.models.load_model(model_path)
23
- loaded_scalers = joblib.load(scaler_path)
24
-
25
- # --- Build ticker <-> company mappings ---
26
- top_tickers = combined_fe['ticker'].unique()
27
- ticker_to_name = {t: t for t in top_tickers} # can update with real names if available
28
- name_to_ticker = {v: k for k,v in ticker_to_name.items()}
29
-
30
- # --- Prediction helpers ---
31
- def get_ticker_from_input(input_str):
32
- if input_str.upper() in ticker_to_name:
33
- return input_str.upper()
34
- if input_str in name_to_ticker:
35
- return name_to_ticker[input_str]
36
- best_match, score = process.extractOne(input_str, name_to_ticker.keys())
37
- if score > 80:
38
- return name_to_ticker[best_match]
39
- return None
40
-
41
- def make_windows(series, input_width=60, horizon=1):
42
- arr = series.values.astype(np.float32)
43
- X, y = [], []
44
- for i in range(input_width, len(arr)-horizon+1):
45
- X.append(arr[i-input_width:i])
46
- y.append(arr[i + (horizon-1)])
47
- return np.array(X), np.array(y)
48
-
49
- def forecast_stock(input_name):
50
  ticker = get_ticker_from_input(input_name)
51
  if not ticker:
52
- return f"Could not find ticker for '{input_name}'"
53
-
54
- recent_data = combined_fe[combined_fe['ticker']==ticker].sort_values('date').tail(60)
55
- if len(recent_data) < 60:
56
- return f"Not enough historical data for {ticker}."
57
 
58
- close_prices = recent_data['close'].values.reshape(-1, 1)
59
- scaler = loaded_scalers[ticker]
 
 
 
 
 
 
 
 
60
  scaled_data = scaler.transform(close_prices)
61
- X_pred = scaled_data.reshape(1, 60, 1)
62
-
63
- prediction_scaled = loaded_model.predict(X_pred, verbose=0)[0][0]
64
- prediction_actual = scaler.inverse_transform([[prediction_scaled]])[0][0]
65
- return round(prediction_actual, 2)
66
-
67
- # --- Gradio Interface ---
68
- iface = gr.Interface(
69
- fn=forecast_stock,
70
- inputs=gr.Textbox(label="Enter Ticker or Company Name"),
71
- outputs=gr.Textbox(label="Predicted Next Day Close"),
72
- title="Stock Price Forecaster (LSTM)",
73
- description="Enter a stock ticker or company name to predict the next day's close price."
74
- )
75
-
76
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
 
 
2
  import gradio as gr
3
+ import tensorflow as tf
4
+ import joblib
5
+ import numpy as np
6
+ import pandas as pd
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
10
+
11
+
12
+
13
+ MODEL_REPO = "munem420/stock-forecaster-lstm"
14
+ MODEL_FILENAME = "model_lstm.h5"
15
+ SCALER_FILENAME = "scalers.joblib"
16
+
17
+
18
+
19
+ print("--- Downloading model and scalers ---")
20
+ try:
21
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
22
+ scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
23
+ print("✅ Files downloaded successfully.")
24
+ except Exception as e:
25
+ print(f"❌ Error downloading files: {e}")
26
+ model_path, scalers_path = None, None
27
+
28
+ loaded_model_lstm = None
29
+ loaded_scalers = None
30
+
31
+
32
+ if model_path and os.path.exists(model_path):
33
+ try:
34
+ loaded_model_lstm = tf.keras.models.load_model(
35
+ model_path,
36
+ custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
37
+ )
38
+ print("✅ Model loaded successfully.")
39
+ except Exception as e:
40
+ print(f"❌ Error loading model: {e}")
41
+
42
+ if scalers_path and os.path.exists(scalers_path):
43
+ try:
44
+ loaded_scalers = joblib.load(scalers_path)
45
+ print("✅ Scalers loaded successfully.")
46
+ except Exception as e:
47
+ print(f"❌ Error loading scalers: {e}")
48
+
49
+ ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
50
+
51
+
52
+
53
+
54
+
55
+
56
 
57
+
58
+
59
+
60
+ def get_ticker_from_input(input_name):
61
+ return input_name.upper()
62
+
63
+
64
+
65
+
66
+
67
+
68
+ def forecast_stock(input_name, model, scalers_dict, input_width=60):
69
+ if not model or not scalers_dict:
70
+ return "Error: Model or scalers not loaded. The backend may have failed to start."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ticker = get_ticker_from_input(input_name)
72
  if not ticker:
73
+ return "Error: Invalid stock ticker."
74
+ print(f"\n--- Generating forecast for {ticker} ---")
75
+
76
+
77
+
78
 
79
+ if len(data_df) < input_width:
80
+ return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
81
+ recent_data = data_df.tail(input_width)
82
+ close_prices = recent_data['Close'].values.reshape(-input, 1)
83
+ scaler = scalers_dict.get(ticker)
84
+ if not scaler:
85
+ print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
86
+ scaler = scalers_dict.get('ZURVY')
87
+ if not scaler:
88
+ return "Error: Default scaler 'ZURVY' not found."
89
  scaled_data = scaler.transform(close_prices)
90
+ X_pred = scaled_data.reshape(1, input_width, 1)
91
+ prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
92
+ prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
93
+ last_close = recent_data['Close'].iloc[-1]
94
+ result = (
95
+ f"Last known close for {ticker}: ${last_close:.2f}\n"
96
+ f"Predicted next day's close price: ${prediction_actual:.2f}"
97
+ )
98
+ print(result)
99
+ return result
100
+
101
+ def predict_api(ticker_symbol):
102
+ return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
103
+
104
+ with gr.Blocks() as app:
105
+ gr.Markdown("This is the backend for the React Stock Forecaster App.")
106
+ ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
107
+ output_text = gr.Textbox(label="Forecast", visible=False)
108
+ ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
109
+
110
+ app = gr.mount_static_directory(app, "build")
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+ if __name__ == "__main__":
120
+ app.launch()