SailajaS commited on
Commit
63a496e
·
verified ·
1 Parent(s): 480b75f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -27
app.py CHANGED
@@ -6,7 +6,7 @@ import tensorflow as tf
6
  from sklearn.preprocessing import MinMaxScaler
7
  import matplotlib.pyplot as plt
8
  import gradio as gr
9
- from datetime import datetime
10
 
11
  # Step 1: Fetch stock data from yfinance
12
  def fetch_stock_data(ticker, start_date, end_date):
@@ -45,12 +45,12 @@ def train_model(model, train_data, epochs=5):
45
  model.fit(X_train, y_train, epochs=epochs, batch_size=32, verbose=0)
46
  return model
47
 
48
- # Step 5: Predict future stock prices
49
- def predict_future(model, last_data, scaler, steps=90):
50
  predictions = []
51
  input_data = last_data[-60:].reshape(1, -1)
52
 
53
- # Generate predictions for the future
54
  for _ in range(steps):
55
  input_reshaped = input_data.reshape(1, 60, 1)
56
  predicted_price = model.predict(input_reshaped, verbose=0)
@@ -63,55 +63,57 @@ def predict_future(model, last_data, scaler, steps=90):
63
 
64
  # Step 6: Plot historical and predicted stock prices
65
  def plot_predictions(data, predicted_prices):
 
66
  last_60_days = data['Close'][-60:].values
67
 
68
- # Create a plot
69
  plt.figure(figsize=(14, 6))
70
 
71
- # Plot historical prices
72
  plt.plot(data['Date'], data['Close'], label="Historical Prices", color='blue')
73
 
74
- # Future dates for predicted prices
75
  future_dates = pd.date_range(start=data['Date'].iloc[-1], periods=len(predicted_prices) + 1, freq='D')[1:]
76
-
77
- # Plot predicted prices
78
  plt.plot(future_dates, predicted_prices, label="Predicted Prices", color='orange')
79
 
80
- plt.title("Stock Price Prediction")
 
81
  plt.xlabel("Date")
82
  plt.ylabel("Stock Price (USD)")
83
  plt.legend()
84
- plt.grid()
85
 
86
- # Save the plot to a file for Gradio to display
87
  plt.savefig("stock_prediction.png")
88
  plt.close()
89
-
90
- return "stock_prediction.png" # Return the file path for Gradio
91
 
92
  # Step 7: Gradio interface function
93
  def stock_prediction_app(ticker, start_date_str, end_date_str):
94
- # Convert date strings to datetime objects
95
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
96
  end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date()
97
 
98
- # Fetch stock data
99
  data = fetch_stock_data(ticker, start_date, end_date)
100
 
101
- # Prepare data for the LSTM model
102
  scaled_data, scaler = prepare_data(data)
103
 
104
  # Build and train the LSTM model
105
  model = build_model((60, 1))
106
  model = train_model(model, scaled_data)
107
 
108
- # Predict future stock prices for the next 90 days
109
- predicted_prices = predict_future(model, scaled_data, scaler)
110
 
111
- # Plot historical and predicted stock prices
112
- plot_path = plot_predictions(data, predicted_prices)
113
 
114
- return plot_path # Return the plot file path
115
 
116
  # Step 8: Gradio UI setup
117
  tickers = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA", "META", "NFLX", "NVDA", "BABA", "BA"]
@@ -124,11 +126,13 @@ ui = gr.Interface(
124
  gr.Textbox(label="Start Date (YYYY-MM-DD)"),
125
  gr.Textbox(label="End Date (YYYY-MM-DD)")
126
  ],
127
- outputs=gr.Image(type="filepath"), # Return the image file path for the plot
128
- title="Stock Prediction App",
129
- description="Select a stock ticker and date range to predict future prices."
 
 
 
130
  )
131
 
132
- # Launch the app
133
  ui.launch()
134
-
 
6
  from sklearn.preprocessing import MinMaxScaler
