SailajaS commited on
Commit
abc4c00
·
verified ·
1 Parent(s): 6880b2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # Import required libraries
2
  import yfinance as yf
3
  import pandas as pd
@@ -7,6 +8,8 @@ from sklearn.preprocessing import MinMaxScaler
7
  import matplotlib.pyplot as plt
8
  import gradio as gr
9
  from datetime import datetime
 
 
10
 
11
  # Step 1: Fetch stock data
12
  def fetch_stock_data(ticker, start_date, end_date):
@@ -118,6 +121,9 @@ def plot_predictions(data, predicted_prices, scaler):
118
  data (pd.DataFrame): DataFrame containing historical stock data.
119
  predicted_prices (list): Predicted stock prices for future dates.
120
  scaler (MinMaxScaler): Scaler to inverse transform the predicted prices.
 
 
 
121
  """
122
  last_60_days = data['Close'][-60:].values.reshape(-1, 1)
123
  predicted_prices = np.array(predicted_prices).reshape(-1, 1)
@@ -134,7 +140,17 @@ def plot_predictions(data, predicted_prices, scaler):
134
  plt.xlabel("Days")
135
  plt.ylabel("Stock Price")
136
  plt.legend()
137
- plt.show()
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Step 7: Gradio Interface Function
140
  def stock_prediction_app(ticker, start_date_str, end_date_str):
@@ -148,7 +164,7 @@ def stock_prediction_app(ticker, start_date_str, end_date_str):
148
  end_date_str (str): End date selected by the user (YYYY-MM-DD).
149
 
150
  Returns:
151
- None (Displays a plot of historical and predicted stock prices).
152
  """
153
  # Convert date strings to datetime objects
154
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
@@ -168,7 +184,9 @@ def stock_prediction_app(ticker, start_date_str, end_date_str):
168
  predicted_prices = predict_future(model, scaled_data)
169
 
170
  # Plot historical and predicted prices
171
- plot_predictions(data, predicted_prices, scaler)
 
 
172
 
173
  # Step 8: Gradio UI Setup
174
  tickers = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA", "META", "NFLX", "NVDA", "BABA", "BA"]
@@ -181,7 +199,7 @@ ui = gr.Interface(
181
  gr.Textbox(label="Start Date (YYYY-MM-DD)"), # Use textbox instead
182
  gr.Textbox(label="End Date (YYYY-MM-DD)") # Use textbox instead
183
  ],
184
- outputs="plot",
185
  title="Stock Prediction App",
186
  description="Select a stock ticker and date range to predict future prices."
187
  )
 
1
+
2
  # Import required libraries
3
  import yfinance as yf
4
  import pandas as pd
 
8
  import matplotlib.pyplot as plt
9
  import gradio as gr
10
  from datetime import datetime
11
+ from io import BytesIO
12
+ import PIL.Image
13
 
14
  # Step 1: Fetch stock data
15
  def fetch_stock_data(ticker, start_date, end_date):
 
121
  data (pd.DataFrame): DataFrame containing historical stock data.
122
  predicted_prices (list): Predicted stock prices for future dates.
123
  scaler (MinMaxScaler): Scaler to inverse transform the predicted prices.
124
+
125
+ Returns:
126
+ PIL.Image: The image of the plot saved in memory.
127
  """
128
  last_60_days = data['Close'][-60:].values.reshape(-1, 1)
129
  predicted_prices = np.array(predicted_prices).reshape(-1, 1)
 
140
  plt.xlabel("Days")
141
  plt.ylabel("Stock Price")
142
  plt.legend()
143
+
144
+ # Save the plot to a bytes buffer
145
+ buf = BytesIO()
146
+ plt.savefig(buf, format='png')
147
+ buf.seek(0)
148
+ image = PIL.Image.open(buf)
149
+
150
+ # Clear the plot so it doesn’t overlap
151
+ plt.clf()
152
+
153
+ return image
154
 
155
  # Step 7: Gradio Interface Function
156
  def stock_prediction_app(ticker, start_date_str, end_date_str):
 
164
  end_date_str (str): End date selected by the user (YYYY-MM-DD).
165
 
166
  Returns:
167
+ PIL.Image: The plot showing historical and predicted stock prices.
168
  """
169
  # Convert date strings to datetime objects
170
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
 
184
  predicted_prices = predict_future(model, scaled_data)
185
 
186
  # Plot historical and predicted prices
187
+ image = plot_predictions(data, predicted_prices, scaler)
188
+
189
+ return image
190
 
191
  # Step 8: Gradio UI Setup
192
  tickers = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA", "META", "NFLX", "NVDA", "BABA", "BA"]
 
199
  gr.Textbox(label="Start Date (YYYY-MM-DD)"), # Use textbox instead
200
  gr.Textbox(label="End Date (YYYY-MM-DD)") # Use textbox instead
201
  ],
202
+ outputs=gr.Image(), # Updated output to return an image
203
  title="Stock Prediction App",
204
  description="Select a stock ticker and date range to predict future prices."
205
  )