munem420 commited on
Commit
cd6bba8
·
verified ·
1 Parent(s): af01c67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -75
app.py CHANGED
@@ -1,91 +1,76 @@
1
  import os
2
- import gradio as gr
3
- import tensorflow as tf
4
- import joblib
5
- import numpy as np
6
  import pandas as pd
7
- from huggingface_hub import hf_hub_download
8
-
9
- os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
 
 
10
 
11
- MODEL_REPO = "munem420/stock-forecaster-lstm"
12
- MODEL_FILENAME = "model_lstm.h5"
13
- SCALER_FILENAME = "scalers.joblib"
 
 
14
 
15
- print("--- Downloading model and scalers ---")
16
- try:
17
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
18
- scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
19
- print("✅ Files downloaded successfully.")
20
- except Exception as e:
21
- print(f"❌ Error downloading files: {e}")
22
- model_path, scalers_path = None, None
23
 
24
- loaded_model_lstm = None
25
- loaded_scalers = None
 
26
 
27
- if model_path and os.path.exists(model_path):
28
- try:
29
- loaded_model_lstm = tf.keras.models.load_model(
30
- model_path,
31
- custom_objects={"mse": tf.keras.losses.MeanSquaredError()}
32
- )
33
- print("✅ Model loaded successfully.")
34
- except Exception as e:
35
- print(f"❌ Error loading model: {e}")
36
 
37
- if scalers_path and os.path.exists(scalers_path):
38
- try:
39
- loaded_scalers = joblib.load(scalers_path)
40
- print("✅ Scalers loaded successfully.")
41
- except Exception as e:
42
- print(f"❌ Error loading scalers: {e}")
43
 
44
- ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
 
 
 
 
 
 
 
 
 
45
 
46
- def get_ticker_from_input(input_name):
47
- return input_name.upper()
 
 
 
 
 
48
 
49
- def forecast_stock(input_name, model, scalers_dict, input_width=60):
50
- if not model or not scalers_dict:
51
- return "Error: Model or scalers not loaded. The backend may have failed to start."
52
  ticker = get_ticker_from_input(input_name)
53
  if not ticker:
54
- return "Error: Invalid stock ticker."
55
- print(f"\n--- Generating forecast for {ticker} ---")
 
 
 
56
 
57
- if len(data_df) < input_width:
58
- return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
59
- recent_data = data_df.tail(input_width)
60
- close_prices = recent_data['Close'].values.reshape(-input, 1)
61
- scaler = scalers_dict.get(ticker)
62
- if not scaler:
63
- print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
64
- scaler = scalers_dict.get('ZURVY')
65
- if not scaler:
66
- return "Error: Default scaler 'ZURVY' not found."
67
  scaled_data = scaler.transform(close_prices)
68
- X_pred = scaled_data.reshape(1, input_width, 1)
69
- prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
70
- prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
71
- last_close = recent_data['Close'].iloc[-1]
72
- result = (
73
- f"Last known close for {ticker}: ${last_close:.2f}\n"
74
- f"Predicted next day's close price: ${prediction_actual:.2f}"
75
- )
76
- print(result)
77
- return result
78
-
79
- def predict_api(ticker_symbol):
80
- return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
81
-
82
- with gr.Blocks() as app:
83
- gr.Markdown("This is the backend for the React Stock Forecaster App.")
84
- ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
85
- output_text = gr.Textbox(label="Forecast", visible=False)
86
- ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
87
 
88
- app = gr.mount_static_directory(app, "build")
 
 
 
 
 
 
 
89
 
90
- if __name__ == "__main__":
91
- app.launch()
 
1
  import os
 
 
 
 
2
  import pandas as pd
3
+ import numpy as np
4
+ import joblib
5
+ import tensorflow as tf
6
+ from fuzzywuzzy import process
7
+ import gradio as gr
8
 
9
+ # --- Load data ---
10
+ DATA_PATH = "/kaggle/working/all_stocks_long.csv" # update if needed
11
+ MODEL_DIR = "/kaggle/working/stock-forecaster-lstm"
12
+ MODEL_FILE = "model_lstm.h5"
13
+ SCALER_FILE = "scalers.joblib"
14
 
15
+ combined_fe = pd.read_csv(DATA_PATH, parse_dates=['date'])
16
+ combined_fe['ticker'] = combined_fe['ticker'].str.upper()
 
 
 
 
 
 
17
 
18
+ # --- Load model ---
19
+ model_path = os.path.join(MODEL_DIR, MODEL_FILE)
20
+ scaler_path = os.path.join(MODEL_DIR, SCALER_FILE)
21
 
22
+ loaded_model = tf.keras.models.load_model(model_path)
23
+ loaded_scalers = joblib.load(scaler_path)
 
 
 
 
 
 
 
24
 
25
+ # --- Build ticker <-> company mappings ---
26
+ top_tickers = combined_fe['ticker'].unique()
27
+ ticker_to_name = {t: t for t in top_tickers} # can update with real names if available
28
+ name_to_ticker = {v: k for k,v in ticker_to_name.items()}
 
 
29
 
30
+ # --- Prediction helpers ---
31
+ def get_ticker_from_input(input_str):
32
+ if input_str.upper() in ticker_to_name:
33
+ return input_str.upper()
34
+ if input_str in name_to_ticker:
35
+ return name_to_ticker[input_str]
36
+ best_match, score = process.extractOne(input_str, name_to_ticker.keys())
37
+ if score > 80:
38
+ return name_to_ticker[best_match]
39
+ return None
40
 
41
+ def make_windows(series, input_width=60, horizon=1):
42
+ arr = series.values.astype(np.float32)
43
+ X, y = [], []
44
+ for i in range(input_width, len(arr)-horizon+1):
45
+ X.append(arr[i-input_width:i])
46
+ y.append(arr[i + (horizon-1)])
47
+ return np.array(X), np.array(y)
48
 
49
+ def forecast_stock(input_name):
 
 
50
  ticker = get_ticker_from_input(input_name)
51
  if not ticker:
52
+ return f"Could not find ticker for '{input_name}'"
53
+
54
+ recent_data = combined_fe[combined_fe['ticker']==ticker].sort_values('date').tail(60)
55
+ if len(recent_data) < 60:
56
+ return f"Not enough historical data for {ticker}."
57
 
58
+ close_prices = recent_data['close'].values.reshape(-1, 1)
59
+ scaler = loaded_scalers[ticker]
 
 
 
 
 
 
 
 
60
  scaled_data = scaler.transform(close_prices)
61
+ X_pred = scaled_data.reshape(1, 60, 1)
62
+
63
+ prediction_scaled = loaded_model.predict(X_pred, verbose=0)[0][0]
64
+ prediction_actual = scaler.inverse_transform([[prediction_scaled]])[0][0]
65
+ return round(prediction_actual, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ # --- Gradio Interface ---
68
+ iface = gr.Interface(
69
+ fn=forecast_stock,
70
+ inputs=gr.Textbox(label="Enter Ticker or Company Name"),
71
+ outputs=gr.Textbox(label="Predicted Next Day Close"),
72
+ title="Stock Price Forecaster (LSTM)",
73
+ description="Enter a stock ticker or company name to predict the next day's close price."
74
+ )
75
 
76
+ iface.launch()