Spaces:
Running
Running
| import logging | |
| import os | |
| import dotenv | |
| import matplotlib.pyplot as plt | |
| import matplotlib.dates as mdates | |
| from matplotlib.axes import Axes | |
| import numpy as np | |
| import pandas as pd | |
| from ta.volume import volume_weighted_average_price | |
| from ta.momentum import RSIIndicator, StochasticOscillator | |
| from ta.trend import MACD | |
| class TechnicalAnalysis(): | |
| def __init__( | |
| self, | |
| ticker: str, | |
| df_hist: pd.DataFrame, | |
| df_past=None, | |
| df_fcst=None, | |
| plot_ta:bool=True, | |
| savefig:bool=False, | |
| debug=False): | |
| # input arguments | |
| """ | |
| Initialize TechnicalAnalysis object. | |
| Args: | |
| ticker : str | |
| stock ticker to analyze | |
| df_hist: pd.DataFrame | |
| historical price data for ticker | |
| df_past: pd.DataFrame, optional, default: None | |
| Closeing price of the ticker for the past few days | |
| df_fcst: pd.DataFrame, optional, default: None | |
| Forecasted closing price and relative returns nextf few days | |
| plot_ta : bool, optional, default: True | |
| whether to generate plots of technical analysis metrics. Plot will be created under plots/{ticker}.png | |
| debug : bool, optional, default: False | |
| whether run in debug mode, so that logging should be produced at debug level | |
| """ | |
| # set up logging | |
| if debug: | |
| self.logger_level = logging.DEBUG | |
| else: | |
| self.logger_level = logging.INFO | |
| self.logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=self.logger_level) # filename='TechnicalAnalysis.log', | |
| # input arguments | |
| self.ticker = ticker | |
| self.df_hist = df_hist | |
| self.df_past = df_past | |
| self.df_fcst = df_fcst | |
| self.plot_ta = plot_ta | |
| self.savefig = savefig | |
| # done initializing | |
| self.logger.info(f'Initialized TechnicalAnalysis object for ticker: {ticker}') | |
| def run( | |
| self | |
| ) -> None: | |
| """ | |
| Main entry point for the TechnicalAnalysis object. | |
| This method: | |
| - computes the technical analysis metrics | |
| - plots the price and TA metrics. | |
| """ | |
| df = self.df_hist | |
| # add the features based on technical analysis | |
| if df.shape[0] > 0: | |
| df = self.tech_analysis(df) | |
| # Merge with forecast data | |
| df_merged = self.merge_hist_with_forecast(df, self.df_past, self.df_fcst) | |
| # plot the results | |
| if self.plot_ta: | |
| os.makedirs('plots', exist_ok=True) | |
| fig = self.plot_stock_metrics( | |
| df_merged, | |
| datasets={ | |
| 'Volume': ['Volume'], | |
| 'Indices': ['RSI', 'StochOsc'], | |
| 'Trend': ['MACD', 'MACDsig', 'MACDdif'], | |
| 'Prices': ['Close', 'VWAP'] # 'High','Low', | |
| }, | |
| savefig=self.savefig | |
| ) | |
| else: | |
| fig = None | |
| else: | |
| if self.plot_ta: | |
| fig = self.get_fetcherror_fig(message='failed fetching data') | |
| else: | |
| fig = None | |
| return df, fig | |
| def tech_analysis( | |
| self, | |
| df: pd.DataFrame | |
| ) -> pd.DataFrame: | |
| """ | |
| Calculates technical analysis indicators for the fetched stock price data. | |
| This method takes the fetched stock price data and calculates several | |
| technical analysis indicators. The following indicators are calculated: | |
| - Additional Price Indicators: | |
| - Volume-Weighted Average Price (VWAP) | |
| - Momentum Indicators: | |
| - Relative Strength Index (RSI) | |
| - Stochastic Oscillator | |
| - Trend Indicators: | |
| - Moving Average Convergence Divergence (MACD) | |
| The calculated indicators are added to the DataFrame as new columns. | |
| Args: | |
| df: pd.DataFrame | |
| The DataFrame containing the fetched stock price data. | |
| Returns: | |
| pd.DataFrame | |
| The DataFrame with the calculated technical analysis indicators. | |
| """ | |
| # Price Indicators | |
| # Volume-Weighted Average Price (VWAP) | |
| # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-overlays/volume-weighted-average-price-vwap | |
| df['VWAP'] = volume_weighted_average_price( | |
| high=df['High'], | |
| low=df['Low'], | |
| close=df['Close'], | |
| volume=df['Volume'], | |
| ) | |
| # Indices | |
| # RSI: | |
| # https://www.investopedia.com/terms/r/rsi.asp | |
| df['RSI'] = RSIIndicator( | |
| df['Close'], | |
| window=14).rsi() | |
| # Stochastic Oscillator: | |
| # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/stochastic-oscillator-fast-slow-and-full | |
| df['StochOsc'] = StochasticOscillator( | |
| df['High'], | |
| df['Low'], | |
| df['Close'], | |
| window=14).stoch() | |
| # Trend signals | |
| # Moving Average Convergence Divergence (MACD): | |
| # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/macd-moving-average-convergence-divergence-oscillator | |
| macd = MACD( | |
| df['Close'], | |
| window_slow=26, | |
| window_fast=12, | |
| window_sign=9) | |
| df['MACD'] = macd.macd() | |
| df['MACDsig'] = macd.macd_signal() | |
| df['MACDdif'] = macd.macd_diff() | |
| return df | |
| def merge_hist_with_forecast(self, df_hist: pd.DataFrame, df_past: pd.DataFrame | None, df_fcst: pd.DataFrame | None) -> pd.DataFrame: | |
| # make sure we are merging the right thing | |
| """ | |
| Merge historical data with forecast data. If forecast data is available, merge it with historical data based on date. | |
| If forecast data is not available, return the historical data as is. | |
| Args: | |
| df_hist: pd.DataFrame | |
| Historical data | |
| df_past: pd.DataFrame | |
| Recent data used for comparison | |
| df_fcst: pd.DataFrame or None | |
| Forecast data | |
| Returns: | |
| df_merged: pd.DataFrame | |
| Merged data | |
| """ | |
| if df_fcst is not None: | |
| # Make sure that the previous hist close price is matching to that of the past close price | |
| assert df_hist.Close.iloc[-2] == df_past.Close.iloc[-2] | |
| df_hist.reset_index(inplace=True) | |
| # in case there are overlapping dates, make sure to remove them | |
| df_fcst = df_fcst.loc[~df_fcst["Date"].isin(df_hist["Date"]), ["Date", "Close"]] | |
| date_diff = df_fcst.Date.iloc[0] - df_hist.Date.iloc[-1] | |
| if date_diff > pd.Timedelta('3 days'): | |
| self.logger.warning(f'Date diff between the first forecast and the last hist is {date_diff}') | |
| df_merged = pd.concat([df_hist, df_fcst], ignore_index=True) | |
| df_merged.set_index("Date", inplace=True) | |
| else: | |
| df_merged = df_hist | |
| return df_merged | |
| def plot_stock_metrics( | |
| self, | |
| df, | |
| datasets, | |
| savefig=False | |
| ) -> None: | |
| """ | |
| Plots the given stock metrics datasets as subplots. | |
| This method takes in a DataFrame and a dictionary of datasets, where | |
| each key is a dataset name and the value is a list of column names. | |
| The method creates a figure with subplots for each dataset and plots | |
| the corresponding columns of the DataFrame. | |
| The figure is then saved to a file in the 'plots' directory in png format | |
| with the ticker symbol as the filename. | |
| Args: | |
| df (pd.DataFrame) | |
| The DataFrame to plot | |
| datasets (dict) | |
| A dictionary of datasets, where each key is a dataset name and | |
| the value is a list of column names to be plotted | |
| savefig (bool) | |
| Whether to save the figure to a file | |
| """ | |
| numax = len(datasets) | |
| fig, axes = plt.subplots( | |
| nrows=numax, | |
| ncols=1, | |
| figsize=(6, 3*numax)) | |
| for i, ax in enumerate(axes.flat): | |
| dataset = list(datasets.keys())[i] | |
| colstoplot = datasets[dataset] | |
| self.plot_stock_metrics_ax( | |
| ax, | |
| dataset, | |
| df, | |
| colstoplot) | |
| plt.tight_layout() | |
| if savefig: | |
| rootdir = os.path.dirname(os.path.dirname(__file__)) | |
| fname = os.path.join(rootdir, 'plots', f'{self.ticker}.png') | |
| plt.savefig(fname) | |
| self.logger.info(f'Saved figure into: {fname}') | |
| plt.close() | |
| fig = None | |
| return fig | |
| def plot_stock_metrics_ax( | |
| self, | |
| ax:Axes, | |
| dataset:str, | |
| df:pd.DataFrame, | |
| colstoplot:list) -> None: | |
| """ | |
| Plots specified stock metrics on the provided Axes object. | |
| This function takes in an Axes object and plots the specified columns from | |
| the DataFrame `df` on it. It formats the x-axis with major ticks set to | |
| every Monday and minor ticks set to every day. The plot includes a title, | |
| x and y labels, and optionally a legend if more than one column is plotted. | |
| Additional shaded regions are added for certain datasets. | |
| Args: | |
| ax (matplotlib.axes.Axes): The flattened axes to plot on. | |
| dataset (str): The name of the dataset, used for the title and y-label. | |
| df (pd.DataFrame): The DataFrame containing the data to plot. | |
| colstoplot (list): A list of column names to plot from the DataFrame. | |
| Note: | |
| - If `dataset` is 'Index' or 'Indices', the y-axis is limited to [0, 100] | |
| and a shaded region between y=30 and y=70 is added. | |
| - If `dataset` is 'Price' or 'Prices', a shaded region between 'Low' and | |
| 'High' columns is added. | |
| """ | |
| self.logger.info(f'plotting {colstoplot} in {dataset}') | |
| colorcycle = ['black', 'blue', 'green', 'orange'] | |
| for i, col in enumerate(colstoplot): | |
| ax.plot( | |
| df.index, | |
| df[col], | |
| color=colorcycle[i], | |
| label=col, | |
| linewidth=2) | |
| if dataset in ['Index', 'Indices']: | |
| ax.set_ylim([0, 100]) | |
| # Add a transparent shaded region between y=30 and y=70 | |
| ax.fill_between(df.index, 30, 70, color='gray', alpha=0.3, label='30-70 Range') | |
| elif dataset in ['Price', 'Prices']: | |
| # Add a transparent shaded region daily lows and highs | |
| ax.fill_between(df.index, df['Low'], df['High'], color='gray', alpha=0.3, label='Price Range') | |
| # extract the Price rows for which 'High's are NaN | |
| nanind = np.where(df.High.isna()) | |
| df_fcst = df['Close'].iloc[nanind] | |
| if df_fcst.shape[0] > 0: | |
| # append the last day before the forecasts | |
| nonnanind = np.where(df.High.notna()) | |
| df_now = df['Close'].iloc[[nonnanind[0][-1]]] | |
| df_now_fcst = pd.concat([df_now, df_fcst]) | |
| # connect the last available day and forecasts with a red line | |
| ax.plot(df_now_fcst.index, df_now_fcst, color='red') | |
| # plot only the forecasts with markers | |
| ax.plot(df_fcst.index, df_fcst, color='red', marker='*', label='Forecast') | |
| # else: | |
| # # plot a transparent line across the full df.index just to make sure that x-axis limits are identical for all panels | |
| # ax.plot(df.index, df[col], color='gray', alpha=0.0) | |
| # # ax.fill_between(df.index, 30, 70, color='gray', alpha=0.0) | |
| ax.set_xlim([df.index.min()-pd.Timedelta(days=5), df.index.max()+pd.Timedelta(days=5)]) | |
| # Format major ticks with year | |
| # Set major ticks (every Monday with labels) | |
| ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO)) | |
| # ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d')) | |
| ax.xaxis.set_major_formatter(mdates.DateFormatter('%m.%d')) | |
| # Set minor ticks (every day, but without labels) | |
| ax.xaxis.set_minor_locator(mdates.DayLocator()) | |
| plt.setp(ax.get_xticklabels(), rotation=90, ha='center') | |
| ax.set_ylabel(dataset) | |
| ax.grid(True, linestyle='--', alpha=0.7) | |
| ax.set_title(dataset) | |
| # ax.set_xlabel('Date') | |
| if len(colstoplot) > 1: | |
| ax.legend(loc='upper left') | |
| def get_fetcherror_fig( | |
| self, | |
| message | |
| ) -> plt.Figure: | |
| """ | |
| Fetches images/plot_error.png, annotates it and returns it as a matplotlib.pyplot.Figure object | |
| Args: | |
| message (str): message to be annotated on the displayed image | |
| Returns: | |
| plt.Figure: figure object containing the annotated image | |
| """ | |
| fig, ax = plt.subplots( | |
| figsize=(5, 5) | |
| ) | |
| # Load and display the image | |
| parentdir = os.path.dirname(os.path.dirname(__file__)) | |
| fname = os.path.join(parentdir, 'images', 'plot_error.png') | |
| img = plt.imread(fname) | |
| ax.imshow(img) | |
| # Remove axes ticks and labels | |
| # ax.set_xticks([]) | |
| # ax.set_yticks([]) | |
| ax.axis('off') # removes both ticks and axes lines completely | |
| ax.text(0.5, 0.05, message, fontsize=20, ha='center', va='center', transform=ax.transAxes) | |
| return fig | |
| if __name__ == '__main__': | |
| ticker = 'AAPL' | |
| # testing | |
| from src.fetch_forecast import FetchForecast | |
| from src.fetch_data import FetchData | |
| dotenv.load_dotenv(dotenv.find_dotenv()) | |
| df_hist = FetchData(ticker, fetchperiodinweeks=12).run() | |
| df_past, df_fcst = FetchForecast(ticker, df_hist).run() | |
| df, fig = TechnicalAnalysis(ticker, df_hist=df_hist, df_past=df_past, df_fcst=df_fcst, plot_ta=True, savefig=True, debug=False).run() | |
| # print(f'columns: {df.columns}') | |