SailajaS commited on
Commit
d93f225
·
verified ·
1 Parent(s): 2ec78ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -86
app.py CHANGED
@@ -12,33 +12,12 @@ import PIL.Image
12
 
13
  # Step 1: Fetch stock data
14
  def fetch_stock_data(ticker, start_date, end_date):
15
- """
16
- Fetch historical stock data from Yahoo Finance using the yfinance library.
17
-
18
- Args:
19
- ticker (str): Stock ticker symbol (e.g., 'AAPL').
20
- start_date (str): Start date for fetching stock data (YYYY-MM-DD).
21
- end_date (str): End date for fetching stock data (YYYY-MM-DD).
22
-
23
- Returns:
24
- pd.DataFrame: Stock data including Date, Open, High, Low, Close, and Volume.
25
- """
26
  stock_data = yf.download(ticker, start=start_date, end=end_date)
27
  stock_data.reset_index(inplace=True)
28
  return stock_data
29
 
30
  # Step 2: Prepare data for the LSTM model
31
  def prepare_data(df):
32
- """
33
- Prepares the stock data for the LSTM model by scaling the 'Close' price.
34
-
35
- Args:
36
- df (pd.DataFrame): DataFrame containing the stock data.
37
-
38
- Returns:
39
- scaled_data (np.array): Normalized stock prices for training the model.
40
- scaler (MinMaxScaler): Scaler used to normalize and later denormalize predictions.
41
- """
42
  scaler = MinMaxScaler(feature_range=(0, 1))
43
  close_prices = df['Close'].values.reshape(-1, 1)
44
  scaled_data = scaler.fit_transform(close_prices)
@@ -46,15 +25,6 @@ def prepare_data(df):
46
 
47
  # Step 3: Build the LSTM model
48
  def build_model(input_shape):
