StockAnalysisAgent / src /technical_analysis.py
OnurKerimoglu's picture
technical_analysis: robust figure saving path
6d0b28e
raw
history blame
12.5 kB
import datetime as dt
import logging
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.axes import Axes
import os
import pandas as pd
from ta.volume import volume_weighted_average_price
from ta.momentum import RSIIndicator, StochasticOscillator
from ta.trend import MACD
import yfinance as yf
class TechnicalAnalysis():
def __init__(
self,
ticker:str,
fetchperiodinweeks:int=12,
plot_ta:bool=True,
savefig:bool=False,
debug=False):
# input arguments
"""
Initialize TechnicalAnalysis object.
Args:
ticker : str
stock ticker to analyze
fetchperiodinweeks : int, optional, default: 8
number of weeks to fetch historical data from YahooFinance
plot_ta : bool, optional, default: True
whether to generate plots of technical analysis metrics. Plot will be created under plots/{ticker}.png
debug : bool, optional, default: False
whether run in debug mode, so that logging should be produced at debug level
"""
# set up logging
if debug:
self.logger_level = logging.DEBUG
else:
self.logger_level = logging.INFO
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=self.logger_level) # filename='TechnicalAnalysis.log',
# input arguments
self.ticker = ticker
self.fetchperiodinweeks = fetchperiodinweeks
self.plot_ta = plot_ta
self.savefig = savefig
# done initializing
self.logger.info(f'Initialized TechnicalAnalysis object for ticker: {ticker}')
def run(
self
) -> None:
"""
Main entry point for the TechnicalAnalysis object.
This method:
- fetches historical data from YahooFinance,
- computes the technical analysis metrics
- plots the price and TA metrics.
"""
# fetch data from yf
self.df = self.fetch_data()
# add the features based on technical analysis
if self.df.shape[0] > 0:
self.df = self.tech_analysis()
# plot the results
if self.plot_ta:
os.makedirs('plots', exist_ok=True)
fig = self.plot_stock_metrics(
self.df,
datasets={
'Volume': ['Volume'],
'Prices': ['Close', 'VWAP'], # 'High','Low',
'Indices': ['RSI', 'StochOsc'],
'Trend': ['MACD', 'MACDsig', 'MACDdif']},
savefig=self.savefig
)
else:
fig = None
else:
if self.plot_ta:
fig = self.get_fetcherror_fig(message='failed fetching data')
else:
fig = None
return self.df, fig
def fetch_data(
self
) -> pd.DataFrame:
"""
Fetches historical stock price data from Yahoo Finance.
This method downloads historical stock price data for the specified
ticker over a given period of weeks. The data is fetched on a daily
interval and stored in a pandas DataFrame. If the download is successful,
redundant ticker columns are removed, and logging information is
recorded. In case of failure, an empty DataFrame is returned and an
exception is raised.
Returns:
pd.DataFrame
A DataFrame containing the historical stock price data with columns
for open, high, low, close, volume, and adjusted close prices.
Raises:
Exception
If the data fetching fails, an exception is raised with an error
message.
"""
period_start = dt.datetime.now() - dt.timedelta(weeks=self.fetchperiodinweeks)
self.logger.info(f'Fetching price data for {self.ticker}')
try:
df = yf.download(
self.ticker,
start=period_start,
end=dt.datetime.now(),
interval='1d'
)
except Exception as e:
self.logger.error(f'{e}')
# create empty df
df = pd.DataFrame()
if df.shape[0] > 0:
# get rid of the redundant ticker column
df.columns = df.columns.droplevel('Ticker')
self.logger.debug(df.head(10))
self.logger.info(f'Fetched {df.shape[0]} rows for {self.ticker}')
return df
def tech_analysis(
self
) -> pd.DataFrame:
"""
Calculates technical analysis indicators for the fetched stock price data.
This method takes the fetched stock price data and calculates several
technical analysis indicators. The following indicators are calculated:
- Additional Price Indicators:
- Volume-Weighted Average Price (VWAP)
- Momentum Indicators:
- Relative Strength Index (RSI)
- Stochastic Oscillator
- Trend Indicators:
- Moving Average Convergence Divergence (MACD)
The calculated indicators are added to the DataFrame as new columns.
Returns:
pd.DataFrame
The DataFrame with the calculated technical analysis indicators.
"""
df = self.df
# Price Indicators
# Volume-Weighted Average Price (VWAP)
# https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-overlays/volume-weighted-average-price-vwap
df['VWAP'] = volume_weighted_average_price(
high=df['High'],
low=df['Low'],
close=df['Close'],
volume=df['Volume'],
)
# Indices
# RSI:
# https://www.investopedia.com/terms/r/rsi.asp
df['RSI'] = RSIIndicator(
df['Close'],
window=14).rsi()
# Stochastic Oscillator:
# https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/stochastic-oscillator-fast-slow-and-full
df['StochOsc'] = StochasticOscillator(
df['High'],
df['Low'],
df['Close'],
window=14).stoch()
# Trend signals
# Moving Average Convergence Divergence (MACD):
# https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/macd-moving-average-convergence-divergence-oscillator
macd = MACD(
df['Close'],
window_slow=26,
window_fast=12,
window_sign=9)
df['MACD'] = macd.macd()
df['MACDsig'] = macd.macd_signal()
df['MACDdif'] = macd.macd_diff()
return df
def plot_stock_metrics(
self,
df,
datasets={
'Volume': ['Volume'],
'Price': ['Close'] # 'High','Low'
},
savefig=False
) -> None:
"""
Plots the given stock metrics datasets as subplots.
This method takes in a DataFrame and a dictionary of datasets, where
each key is a dataset name and the value is a list of column names.
The method creates a figure with subplots for each dataset and plots
the corresponding columns of the DataFrame.
The figure is then saved to a file in the 'plots' directory in png format
with the ticker symbol as the filename.
Args:
df (pd.DataFrame)
The DataFrame to plot
datasets (dict)
A dictionary of datasets, where each key is a dataset name and
the value is a list of column names to be plotted
savefig (bool)
Whether to save the figure to a file
"""
numax = len(datasets)
fig, axes = plt.subplots(
nrows=numax,
ncols=1,
figsize=(10, 5*numax))
for i, ax in enumerate(axes.flat):
dataset = list(datasets.keys())[i]
colstoplot = datasets[dataset]
self.plot_stock_metrics_ax(
ax,
dataset,
df,
colstoplot)
plt.tight_layout()
if savefig:
rootdir = os.path.dirname(os.path.dirname(__file__))
fname = os.path.join(rootdir, 'plots', f'{self.ticker}.png')
plt.savefig(fname)
self.logger.info(f'Saved figure into: {fname}')
plt.close()
fig = None
return fig
def plot_stock_metrics_ax(
self,
ax:Axes,
dataset:str,
df:pd.DataFrame,
colstoplot:list) -> None:
"""
Plots specified stock metrics on the provided Axes object.
This function takes in an Axes object and plots the specified columns from
the DataFrame `df` on it. It formats the x-axis with major ticks set to
every Monday and minor ticks set to every day. The plot includes a title,
x and y labels, and optionally a legend if more than one column is plotted.
Additional shaded regions are added for certain datasets.
Args:
ax (matplotlib.axes.Axes): The flattened axes to plot on.
dataset (str): The name of the dataset, used for the title and y-label.
df (pd.DataFrame): The DataFrame containing the data to plot.
colstoplot (list): A list of column names to plot from the DataFrame.
Note:
- If `dataset` is 'Index' or 'Indices', the y-axis is limited to [0, 100]
and a shaded region between y=30 and y=70 is added.
- If `dataset` is 'Price' or 'Prices', a shaded region between 'Low' and
'High' columns is added.
"""
print(f'plotting {colstoplot} in {dataset}')
colorcycle = ['black', 'blue', 'red', 'green', 'orange']
for i, col in enumerate(colstoplot):
ax.plot(
df.index,
df[col],
color=colorcycle[i],
label=col,
linewidth=2)
# Format major ticks with year
# Set major ticks (every Monday with labels)
ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO))
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
# Set minor ticks (every day, but without labels)
ax.xaxis.set_minor_locator(mdates.DayLocator())
ax.set_title(dataset)
# ax.set_xlabel('Date')
ax.set_ylabel(dataset)
if len(colstoplot) > 1:
ax.legend()
if dataset in ['Index', 'Indices']:
ax.set_ylim([0, 100])
# Add a transparent shaded region between y=30 and y=70
ax.fill_between(df.index, 30, 70, color='gray', alpha=0.3)
if dataset in ['Price', 'Prices']:
# Add a transparent shaded region between y=30 and y=70
ax.fill_between(df.index, df['Low'], df['High'], color='gray', alpha=0.3)
ax.grid(True, linestyle='--', alpha=0.7)
def get_fetcherror_fig(
self,
message
) -> plt.Figure:
"""
Fetches images/plot_error.png, annotates it and returns it as a matplotlib.pyplot.Figure object
Args:
message (str): message to be annotated on the displayed image
Returns:
plt.Figure: figure object containing the annotated image
"""
fig, ax = plt.subplots(
figsize=(5, 5)
)
# Load and display the image
parentdir = os.path.dirname(os.path.dirname(__file__))
fname = os.path.join(parentdir, 'images', 'plot_error.png')
img = plt.imread(fname)
ax.imshow(img)
# Remove axes ticks and labels
# ax.set_xticks([])
# ax.set_yticks([])
ax.axis('off') # removes both ticks and axes lines completely
ax.text(0.5, 0.05, message, fontsize=20, ha='center', va='center', transform=ax.transAxes)
return fig
if __name__ == '__main__':
ticker = 'GOOG'
df, fig = TechnicalAnalysis(ticker, plot_ta=True, savefig=True, debug=False).run()
print(f'columns: {df.columns}')