SailajaS commited on
Commit
480b75f
·
verified ·
1 Parent(s): 070285a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -27
app.py CHANGED
@@ -8,7 +8,7 @@ import matplotlib.pyplot as plt
8
  import gradio as gr
9
  from datetime import datetime
10
 
11
- # Step 1: Fetch stock data
12
  def fetch_stock_data(ticker, start_date, end_date):
13
  stock_data = yf.download(ticker, start=start_date, end=end_date)
14
  stock_data.reset_index(inplace=True)
@@ -46,46 +46,50 @@ def train_model(model, train_data, epochs=5):
46
  return model
47
 
48
  # Step 5: Predict future stock prices
49
- def predict_future(model, last_data, steps=90):
50
  predictions = []
51
  input_data = last_data[-60:].reshape(1, -1)
52
-
 
53
  for _ in range(steps):
54
- predicted_price = model.predict(input_data.reshape(1, 60, 1), verbose=0)
 
55
  predictions.append(predicted_price[0][0])
56
  input_data = np.append(input_data[0][1:], predicted_price)
57
 
58
- return predictions
 
 
59
 
60
  # Step 6: Plot historical and predicted stock prices
61
- def plot_predictions(data, predicted_prices, scaler):
62
- last_60_days = data['Close'][-60:].values.reshape(-1, 1)
63
- predicted_prices = np.array(predicted_prices).reshape(-1, 1)
64
- predicted_prices = scaler.inverse_transform(predicted_prices)
65
 
66
- # Create the plot
67
- plt.figure(figsize=(14,6))
68
 
69
- # Plot historical stock prices
70
  plt.plot(data['Date'], data['Close'], label="Historical Prices", color='blue')
71
 
72
- # Plot future predictions
73
- future_dates = pd.date_range(start=data['Date'].iloc[-1], periods=len(predicted_prices)+1)[1:]
 
 
74
  plt.plot(future_dates, predicted_prices, label="Predicted Prices", color='orange')
75
 
76
  plt.title("Stock Price Prediction")
77
  plt.xlabel("Date")
78
- plt.ylabel("Stock Price")
79
  plt.legend()
80
  plt.grid()
81
 
82
- # Save plot to file to ensure consistent rendering
83
  plt.savefig("stock_prediction.png")
84
- plt.close() # Clear the plot to avoid overlap
85
 
86
- return "stock_prediction.png" # Return the file path for Gradio to display
87
 
88
- # Step 7: Gradio Interface Function
89
  def stock_prediction_app(ticker, start_date_str, end_date_str):
90
  # Convert date strings to datetime objects
91
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
@@ -101,18 +105,18 @@ def stock_prediction_app(ticker, start_date_str, end_date_str):
101
  model = build_model((60, 1))
102
  model = train_model(model, scaled_data)
103
 
104
- # Predict future prices for the next 90 days
105
- predicted_prices = predict_future(model, scaled_data)
106
 
107
- # Plot historical and predicted prices
108
- plot_path = plot_predictions(data, predicted_prices, scaler)
109
 
110
- return plot_path
111
 
112
- # Step 8: Gradio UI Setup
113
  tickers = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA", "META", "NFLX", "NVDA", "BABA", "BA"]
114
 
115
- # Create the Gradio interface using updated components
116
  ui = gr.Interface(
117
  fn=stock_prediction_app,
118
  inputs=[
@@ -120,10 +124,11 @@ ui = gr.Interface(
120
  gr.Textbox(label="Start Date (YYYY-MM-DD)"),
121
  gr.Textbox(label="End Date (YYYY-MM-DD)")
122
  ],
123
- outputs=gr.Image(type="filepath"), # Return the image file path (fixed here)
124
  title="Stock Prediction App",
125
  description="Select a stock ticker and date range to predict future prices."
126
  )
127
 
128
  # Launch the app
129
  ui.launch()
 
 
8
  import gradio as gr
9
  from datetime import datetime
10
 
11
+ # Step 1: Fetch stock data from yfinance
12
  def fetch_stock_data(ticker, start_date, end_date):
13
  stock_data = yf.download(ticker, start=start_date, end=end_date)
14
  stock_data.reset_index(inplace=True)
 
46
  return model
47
 
48
  # Step 5: Predict future stock prices
49
+ def predict_future(model, last_data, scaler, steps=90):
50
  predictions = []
51
  input_data = last_data[-60:].reshape(1, -1)
52
+
53
+ # Generate predictions for the future
54
  for _ in range(steps):
55
+ input_reshaped = input_data.reshape(1, 60, 1)
56
+ predicted_price = model.predict(input_reshaped, verbose=0)
57
  predictions.append(predicted_price[0][0])
58
  input_data = np.append(input_data[0][1:], predicted_price)
59
 
60
+ predicted_prices = np.array(predictions).reshape(-1, 1)
61
+ predicted_prices = scaler.inverse_transform(predicted_prices) # Reverse scaling
62
+ return predicted_prices
63
 
64
  # Step 6: Plot historical and predicted stock prices
65
+ def plot_predictions(data, predicted_prices):
66
+ last_60_days = data['Close'][-60:].values
 
 
67
 
68
+ # Create a plot
69
+ plt.figure(figsize=(14, 6))
70
 
71
+ # Plot historical prices
72
  plt.plot(data['Date'], data['Close'], label="Historical Prices", color='blue')
73
 
74
+ # Future dates for predicted prices
75
+ future_dates = pd.date_range(start=data['Date'].iloc[-1], periods=len(predicted_prices) + 1, freq='D')[1:]
76
+
77
+ # Plot predicted prices
78
  plt.plot(future_dates, predicted_prices, label="Predicted Prices", color='orange')
79
 
80
  plt.title("Stock Price Prediction")
81
  plt.xlabel("Date")
82
+ plt.ylabel("Stock Price (USD)")
83
  plt.legend()
84
  plt.grid()
85
 
86
+ # Save the plot to a file for Gradio to display
87
  plt.savefig("stock_prediction.png")
88
+ plt.close()
89
 
90
+ return "stock_prediction.png" # Return the file path for Gradio
91
 
92
+ # Step 7: Gradio interface function
93
  def stock_prediction_app(ticker, start_date_str, end_date_str):
94
  # Convert date strings to datetime objects
95
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
 
105
  model = build_model((60, 1))
106
  model = train_model(model, scaled_data)
107
 
108
+ # Predict future stock prices for the next 90 days
109
+ predicted_prices = predict_future(model, scaled_data, scaler)
110
 
111
+ # Plot historical and predicted stock prices
112
+ plot_path = plot_predictions(data, predicted_prices)
113
 
114
+ return plot_path # Return the plot file path
115
 
116
+ # Step 8: Gradio UI setup
117
  tickers = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA", "META", "NFLX", "NVDA", "BABA", "BA"]
118
 
119
+ # Define the Gradio interface
120
  ui = gr.Interface(
121
  fn=stock_prediction_app,
122
  inputs=[
 
124
  gr.Textbox(label="Start Date (YYYY-MM-DD)"),
125
  gr.Textbox(label="End Date (YYYY-MM-DD)")
126
  ],
127
+ outputs=gr.Image(type="filepath"), # Return the image file path for the plot
128
  title="Stock Prediction App",
129
  description="Select a stock ticker and date range to predict future prices."
130
  )
131
 
132
  # Launch the app
133
  ui.launch()
134
+