Spaces:
Running
Running
| import datetime as dt | |
| import logging | |
| import matplotlib.pyplot as plt | |
| import matplotlib.dates as mdates | |
| from matplotlib.axes import Axes | |
| import os | |
| import pandas as pd | |
| from ta.volume import volume_weighted_average_price | |
| from ta.momentum import RSIIndicator, StochasticOscillator | |
| from ta.trend import MACD | |
| import yfinance as yf | |
| class TechnicalAnalysis(): | |
| def __init__( | |
| self, | |
| ticker:str, | |
| fetchperiodinweeks:int=12, | |
| plot_ta:bool=True, | |
| savefig:bool=False, | |
| debug=False): | |
| # input arguments | |
| """ | |
| Initialize TechnicalAnalysis object. | |
| Args: | |
| ticker : str | |
| stock ticker to analyze | |
| fetchperiodinweeks : int, optional, default: 8 | |
| number of weeks to fetch historical data from YahooFinance | |
| plot_ta : bool, optional, default: True | |
| whether to generate plots of technical analysis metrics. Plot will be created under plots/{ticker}.png | |
| debug : bool, optional, default: False | |
| whether run in debug mode, so that logging should be produced at debug level | |
| """ | |
| # set up logging | |
| if debug: | |
| self.logger_level = logging.DEBUG | |
| else: | |
| self.logger_level = logging.INFO | |
| self.logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=self.logger_level) # filename='TechnicalAnalysis.log', | |
| # input arguments | |
| self.ticker = ticker | |
| self.fetchperiodinweeks = fetchperiodinweeks | |
| self.plot_ta = plot_ta | |
| self.savefig = savefig | |
| # done initializing | |
| self.logger.info(f'Initialized TechnicalAnalysis object for ticker: {ticker}') | |
| def run( | |
| self | |
| ) -> None: | |
| """ | |
| Main entry point for the TechnicalAnalysis object. | |
| This method: | |
| - fetches historical data from YahooFinance, | |
| - computes the technical analysis metrics | |
| - plots the price and TA metrics. | |
| """ | |
| # fetch data from yf | |
| self.df = self.fetch_data() | |
| # add the features based on technical analysis | |
| if self.df.shape[0] > 0: | |
| self.df = self.tech_analysis() | |
| # plot the results | |
| if self.plot_ta: | |
| os.makedirs('plots', exist_ok=True) | |
| fig = self.plot_stock_metrics( | |
| self.df, | |
| datasets={ | |
| 'Volume': ['Volume'], | |
| 'Prices': ['Close', 'VWAP'], # 'High','Low', | |
| 'Indices': ['RSI', 'StochOsc'], | |
| 'Trend': ['MACD', 'MACDsig', 'MACDdif']}, | |
| savefig=self.savefig | |
| ) | |
| else: | |
| fig = None | |
| else: | |
| if self.plot_ta: | |
| fig = self.get_fetcherror_fig(message='failed fetching data') | |
| else: | |
| fig = None | |
| return self.df, fig | |
| def fetch_data( | |
| self | |
| ) -> pd.DataFrame: | |
| """ | |
| Fetches historical stock price data from Yahoo Finance. | |
| This method downloads historical stock price data for the specified | |
| ticker over a given period of weeks. The data is fetched on a daily | |
| interval and stored in a pandas DataFrame. If the download is successful, | |
| redundant ticker columns are removed, and logging information is | |
| recorded. In case of failure, an empty DataFrame is returned and an | |
| exception is raised. | |
| Returns: | |
| pd.DataFrame | |
| A DataFrame containing the historical stock price data with columns | |
| for open, high, low, close, volume, and adjusted close prices. | |
| Raises: | |
| Exception | |
| If the data fetching fails, an exception is raised with an error | |
| message. | |
| """ | |
| period_start = dt.datetime.now() - dt.timedelta(weeks=self.fetchperiodinweeks) | |
| self.logger.info(f'Fetching price data for {self.ticker}') | |
| try: | |
| df = yf.download( | |
| self.ticker, | |
| start=period_start, | |
| end=dt.datetime.now(), | |
| interval='1d' | |
| ) | |
| except Exception as e: | |
| self.logger.error(f'{e}') | |
| # create empty df | |
| df = pd.DataFrame() | |
| if df.shape[0] > 0: | |
| # get rid of the redundant ticker column | |
| df.columns = df.columns.droplevel('Ticker') | |
| self.logger.debug(df.head(10)) | |
| self.logger.info(f'Fetched {df.shape[0]} rows for {self.ticker}') | |
| return df | |
| def tech_analysis( | |
| self | |
| ) -> pd.DataFrame: | |
| """ | |
| Calculates technical analysis indicators for the fetched stock price data. | |
| This method takes the fetched stock price data and calculates several | |
| technical analysis indicators. The following indicators are calculated: | |
| - Additional Price Indicators: | |
| - Volume-Weighted Average Price (VWAP) | |
| - Momentum Indicators: | |
| - Relative Strength Index (RSI) | |
| - Stochastic Oscillator | |
| - Trend Indicators: | |
| - Moving Average Convergence Divergence (MACD) | |
| The calculated indicators are added to the DataFrame as new columns. | |
| Returns: | |
| pd.DataFrame | |
| The DataFrame with the calculated technical analysis indicators. | |
| """ | |
| df = self.df | |
| # Price Indicators | |
| # Volume-Weighted Average Price (VWAP) | |
| # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-overlays/volume-weighted-average-price-vwap | |
| df['VWAP'] = volume_weighted_average_price( | |
| high=df['High'], | |
| low=df['Low'], | |
| close=df['Close'], | |
| volume=df['Volume'], | |
| ) | |
| # Indices | |
| # RSI: | |
| # https://www.investopedia.com/terms/r/rsi.asp | |
| df['RSI'] = RSIIndicator( | |
| df['Close'], | |
| window=14).rsi() | |
| # Stochastic Oscillator: | |
| # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/stochastic-oscillator-fast-slow-and-full | |
| df['StochOsc'] = StochasticOscillator( | |
| df['High'], | |
| df['Low'], | |
| df['Close'], | |
| window=14).stoch() | |
| # Trend signals | |
| # Moving Average Convergence Divergence (MACD): | |
| # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/macd-moving-average-convergence-divergence-oscillator | |
| macd = MACD( | |
| df['Close'], | |
| window_slow=26, | |
| window_fast=12, | |
| window_sign=9) | |
| df['MACD'] = macd.macd() | |
| df['MACDsig'] = macd.macd_signal() | |
| df['MACDdif'] = macd.macd_diff() | |
| return df | |
| def plot_stock_metrics( | |
| self, | |
| df, | |
| datasets={ | |
| 'Volume': ['Volume'], | |
| 'Price': ['Close'] # 'High','Low' | |
| }, | |
| savefig=False | |
| ) -> None: | |
| """ | |
| Plots the given stock metrics datasets as subplots. | |
| This method takes in a DataFrame and a dictionary of datasets, where | |
| each key is a dataset name and the value is a list of column names. | |
| The method creates a figure with subplots for each dataset and plots | |
| the corresponding columns of the DataFrame. | |
| The figure is then saved to a file in the 'plots' directory in png format | |
| with the ticker symbol as the filename. | |
| Args: | |
| df (pd.DataFrame) | |
| The DataFrame to plot | |
| datasets (dict) | |
| A dictionary of datasets, where each key is a dataset name and | |
| the value is a list of column names to be plotted | |
| savefig (bool) | |
| Whether to save the figure to a file | |
| """ | |
| numax = len(datasets) | |
| fig, axes = plt.subplots( | |
| nrows=numax, | |
| ncols=1, | |
| figsize=(10, 5*numax)) | |
| for i, ax in enumerate(axes.flat): | |
| dataset = list(datasets.keys())[i] | |
| colstoplot = datasets[dataset] | |
| self.plot_stock_metrics_ax( | |
| ax, | |
| dataset, | |
| df, | |
| colstoplot) | |
| plt.tight_layout() | |
| if savefig: | |
| rootdir = os.path.dirname(os.path.dirname(__file__)) | |
| fname = os.path.join(rootdir, 'plots', f'{self.ticker}.png') | |
| plt.savefig(fname) | |
| self.logger.info(f'Saved figure into: {fname}') | |
| plt.close() | |
| fig = None | |
| return fig | |
| def plot_stock_metrics_ax( | |
| self, | |
| ax:Axes, | |
| dataset:str, | |
| df:pd.DataFrame, | |
| colstoplot:list) -> None: | |
| """ | |
| Plots specified stock metrics on the provided Axes object. | |
| This function takes in an Axes object and plots the specified columns from | |
| the DataFrame `df` on it. It formats the x-axis with major ticks set to | |
| every Monday and minor ticks set to every day. The plot includes a title, | |
| x and y labels, and optionally a legend if more than one column is plotted. | |
| Additional shaded regions are added for certain datasets. | |
| Args: | |
| ax (matplotlib.axes.Axes): The flattened axes to plot on. | |
| dataset (str): The name of the dataset, used for the title and y-label. | |
| df (pd.DataFrame): The DataFrame containing the data to plot. | |
| colstoplot (list): A list of column names to plot from the DataFrame. | |
| Note: | |
| - If `dataset` is 'Index' or 'Indices', the y-axis is limited to [0, 100] | |
| and a shaded region between y=30 and y=70 is added. | |
| - If `dataset` is 'Price' or 'Prices', a shaded region between 'Low' and | |
| 'High' columns is added. | |
| """ | |
| print(f'plotting {colstoplot} in {dataset}') | |
| colorcycle = ['black', 'blue', 'red', 'green', 'orange'] | |
| for i, col in enumerate(colstoplot): | |
| ax.plot( | |
| df.index, | |
| df[col], | |
| color=colorcycle[i], | |
| label=col, | |
| linewidth=2) | |
| # Format major ticks with year | |
| # Set major ticks (every Monday with labels) | |
| ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO)) | |
| ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d')) | |
| # Set minor ticks (every day, but without labels) | |
| ax.xaxis.set_minor_locator(mdates.DayLocator()) | |
| ax.set_title(dataset) | |
| # ax.set_xlabel('Date') | |
| ax.set_ylabel(dataset) | |
| if len(colstoplot) > 1: | |
| ax.legend() | |
| if dataset in ['Index', 'Indices']: | |
| ax.set_ylim([0, 100]) | |
| # Add a transparent shaded region between y=30 and y=70 | |
| ax.fill_between(df.index, 30, 70, color='gray', alpha=0.3) | |
| if dataset in ['Price', 'Prices']: | |
| # Add a transparent shaded region between y=30 and y=70 | |
| ax.fill_between(df.index, df['Low'], df['High'], color='gray', alpha=0.3) | |
| ax.grid(True, linestyle='--', alpha=0.7) | |
| def get_fetcherror_fig( | |
| self, | |
| message | |
| ) -> plt.Figure: | |
| """ | |
| Fetches images/plot_error.png, annotates it and returns it as a matplotlib.pyplot.Figure object | |
| Args: | |
| message (str): message to be annotated on the displayed image | |
| Returns: | |
| plt.Figure: figure object containing the annotated image | |
| """ | |
| fig, ax = plt.subplots( | |
| figsize=(5, 5) | |
| ) | |
| # Load and display the image | |
| parentdir = os.path.dirname(os.path.dirname(__file__)) | |
| fname = os.path.join(parentdir, 'images', 'plot_error.png') | |
| img = plt.imread(fname) | |
| ax.imshow(img) | |
| # Remove axes ticks and labels | |
| # ax.set_xticks([]) | |
| # ax.set_yticks([]) | |
| ax.axis('off') # removes both ticks and axes lines completely | |
| ax.text(0.5, 0.05, message, fontsize=20, ha='center', va='center', transform=ax.transAxes) | |
| return fig | |
| if __name__ == '__main__': | |
| ticker = 'GOOG' | |
| df, fig = TechnicalAnalysis(ticker, plot_ta=True, savefig=True, debug=False).run() | |
| print(f'columns: {df.columns}') | |