dummydj2633 commited on
Commit
14a7e0f
·
verified ·
1 Parent(s): cca870e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -76
app.py CHANGED
@@ -1,91 +1,91 @@
1
- # import streamlit as st
2
- # import yfinance as yf
3
- # import pandas as pd
4
- # import matplotlib.pyplot as plt
5
- # import numpy as np
6
- # from sklearn.preprocessing import MinMaxScaler
7
- # from tensorflow.keras.models import Sequential
8
- # from tensorflow.keras.layers import LSTM, Dense
9
 
10
- # def fetch_stock_data(symbol):
11
- # stock = yf.Ticker(symbol)
12
- # hist = stock.history(period="6mo")
13
- # return hist
14
 
15
- # def prepare_data(data, time_steps=10):
16
- # scaler = MinMaxScaler(feature_range=(0,1))
17
- # scaled_data = scaler.fit_transform(data[['Close']].values)
18
 
19
- # X, y = [], []
20
- # for i in range(time_steps, len(scaled_data)):
21
- # X.append(scaled_data[i-time_steps:i, 0])
22
- # y.append(scaled_data[i, 0])
23
 
24
- # X, y = np.array(X), np.array(y)
25
- # X = np.reshape(X, (X.shape[0], X.shape[1], 1))
26
- # return X, y, scaler
27
 
28
- # def build_lstm_model():
29
- # model = Sequential([
30
- # LSTM(50, return_sequences=True, input_shape=(10,1)),
31
- # LSTM(50, return_sequences=False),
32
- # Dense(25),
33
- # Dense(1) ])
34
- # model.compile(optimizer='adam', loss='mean_squared_error')
35
- # return model
36
 
37
- # def predict_prices(data, days=7):
38
- # X, y, scaler = prepare_data(data)
39
- # model = build_lstm_model()
40
- # model.fit(X, y, epochs=10, batch_size=16, verbose=0)
41
 
42
- # future_input = X[-1].reshape(1, 10, 1)
43
- # predictions = []
44
- # for _ in range(days):
45
- # pred = model.predict(future_input)[0, 0]
46
- # predictions.append(pred)
47
- # future_input = np.append(future_input[:, 1:, :], [[[pred]]], axis=1)
48
- # return np.array(predictions).reshape(-1, 1), scaler
49
 
50
- # def main():
51
- # st.title("Stock Data Viewer with LSTM Prediction")
52
 
53
- # symbol = st.text_input("Enter Stock Symbol (e.g., AAPL, TSLA):")
54
 
55
- # if symbol:
56
- # try:
57
- # data = fetch_stock_data(symbol)
58
- # if data.empty:
59
- # st.error("Invalid stock symbol or no data available.")
60
- # else:
61
- # st.subheader("Stock Price History")
62
- # st.write(data)
63
 
64
- # st.subheader("Stock Price Chart")
65
- # fig, ax = plt.subplots()
66
- # ax.plot(data.index, data['Close'], label='Close Price')
67
- # ax.set_xlabel("Date")
68
- # ax.set_ylabel("Price (USD)")
69
- # ax.legend()
70
- # st.pyplot(fig)
71
 
72
- # days_to_predict = st.slider("Days to Predict", 1, 30, 7)
73
- # predictions, scaler = predict_prices(data, days_to_predict)
74
- # predicted_prices = scaler.inverse_transform(predictions)
75
 
76
- # st.subheader("Predicted Stock Prices")
77
- # fig, ax = plt.subplots()
78
- # ax.plot(data.index, data['Close'], label='Historical Close Price')
79
- # ax.plot(pd.date_range(data.index[-1], periods=days_to_predict+1, freq='D')[1:], predicted_prices, linestyle='dashed', label='Predicted Price')
80
- # ax.set_xlabel("Date")
81
- # ax.set_ylabel("Price (USD)")
82
- # ax.legend()
83
- # st.pyplot(fig)
84
 
85
- # csv = data.to_csv().encode('utf-8')
86
- # st.download_button("Download CSV", csv, f"{symbol}_data.csv", "text/csv")
87
- # except Exception as e:
88
- # st.error(f"An error occurred: {e}")
89
 
