import logging import os import dotenv import matplotlib.pyplot as plt import matplotlib.dates as mdates from matplotlib.axes import Axes import numpy as np import pandas as pd from ta.volume import volume_weighted_average_price from ta.momentum import RSIIndicator, StochasticOscillator from ta.trend import MACD class TechnicalAnalysis(): def __init__( self, ticker: str, df_hist: pd.DataFrame, df_past=None, df_fcst=None, plot_ta:bool=True, savefig:bool=False, debug=False): # input arguments """ Initialize TechnicalAnalysis object. Args: ticker : str stock ticker to analyze df_hist: pd.DataFrame historical price data for ticker df_past: pd.DataFrame, optional, default: None Closeing price of the ticker for the past few days df_fcst: pd.DataFrame, optional, default: None Forecasted closing price and relative returns nextf few days plot_ta : bool, optional, default: True whether to generate plots of technical analysis metrics. Plot will be created under plots/{ticker}.png debug : bool, optional, default: False whether run in debug mode, so that logging should be produced at debug level """ # set up logging if debug: self.logger_level = logging.DEBUG else: self.logger_level = logging.INFO self.logger = logging.getLogger(__name__) logging.basicConfig(level=self.logger_level) # filename='TechnicalAnalysis.log', # input arguments self.ticker = ticker self.df_hist = df_hist self.df_past = df_past self.df_fcst = df_fcst self.plot_ta = plot_ta self.savefig = savefig # done initializing self.logger.info(f'Initialized TechnicalAnalysis object for ticker: {ticker}') def run( self ) -> None: """ Main entry point for the TechnicalAnalysis object. This method: - computes the technical analysis metrics - plots the price and TA metrics. """ df = self.df_hist # add the features based on technical analysis if df.shape[0] > 0: df = self.tech_analysis(df) # Merge with forecast data df_merged = self.merge_hist_with_forecast(df, self.df_past, self.df_fcst) # plot the results if self.plot_ta: os.makedirs('plots', exist_ok=True) fig = self.plot_stock_metrics( df_merged, datasets={ 'Volume': ['Volume'], 'Indices': ['RSI', 'StochOsc'], 'Trend': ['MACD', 'MACDsig', 'MACDdif'], 'Prices': ['Close', 'VWAP'] # 'High','Low', }, savefig=self.savefig ) else: fig = None else: if self.plot_ta: fig = self.get_fetcherror_fig(message='failed fetching data') else: fig = None return df, fig def tech_analysis( self, df: pd.DataFrame ) -> pd.DataFrame: """ Calculates technical analysis indicators for the fetched stock price data. This method takes the fetched stock price data and calculates several technical analysis indicators. The following indicators are calculated: - Additional Price Indicators: - Volume-Weighted Average Price (VWAP) - Momentum Indicators: - Relative Strength Index (RSI) - Stochastic Oscillator - Trend Indicators: - Moving Average Convergence Divergence (MACD) The calculated indicators are added to the DataFrame as new columns. Args: df: pd.DataFrame The DataFrame containing the fetched stock price data. Returns: pd.DataFrame The DataFrame with the calculated technical analysis indicators. """ # Price Indicators # Volume-Weighted Average Price (VWAP) # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-overlays/volume-weighted-average-price-vwap df['VWAP'] = volume_weighted_average_price( high=df['High'], low=df['Low'], close=df['Close'], volume=df['Volume'], ) # Indices # RSI: # https://www.investopedia.com/terms/r/rsi.asp df['RSI'] = RSIIndicator( df['Close'], window=14).rsi() # Stochastic Oscillator: # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/stochastic-oscillator-fast-slow-and-full df['StochOsc'] = StochasticOscillator( df['High'], df['Low'], df['Close'], window=14).stoch() # Trend signals # Moving Average Convergence Divergence (MACD): # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/macd-moving-average-convergence-divergence-oscillator macd = MACD( df['Close'], window_slow=26, window_fast=12, window_sign=9) df['MACD'] = macd.macd() df['MACDsig'] = macd.macd_signal() df['MACDdif'] = macd.macd_diff() return df def merge_hist_with_forecast(self, df_hist: pd.DataFrame, df_past: pd.DataFrame | None, df_fcst: pd.DataFrame | None) -> pd.DataFrame: # make sure we are merging the right thing """ Merge historical data with forecast data. If forecast data is available, merge it with historical data based on date. If forecast data is not available, return the historical data as is. Args: df_hist: pd.DataFrame Historical data df_past: pd.DataFrame Recent data used for comparison df_fcst: pd.DataFrame or None Forecast data Returns: df_merged: pd.DataFrame Merged data """ if df_fcst is not None: # Make sure that the previous hist close price is matching to that of the past close price assert df_hist.Close.iloc[-2] == df_past.Close.iloc[-2] df_hist.reset_index(inplace=True) # in case there are overlapping dates, make sure to remove them df_fcst = df_fcst.loc[~df_fcst["Date"].isin(df_hist["Date"]), ["Date", "Close"]] date_diff = df_fcst.Date.iloc[0] - df_hist.Date.iloc[-1] if date_diff > pd.Timedelta('3 days'): self.logger.warning(f'Date diff between the first forecast and the last hist is {date_diff}') df_merged = pd.concat([df_hist, df_fcst], ignore_index=True) df_merged.set_index("Date", inplace=True) else: df_merged = df_hist return df_merged def plot_stock_metrics( self, df, datasets, savefig=False ) -> None: """ Plots the given stock metrics datasets as subplots. This method takes in a DataFrame and a dictionary of datasets, where each key is a dataset name and the value is a list of column names. The method creates a figure with subplots for each dataset and plots the corresponding columns of the DataFrame. The figure is then saved to a file in the 'plots' directory in png format with the ticker symbol as the filename. Args: df (pd.DataFrame) The DataFrame to plot datasets (dict) A dictionary of datasets, where each key is a dataset name and the value is a list of column names to be plotted savefig (bool) Whether to save the figure to a file """ numax = len(datasets) fig, axes = plt.subplots( nrows=numax, ncols=1, figsize=(6, 3*numax)) for i, ax in enumerate(axes.flat): dataset = list(datasets.keys())[i] colstoplot = datasets[dataset] self.plot_stock_metrics_ax( ax, dataset, df, colstoplot) plt.tight_layout() if savefig: rootdir = os.path.dirname(os.path.dirname(__file__)) fname = os.path.join(rootdir, 'plots', f'{self.ticker}.png') plt.savefig(fname) self.logger.info(f'Saved figure into: {fname}') plt.close() fig = None return fig def plot_stock_metrics_ax( self, ax:Axes, dataset:str, df:pd.DataFrame, colstoplot:list) -> None: """ Plots specified stock metrics on the provided Axes object. This function takes in an Axes object and plots the specified columns from the DataFrame `df` on it. It formats the x-axis with major ticks set to every Monday and minor ticks set to every day. The plot includes a title, x and y labels, and optionally a legend if more than one column is plotted. Additional shaded regions are added for certain datasets. Args: ax (matplotlib.axes.Axes): The flattened axes to plot on. dataset (str): The name of the dataset, used for the title and y-label. df (pd.DataFrame): The DataFrame containing the data to plot. colstoplot (list): A list of column names to plot from the DataFrame. Note: - If `dataset` is 'Index' or 'Indices', the y-axis is limited to [0, 100] and a shaded region between y=30 and y=70 is added. - If `dataset` is 'Price' or 'Prices', a shaded region between 'Low' and 'High' columns is added. """ self.logger.info(f'plotting {colstoplot} in {dataset}') colorcycle = ['black', 'blue', 'green', 'orange'] for i, col in enumerate(colstoplot): ax.plot( df.index, df[col], color=colorcycle[i], label=col, linewidth=2) if dataset in ['Index', 'Indices']: ax.set_ylim([0, 100]) # Add a transparent shaded region between y=30 and y=70 ax.fill_between(df.index, 30, 70, color='gray', alpha=0.3, label='30-70 Range') elif dataset in ['Price', 'Prices']: # Add a transparent shaded region daily lows and highs ax.fill_between(df.index, df['Low'], df['High'], color='gray', alpha=0.3, label='Price Range') # extract the Price rows for which 'High's are NaN nanind = np.where(df.High.isna()) df_fcst = df['Close'].iloc[nanind] if df_fcst.shape[0] > 0: # append the last day before the forecasts nonnanind = np.where(df.High.notna()) df_now = df['Close'].iloc[[nonnanind[0][-1]]] df_now_fcst = pd.concat([df_now, df_fcst]) # connect the last available day and forecasts with a red line ax.plot(df_now_fcst.index, df_now_fcst, color='red') # plot only the forecasts with markers ax.plot(df_fcst.index, df_fcst, color='red', marker='*', label='Forecast') # else: # # plot a transparent line across the full df.index just to make sure that x-axis limits are identical for all panels # ax.plot(df.index, df[col], color='gray', alpha=0.0) # # ax.fill_between(df.index, 30, 70, color='gray', alpha=0.0) ax.set_xlim([df.index.min()-pd.Timedelta(days=5), df.index.max()+pd.Timedelta(days=5)]) # Format major ticks with year # Set major ticks (every Monday with labels) ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO)) # ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d')) ax.xaxis.set_major_formatter(mdates.DateFormatter('%m.%d')) # Set minor ticks (every day, but without labels) ax.xaxis.set_minor_locator(mdates.DayLocator()) plt.setp(ax.get_xticklabels(), rotation=90, ha='center') ax.set_ylabel(dataset) ax.grid(True, linestyle='--', alpha=0.7) ax.set_title(dataset) # ax.set_xlabel('Date') if len(colstoplot) > 1: ax.legend(loc='upper left') def get_fetcherror_fig( self, message ) -> plt.Figure: """ Fetches images/plot_error.png, annotates it and returns it as a matplotlib.pyplot.Figure object Args: message (str): message to be annotated on the displayed image Returns: plt.Figure: figure object containing the annotated image """ fig, ax = plt.subplots( figsize=(5, 5) ) # Load and display the image parentdir = os.path.dirname(os.path.dirname(__file__)) fname = os.path.join(parentdir, 'images', 'plot_error.png') img = plt.imread(fname) ax.imshow(img) # Remove axes ticks and labels # ax.set_xticks([]) # ax.set_yticks([]) ax.axis('off') # removes both ticks and axes lines completely ax.text(0.5, 0.05, message, fontsize=20, ha='center', va='center', transform=ax.transAxes) return fig if __name__ == '__main__': ticker = 'AAPL' # testing from src.fetch_forecast import FetchForecast from src.fetch_data import FetchData dotenv.load_dotenv(dotenv.find_dotenv()) df_hist = FetchData(ticker, fetchperiodinweeks=12).run() df_past, df_fcst = FetchForecast(ticker, df_hist).run() df, fig = TechnicalAnalysis(ticker, df_hist=df_hist, df_past=df_past, df_fcst=df_fcst, plot_ta=True, savefig=True, debug=False).run() # print(f'columns: {df.columns}')