| import pandas as pd |
| from datetime import datetime |
| from datetime import timedelta |
| import numpy as np |
| import statsmodels.api as sm |
|
|
| import plotly.express as px |
| import plotly.graph_objects as go |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| |
|
|
| df = pd.read_csv('us-shareprices-daily.csv', sep=';') |
|
|
| def get_model_accuracy(data, ticker_symbol): |
| |
| stock_data = data[data['Ticker'] == ticker_symbol] |
|
|
|
|
| |
|
|
| train_data, test_data = stock_data[0:int(len(stock_data)*0.85)], stock_data[int(len(stock_data)*0.85):] |
| training_data = train_data['Close'].values |
| test_data = test_data['Close'].values |
| history = [x for x in training_data] |
| model_predictions = [] |
| N_test_observations = len(test_data) |
| for time_point in range(N_test_observations): |
| model = sm.tsa.statespace.SARIMAX(history, order=(1,1,1)) |
| model_fit = model.fit(disp=0) |
| output = model_fit.forecast() |
| yhat = output[0] |
| model_predictions.append(yhat) |
| true_test_value = test_data[time_point] |
| history.append(true_test_value) |
|
|
| MSE_error = mean_squared_error(test_data, model_predictions) |
| return 'Testing Mean Squared Error is {}'.format(MSE_error) |
|
|
|
|
| def arima_chart(tickers): |
| df = pd.read_csv('data_and_sp500.csv') |
| df = df[['Date']+tickers] |
| fig = px.line(df, x='Date', y=df.columns) |
|
|
| for ticker in tickers: |
| x = np.array(df['Date']) |
| y = np.array(df[ticker]) |
| ticker_df = pd.concat([df['Date'], df[ticker]], axis=1) |
|
|
| model = sm.tsa.statespace.SARIMAX(ticker_df[ticker], order=(21,1,7)) |
| model_fit = model.fit(disp=-1) |
| |
| forecast = model_fit.forecast(7, alpha=0.05) |
| begin_date = datetime.strptime('2021-10-22', '%Y-%m-%d') |
| forecast_dates = [begin_date+timedelta(days=i-1258) for i in forecast.index] |
| fig.add_trace(go.Scatter(x=forecast_dates, y=forecast.to_list(), |
| mode='lines', |
| name='{} forecast'.format(ticker))) |
|
|
| fig.update_xaxes(range=[begin_date-timedelta(days=120), begin_date+timedelta(days=10)]) |
| st.plotly_chart(fig, use_container_width=True) |
|
|