90
- # if __name__ == "main":
91
- # main()
 
1
+ import streamlit as st
2
+ import yfinance as yf
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ from sklearn.preprocessing import MinMaxScaler
7
+ from tensorflow.keras.models import Sequential
8
+ from tensorflow.keras.layers import LSTM, Dense
9
 
10
+ def fetch_stock_data(symbol):
11
+ stock = yf.Ticker(symbol)
12
+ hist = stock.history(period="6mo")
13
+ return hist
14
 
15
+ def prepare_data(data, time_steps=10):
16
+ scaler = MinMaxScaler(feature_range=(0,1))
17
+ scaled_data = scaler.fit_transform(data[['Close']].values)
18
 
19
+ X, y = [], []
20
+ for i in range(time_steps, len(scaled_data)):
21
+ X.append(scaled_data[i-time_steps:i, 0])
22
+ y.append(scaled_data[i, 0])
23
 
24
+ X, y = np.array(X), np.array(y)
25
+ X = np.reshape(X, (X.shape[0], X.shape[1], 1))
26
+ return X, y, scaler
27
 
28
+ def build_lstm_model():
29
+ model = Sequential([
30
+ LSTM(50, return_sequences=True, input_shape=(10,1)),
31
+ LSTM(50, return_sequences=False),
32
+ Dense(25),
33
+ Dense(1) ])
34
+ model.compile(optimizer='adam', loss='mean_squared_error')
35
+ return model
36
 
37
+ def predict_prices(data, days=7):
38
+ X, y, scaler = prepare_data(data)
39
+ model = build_lstm_model()
40
+ model.fit(X, y, epochs=10, batch_size=16, verbose=0)
41
 
42
+ future_input = X[-1].reshape(1, 10, 1)
43
+ predictions = []
44
+ for _ in range(days):
45
+ pred = model.predict(future_input)[0, 0]
46
+ predictions.append(pred)
47
+ future_input = np.append(future_input[:, 1:, :], [[[pred]]], axis=1)
48
+ return np.array(predictions).reshape(-1, 1), scaler
49
 
50
+ def main():
51
+ st.title("Stock Data Viewer with LSTM Prediction")
52
 
53
+ symbol = st.text_input("Enter Stock Symbol (e.g., AAPL, TSLA):")
54
 
55
+ if symbol:
56
+ try:
57
+ data = fetch_stock_data(symbol)
58
+ if data.empty:
59
+ st.error("Invalid stock symbol or no data available.")
60
+ else:
61
+ st.subheader("Stock Price History")
62
+ st.write(data)
63
 
64
+ st.subheader("Stock Price Chart")
65
+ fig, ax = plt.subplots()
66
+ ax.plot(data.index, data['Close'], label='Close Price')
67
+ ax.set_xlabel("Date")
68
+ ax.set_ylabel("Price (USD)")
69
+ ax.legend()
70
+ st.pyplot(fig)
71
 
72
+ days_to_predict = st.slider("Days to Predict", 1, 30, 7)
73
+ predictions, scaler = predict_prices(data, days_to_predict)
74
+ predicted_prices = scaler.inverse_transform(predictions)
75
 
76
+ st.subheader("Predicted Stock Prices")
77
+ fig, ax = plt.subplots()
78
+ ax.plot(data.index, data['Close'], label='Historical Close Price')
79
+ ax.plot(pd.date_range(data.index[-1], periods=days_to_predict+1, freq='D')[1:], predicted_prices, linestyle='dashed', label='Predicted Price')
80
+ ax.set_xlabel("Date")
81
+ ax.set_ylabel("Price (USD)")
82
+ ax.legend()
83
+ st.pyplot(fig)
84
 
85
+ csv = data.to_csv().encode('utf-8')
86
+ st.download_button("Download CSV", csv, f"{symbol}_data.csv", "text/csv")
87
+ except Exception as e:
88
+ st.error(f"An error occurred: {e}")
89
 
90
+ if __name__ == "main":
91
+ main()