munem420 commited on
Commit
4b9dc02
·
verified ·
1 Parent(s): d7298e9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import joblib
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+ import yfinance as yf
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # --- 1. Download Model and Scalers from Hugging Face Hub ---
11
+ # This is better than manually uploading them. The Space will fetch them automatically.
12
+ MODEL_REPO = "munem420/stock-forecaster-lstm"
13
+ MODEL_FILENAME = "model_lstm.h5"
14
+ SCALER_FILENAME = "scalers.joblib"
15
+
16
+ print("--- Downloading model and scalers ---")
17
+ try:
18
+ model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
19
+ scalers_path = hf_hub_download(repo_id=MODEL_REPO, filename=SCALER_FILENAME)
20
+ print("✅ Files downloaded successfully.")
21
+ except Exception as e:
22
+ print(f"❌ Error downloading files: {e}")
23
+ model_path, scalers_path = None, None
24
+
25
+ # --- 2. Load the Model and Scalers ---
26
+ loaded_model_lstm = None
27
+ loaded_scalers = None
28
+
29
+ if model_path and os.path.exists(model_path):
30
+ try:
31
+ loaded_model_lstm = tf.keras.models.load_model(model_path)
32
+ print("✅ Model loaded successfully.")
33
+ except Exception as e:
34
+ print(f"❌ Error loading model: {e}")
35
+
36
+ if scalers_path and os.path.exists(scalers_path):
37
+ try:
38
+ loaded_scalers = joblib.load(scalers_path)
39
+ print("✅ Scalers loaded successfully.")
40
+ except Exception as e:
41
+ print(f"❌ Error loading scalers: {e}")
42
+
43
+ # This dictionary is part of the original model's logic.
44
+ # A more robust solution would fetch this dynamically or store it better.
45
+ ticker_to_name = {'ZURVY': 'Zurich Insurance Group AG'}
46
+
47
+ def get_ticker_from_input(input_name):
48
+ # Simplified version for this app
49
+ return input_name.upper()
50
+
51
+ # --- 3. The Main Forecasting Function (Adapted from your code) ---
52
+ def forecast_stock(input_name, model, scalers_dict, input_width=60):
53
+ if not model or not scalers_dict:
54
+ return "Error: Model or scalers not loaded."
55
+
56
+ ticker = get_ticker_from_input(input_name)
57
+ if not ticker:
58
+ return "Error: Invalid stock ticker."
59
+
60
+ print(f"\n--- Generating forecast for {ticker} ---")
61
+
62
+ # Fetch recent data using yfinance
63
+ try:
64
+ data_df = yf.download(ticker, period="1y", progress=False)
65
+ if data_df.empty:
66
+ return f"Error: No data found for ticker {ticker}. It may be delisted or invalid."
67
+ except Exception as e:
68
+ return f"Error fetching data for {ticker}: {e}"
69
+
70
+
71
+ if len(data_df) < input_width:
72
+ return f"Error: Not enough historical data for {ticker}. Need {input_width} days, but only have {len(data_df)}."
73
+
74
+ recent_data = data_df.tail(input_width)
75
+ close_prices = recent_data['Close'].values.reshape(-1, 1)
76
+
77
+ # Note: The original scalers were trained on specific stocks.
78
+ # Using a scaler for a different stock (e.g., AAPL) on a new ticker might not be accurate.
79
+ # For this example, we'll try to find a matching scaler or default to a common one.
80
+ scaler = scalers_dict.get(ticker)
81
+ if not scaler:
82
+ print(f"Warning: No specific scaler found for {ticker}. Using ZURVY's scaler as a fallback.")
83
+ scaler = scalers_dict.get('ZURVY')
84
+ if not scaler:
85
+ return "Error: Default scaler 'ZURVY' not found."
86
+
87
+ scaled_data = scaler.transform(close_prices)
88
+ X_pred = scaled_data.reshape(1, input_width, 1)
89
+
90
+ prediction_scaled = model.predict(X_pred, verbose=0)[0][0]
91
+ prediction_actual = scaler.inverse_transform(np.array([[prediction_scaled]]))[0][0]
92
+
93
+ last_close = recent_data['Close'].iloc[-1]
94
+
95
+ result = (
96
+ f"Last known close for {ticker}: ${last_close:.2f}\n"
97
+ f"Predicted next day's close price: ${prediction_actual:.2f}"
98
+ )
99
+ print(result)
100
+ return result
101
+
102
+ # --- 4. Create the Gradio Interface ---
103
+ # We create a simple function that Gradio can expose as an API endpoint.
104
+ def predict_api(ticker_symbol):
105
+ return forecast_stock(ticker_symbol, loaded_model_lstm, loaded_scalers)
106
+
107
+ # We use a dummy Gradio interface because we only need its backend API capabilities.
108
+ # The `gr.Blocks()` allows us to run the server without displaying a UI.
109
+ with gr.Blocks() as app:
110
+ gr.Markdown("This is the backend for the React Stock Forecaster App.")
111
+ # This creates an API endpoint at `/run/predict`
112
+ ticker_input = gr.Textbox(label="Stock Ticker", visible=False)
113
+ output_text = gr.Textbox(label="Forecast", visible=False)
114
+
115
+ # The Gradio API function must be tied to an event
116
+ # We will call this endpoint from our React app.
117
+ ticker_input.submit(predict_api, inputs=[ticker_input], outputs=[output_text], api_name="predict")
118
+
119
+ # --- 5. Mount and Serve the React App's Static Files ---
120
+ # Before running this, you must build your React app using `npm run build`.
121
+ # This will create a `build` directory with static files.
122
+ # Gradio will serve the `index.html` from this directory.
123
+ app.mount_static_directory("./build")
124
+
125
+ # Launch the server
126
+ app.launch()