7
  import matplotlib.pyplot as plt
8
  import gradio as gr
9
+ from datetime import datetime, timedelta
10
 
11
  # Step 1: Fetch stock data from yfinance
12
  def fetch_stock_data(ticker, start_date, end_date):
 
45
  model.fit(X_train, y_train, epochs=epochs, batch_size=32, verbose=0)
46
  return model
47
 
48
+ # Step 5: Predict stock prices
49
+ def predict_future(model, last_data, scaler, steps=1):
50
  predictions = []
51
  input_data = last_data[-60:].reshape(1, -1)
52
 
53
+ # Generate predictions for the specified number of future days
54
  for _ in range(steps):
55
  input_reshaped = input_data.reshape(1, 60, 1)
56
  predicted_price = model.predict(input_reshaped, verbose=0)
 
63
 
64
  # Step 6: Plot historical and predicted stock prices
65
  def plot_predictions(data, predicted_prices):
66
+ # Fetch the last 60 days of data to plot before the prediction starts
67
  last_60_days = data['Close'][-60:].values
68
 
69
+ # Create a figure for the plot
70
  plt.figure(figsize=(14, 6))
71
 
72
+ # Plot historical stock prices
73
  plt.plot(data['Date'], data['Close'], label="Historical Prices", color='blue')
74
 
75
+ # Generate future dates for predicted prices
76
  future_dates = pd.date_range(start=data['Date'].iloc[-1], periods=len(predicted_prices) + 1, freq='D')[1:]
77
+
78
+ # Plot predicted stock prices
79
  plt.plot(future_dates, predicted_prices, label="Predicted Prices", color='orange')
80
 
81
+ # Adding labels and title to the graph
82
+ plt.title("Stock Price Prediction for Tomorrow")
83
  plt.xlabel("Date")
84
  plt.ylabel("Stock Price (USD)")
85
  plt.legend()
86
+ plt.grid(True)
87
 
88
+ # Save the plot as an image file for Gradio to display
89
  plt.savefig("stock_prediction.png")
90
  plt.close()
91
+
92
+ return "stock_prediction.png" # Return the path to the saved image
93
 
94
  # Step 7: Gradio interface function
95
  def stock_prediction_app(ticker, start_date_str, end_date_str):
96
+ # Convert input strings to datetime objects
97
  start_date = datetime.strptime(start_date_str, "%Y-%m-%d").date()
98
  end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date()
99
 
100
+ # Fetch stock data from yfinance
101
  data = fetch_stock_data(ticker, start_date, end_date)
102
 
103
+ # Prepare data for LSTM model
104
  scaled_data, scaler = prepare_data(data)
105
 
106
  # Build and train the LSTM model
107
  model = build_model((60, 1))
108
  model = train_model(model, scaled_data)
109
 
110
+ # Predict stock price for tomorrow (1 day)
111
+ predicted_price = predict_future(model, scaled_data, scaler, steps=1)
112
 
113
+ # Generate and return the plot with historical and predicted prices
114
+ plot_path = plot_predictions(data, predicted_price)
115
 
116
+ return plot_path, predicted_price[0][0] # Return the path of the plot image and tomorrow's predicted price
117
 
118
  # Step 8: Gradio UI setup
119
  tickers = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA", "META", "NFLX", "NVDA", "BABA", "BA"]
 
126
  gr.Textbox(label="Start Date (YYYY-MM-DD)"),
127
  gr.Textbox(label="End Date (YYYY-MM-DD)")
128
  ],
129
+ outputs=[
130
+ gr.Image(type="filepath"), # Return the file path for the generated graph
131
+ gr.Number(label="Predicted Price for Tomorrow (USD)")
132
+ ],
133
+ title="Stock Price Prediction App",
134
+ description="Predict future stock price for tomorrow based on historical data."
135
  )
136
 
137
+ # Launch the Gradio app
138
  ui.launch()