Spaces:
Runtime error
Runtime error
| import datetime | |
| from pydantic import BaseModel, Field | |
| from typing import Dict, List, Optional | |
| import yfinance as yf | |
| import plotly.graph_objs as go | |
| import plotly.express as px | |
| from prophet import Prophet | |
| from workcell.integrations.types import PlotlyPlot | |
| class Input(BaseModel): | |
| ticker: str = Field(default="AAPL", description="A ticker value, like `AAPL`, etc...") | |
| def load_data(ticker): | |
| """Download ticker price data from ticker. | |
| e.g. ticker = 'AAPL'|'AMZN'|'GOOG' | |
| """ | |
| start = datetime.datetime(2022, 1, 1) | |
| end = datetime.datetime.now() # latest | |
| data = yf.download(ticker, start=start, end=end, interval='1d') | |
| # adjust close | |
| close = data['Adj Close'] | |
| return close | |
| def preprocess_data(df): | |
| """ | |
| Preprocess dataframe for prediction. | |
| - Filter out predict value. | |
| """ | |
| # post process | |
| df_processed = df.reset_index() | |
| df_processed.rename(columns={'Adj Close': 'y', 'Date': 'ds'}, inplace=True) | |
| return df_processed | |
| def predict_data(df, periods=30): | |
| """Predict future prices by prophet. | |
| e.g. df = preprocess_df(df) | |
| """ | |
| # init prophet model | |
| model = Prophet() | |
| # fit | |
| model.fit(df) | |
| # predict data | |
| future_prices = model.make_future_dataframe(periods=periods) | |
| forecast = model.predict(future_prices) | |
| # forecast data | |
| df_forecast = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']] | |
| return df_forecast | |
| def visualization(df_processed, df_forecast, ticker): | |
| """Visualization price plot by df_forecast dataframe. | |
| """ | |
| trace_open = go.Scatter( | |
| x = df_forecast["ds"], | |
| y = df_forecast["yhat"], | |
| mode = 'lines', | |
| name="Forecast" | |
| ) | |
| trace_high = go.Scatter( | |
| x = df_forecast["ds"], | |
| y = df_forecast["yhat_upper"], | |
| mode = 'lines', | |
| fill = "tonexty", | |
| line = {"color": "#57b8ff"}, | |
| name="Higher uncertainty interval" | |
| ) | |
| trace_low = go.Scatter( | |
| x = df_forecast["ds"], | |
| y = df_forecast["yhat_lower"], | |
| mode = 'lines', | |
| fill = "tonexty", | |
| line = {"color": "#57b8ff"}, | |
| name="Lower uncertainty interval" | |
| ) | |
| trace_close = go.Scatter( | |
| x = df_processed["ds"], | |
| y = df_processed["y"], | |
| name="Data values" | |
| ) | |
| data = [trace_open,trace_high,trace_low,trace_close] | |
| layout = go.Layout(title="Repsol Stock Price Forecast for: {}".format(ticker), xaxis_rangeslider_visible=True) | |
| fig = go.Figure(data=data,layout=layout) | |
| fig.update_xaxes( | |
| rangeslider_visible=True, | |
| rangeselector=dict( | |
| buttons=list([ | |
| dict(count=1, label="1m", step="month", stepmode="backward"), | |
| dict(count=6, label="6m", step="month", stepmode="backward"), | |
| dict(count=1, label="YTD", step="year", stepmode="todate"), | |
| dict(count=1, label="1y", step="year", stepmode="backward"), | |
| dict(step="all") | |
| ]) | |
| ) | |
| ) | |
| fig.update_layout( | |
| hovermode="x", | |
| legend=dict( | |
| yanchor="top", | |
| y=0.99, | |
| xanchor="left", | |
| x=0.01 | |
| ) | |
| ) | |
| return fig | |
| def stock_predictor(input: Input) -> PlotlyPlot: | |
| """Input ticker, predict stocks price in 30 days by prophet. Data from yahoo finance.""" | |
| # Step1. load data & preprocess | |
| df = load_data(input.ticker) | |
| df_processed = preprocess_data(df) | |
| # Step2. predict | |
| df_forecast = predict_data(df_processed) | |
| # Step3. visualization | |
| fig = visualization(df_processed, df_forecast, input.ticker) | |
| # Step3. wrapped by output | |
| output = PlotlyPlot(data=fig) | |
| return output |