49
- """
50
- Builds and compiles the LSTM model for stock price prediction.
51
-
52
- Args:
53
- input_shape (tuple): Shape of the input data for the LSTM model.
54
-
55
- Returns:
56
- tf.keras.Model: Compiled LSTM model.
57
- """
58
  model = tf.keras.Sequential([
59
  tf.keras.layers.LSTM(50, return_sequences=True, input_shape=input_shape),
60
  tf.keras.layers.LSTM(50, return_sequences=False),
@@ -66,17 +36,6 @@ def build_model(input_shape):
66
 
67
  # Step 4: Train the LSTM model
68
  def train_model(model, train_data, epochs=5):
69
- """
70
- Trains the LSTM model using the scaled stock price data.
71
-
72
- Args:
73
- model (tf.keras.Model): The LSTM model to be trained.
74
- train_data (np.array): Scaled stock price data for training.
75
- epochs (int): Number of training epochs.
76
-
77
- Returns:
78
- model (tf.keras.Model): The trained LSTM model.
79
- """
80
  X_train, y_train = [], []
81
  for i in range(60, len(train_data)):
82
  X_train.append(train_data[i-60:i, 0])
@@ -90,17 +49,6 @@ def train_model(model, train_data, epochs=5):
90
 
91
  # Step 5: Predict future stock prices
92
  def predict_future(model, last_data, steps=90):
93
- """
94
- Predicts future stock prices using the trained LSTM model.
95
-
96
- Args:
97
- model (tf.keras.Model): The trained LSTM model.
98
- last_data (np.array): The last 60 days of stock price data.
99
- steps (int): Number of future days to predict.
100
-
101
- Returns:
102
- predictions (list): Predicted stock prices for the future.
103
- """
104
  predictions = []
105
  input_data = last_data[-60:].reshape(1, -1)
106
 
@@ -113,58 +61,34 @@ def predict_future(model, last_data, steps=90):
113
 
114
  # Step 6: Plot historical and predicted stock prices
115
  def plot_predictions(data, predicted_prices, scaler):
116
- """
117
- Plots the historical stock prices and the predicted future stock prices.
118
-
119
- Args:
120
- data (pd.DataFrame): DataFrame containing historical stock data.
121
- predicted_prices (list): Predicted stock prices for future dates.
122
- scaler (MinMaxScaler): Scaler to inverse transform the predicted prices.
123
-
124
- Returns:
125
- PIL.Image: The image of the plot saved in memory.
126
- """
127
  last_60_days = data['Close'][-60:].values.reshape(-1, 1)
128
  predicted_prices = np.array(predicted_prices).reshape(-1, 1)
129
  predicted_prices = scaler.inverse_transform(predicted_prices)
130
 
131
- # Plot historical data
132
  plt.figure(figsize=(14,6))
133
- plt.plot(data['Close'], label="Historical Prices")
 
 
 
134
 
135
- # Plot predicted data
136
- future_days = range(len(data), len(data) + len(predicted_prices))
137
- plt.plot(future_days, predicted_prices, label="Predicted Prices")
138
  plt.title("Stock Price Prediction")
139
- plt.xlabel("Days")
140
  plt.ylabel("Stock Price")
141
  plt.legend()
142
 
143
- # Save the plot to a bytes buffer
144
  buf = BytesIO()
145
  plt.savefig(buf, format='png')
146
  buf.seek(0)
147
  image = PIL.Image.open(buf)
148
 
149
- # Clear the plot so it doesn’t overlap
150
- plt.clf()
151
 
152
  return image
153
 
154
  # Step 7: Gradio Interface Function
155
  def stock_prediction_app(ticker, start_date_str, end_date_str):
156
- """
157
- The core function for the Gradio app. Fetches stock data, trains the LSTM model,
158
- predicts future prices, and visualizes the results.
159
-
160
- Args:
161
- ticker (str): Stock ticker symbol selected by the user.
162
- start_date_str (str): Start date selected by the user (YYYY-MM-DD).
163
- end_date_str (str): End date selected by the user (YYYY-MM-DD).
164
-
165
- Returns:
166
- PIL.Image: The plot showing historical and predicted stock prices.
167
- """
168
  # Convert date strings to datetime objects
169
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
170
  end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date()
@@ -195,8 +119,8 @@ ui = gr.Interface(
195
  fn=stock_prediction_app,
196
  inputs=[
197
  gr.Dropdown(tickers, label="Select Stock Ticker"),
198
- gr.Textbox(label="Start Date (YYYY-MM-DD)"), # Use textbox instead
199
- gr.Textbox(label="End Date (YYYY-MM-DD)") # Use textbox instead
200
  ],
201
  outputs=gr.Image(), # Updated output to return an image
202
  title="Stock Prediction App",
 
12
 
13
  # Step 1: Fetch stock data
14
  def fetch_stock_data(ticker, start_date, end_date):
 
 
 
 
 
 
 
 
 
 
 
15
  stock_data = yf.download(ticker, start=start_date, end=end_date)
16
  stock_data.reset_index(inplace=True)
17
  return stock_data
18
 
19
  # Step 2: Prepare data for the LSTM model
20
  def prepare_data(df):
 
 
 
 
 
 
 
 
 
 
21
  scaler = MinMaxScaler(feature_range=(0, 1))
22
  close_prices = df['Close'].values.reshape(-1, 1)
23
  scaled_data = scaler.fit_transform(close_prices)
 
25
 
26
  # Step 3: Build the LSTM model
27
  def build_model(input_shape):
 
 
 
 
 
 
 
 
 
28
  model = tf.keras.Sequential([
29
  tf.keras.layers.LSTM(50, return_sequences=True, input_shape=input_shape),
30
  tf.keras.layers.LSTM(50, return_sequences=False),
 
36
 
37
  # Step 4: Train the LSTM model
38
  def train_model(model, train_data, epochs=5):
 
 
 
 
 
 
 
 
 
 
 
39
  X_train, y_train = [], []
40
  for i in range(60, len(train_data)):
41
  X_train.append(train_data[i-60:i, 0])
 
49
 
50
  # Step 5: Predict future stock prices
51
  def predict_future(model, last_data, steps=90):
 
 
 
 
 
 
 
 
 
 
 
52
  predictions = []
53
  input_data = last_data[-60:].reshape(1, -1)
54
 
 
61
 
62
  # Step 6: Plot historical and predicted stock prices
63
  def plot_predictions(data, predicted_prices, scaler):
 
 
 
 
 
 
 
 
 
 
 
64
  last_60_days = data['Close'][-60:].values.reshape(-1, 1)
65
  predicted_prices = np.array(predicted_prices).reshape(-1, 1)
66
  predicted_prices = scaler.inverse_transform(predicted_prices)
67
 
68
+ # Create the plot
69
  plt.figure(figsize=(14,6))
70
+ plt.plot(data['Date'], data['Close'], label="Historical Prices", color='blue')
71
+
72
+ future_dates = pd.date_range(start=data['Date'].iloc[-1], periods=len(predicted_prices)+1)[1:]
73
+ plt.plot(future_dates, predicted_prices, label="Predicted Prices", color='orange')
74
 
 
 
 
75
  plt.title("Stock Price Prediction")
76
+ plt.xlabel("Date")
77
  plt.ylabel("Stock Price")
78
  plt.legend()
79
 
80
+ # Save plot to in-memory buffer
81
  buf = BytesIO()
82
  plt.savefig(buf, format='png')
83
  buf.seek(0)
84
  image = PIL.Image.open(buf)
85
 
86
+ plt.close() # Clear the plot to avoid overlap
 
87
 
88
  return image
89
 
90
  # Step 7: Gradio Interface Function
91
  def stock_prediction_app(ticker, start_date_str, end_date_str):
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Convert date strings to datetime objects
93
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
94
  end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date()
 
119
  fn=stock_prediction_app,
120
  inputs=[
121
  gr.Dropdown(tickers, label="Select Stock Ticker"),
122
+ gr.Textbox(label="Start Date (YYYY-MM-DD)"),
123
+ gr.Textbox(label="End Date (YYYY-MM-DD)")
124
  ],
125
  outputs=gr.Image(), # Updated output to return an image
126
  title="Stock Prediction App",