munem420 commited on
Commit
17ef946
Β·
1 Parent(s): ef87e36

app.py added

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import tensorflow as tf
5
+ import joblib
6
+ import numpy as np
7
+ import pandas as pd
8
+ import yfinance as yf
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # This forces TensorFlow to only use the CPU.
12
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
13
+
14
+ # --- 1. Download Model and Scalers from Hugging Face Hub ---
15
+ # FIX #1: Ensured the repository name is 100% correct.
16
+ MODEL_REPO = "munem420/stock-forecaster-lstm"
17
+ MODEL_FILENAME = "model_lstm.h5"
18
+ SCALER_FILENAME = "scalers.joblib"
19
+
20
+ print("--- Downloading model and scalers ---")
21
+ try:
22
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
23
+ scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
24
+ print("βœ… Files downloaded successfully.")
25
+ except Exception as e:
26
+ print(f"❌ Error downloading files: {e}")
27
+ model_path, scalers_path = None, None
28
+
29
+ # --- 2. Load the Model and Scalers ---
30
+ loaded_model_lstm = None
31
+ loaded_scalers = None
32
+
33
+ if model_path and os.path.exists(model_path):
34
+ try:
35
+ loaded_model_lstm = tf.keras.models.load_model(model_path)
36
+ print("βœ… Model loaded successfully.")
37
+ except Exception as e:
38
+ print(f"❌ Error loading model: {e}")
39
+
40
+ if scalers_path and os.path.exists(scalers_path):
41
+ try:
42
+ loaded_scalers = joblib.load(scalers_path)
43
+ print("βœ… Scalers loaded successfully.")
44
+ except Exception as e:
45
+ print(f"❌ Error loading scalers: {e}")
46
+
47
+ ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
48
+
49
+ def get_ticker_from_input(input_name):
50
+ return input_name.upper()
51
+
52
+ # --- 3. The Main Forecasting Function ---
53
+ def forecast_stock(input_name, model, scalers_dict, input_width=60):
54
+ if not model or not scalers_dict:
55
+ return "Error: Model or scalers not loaded. The backend may have failed to start."
56
+ # ... (The rest of the function is the same)
57
+ ticker = get_ticker_from_input(input_name)
58
+ if not ticker:
59
+ return "Error: Invalid stock ticker."
60
+
61
+ print(f"\n--- Generating forecast for {ticker} ---")
62
+ try:
63
+ data_df = yf.download(ticker, period="1y", progress=False)
64
+ if data_df.empty:
65
+ return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
66
+ except Exception as e:
67
+ return f"Error fetching data for {ticker}: {e}"
68
+
69
+ if len(data_df) < input_width:
70
+ return f"Error: Not enough historical data. Need {input_width} days, but only have {len(data_df)}."
71
+
72
+ recent_data = data_df.tail(input_width)
73
+ close_prices = recent_data['Close'].values.reshape(-1, 1)
74
+
75
+ scaler = scalers_dict.get(ticker)
76
+ if not scaler:
77
+ print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
78
+ scaler = scalers_dict.get('ZURVY')
79
+ if not scaler:
80
+ return "Error: Default scaler 'ZURVY' not found."
81
+
82
+ scaled_data = scaler.transform(close_prices)
83
+ X_pred = scaled_data.reshape(1, input_width, 1)
84
+
85
+ prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
86
+ prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
87
+
88
+ last_close = recent_data['Close'].iloc[-1]
89
+ result = (
90
+ f"Last known close for {ticker}: ${last_close:.2f}\n"
91
+ f"Predicted next day's close price: ${prediction_actual:.2f}"
92
+ )
93
+ print(result)
94
+ return result
95
+
96
+ # --- 4. Create the Gradio Interface ---
97
+ def predict_api(ticker_symbol):
98
+ return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
99
+
100
+ with gr.Blocks() as app:
101
+ gr.Markdown("This is the backend for the React Stock Forecaster App.")
102
+ ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
103
+ output_text = gr.Textbox(label="Forecast", visible=False)
104
+ ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
105
+
106
+ # --- 5. Mount and Serve the React App's Static Files ---
107
+ # This function requires a modern version of Gradio, specified in README.md
108
+ app = gr.mount_static_directory(app, "build")
109
+
110
+ # Launch the server
111
+ if __name__ == "__main__":
112
+ app.launch()