Spaces:
Running
Running
Commit
·
f680f62
1
Parent(s):
8ba4914
introduced src.technical_analysis.py
Browse files- src/technical_analysis.py +277 -0
src/technical_analysis.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime as dt
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.dates as mdates
|
| 6 |
+
from matplotlib.axes import Axes
|
| 7 |
+
import os
|
| 8 |
+
import pandas as pd
|
| 9 |
+
from ta.volume import volume_weighted_average_price
|
| 10 |
+
from ta.momentum import RSIIndicator, StochasticOscillator
|
| 11 |
+
from ta.trend import MACD
|
| 12 |
+
import yfinance as yf
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TechnicalAnalysis():
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
ticker:str,
|
| 19 |
+
fetchperiodinweeks:int=8,
|
| 20 |
+
plot_ta:bool=True,
|
| 21 |
+
debug=False):
|
| 22 |
+
# input arguments
|
| 23 |
+
"""
|
| 24 |
+
Initialize TechnicalAnalysis object.
|
| 25 |
+
Args:
|
| 26 |
+
ticker : str
|
| 27 |
+
stock ticker to analyze
|
| 28 |
+
fetchperiodinweeks : int, optional, default: 8
|
| 29 |
+
number of weeks to fetch historical data from YahooFinance
|
| 30 |
+
plot_ta : bool, optional, default: True
|
| 31 |
+
whether to generate plots of technical analysis metrics. Plot will be created under plots/{ticker}.png
|
| 32 |
+
debug : bool, optional, default: False
|
| 33 |
+
whether run in debug mode, so that logging should be produced at debug level
|
| 34 |
+
"""
|
| 35 |
+
self.ticker = ticker
|
| 36 |
+
self.fetchperiodinweeks = fetchperiodinweeks
|
| 37 |
+
self.plot_ta = plot_ta
|
| 38 |
+
|
| 39 |
+
# set up logging
|
| 40 |
+
if debug:
|
| 41 |
+
self.logger_level = logging.DEBUG
|
| 42 |
+
else:
|
| 43 |
+
self.logger_level = logging.INFO
|
| 44 |
+
self.logger = logging.getLogger(__name__)
|
| 45 |
+
logging.basicConfig(level=self.logger_level) # filename='TechnicalAnalysis.log',
|
| 46 |
+
|
| 47 |
+
def run(
|
| 48 |
+
self
|
| 49 |
+
) -> None:
|
| 50 |
+
# fetch data from yf
|
| 51 |
+
"""
|
| 52 |
+
Main entry point for the TechnicalAnalysis object.
|
| 53 |
+
This method:
|
| 54 |
+
- fetches historical data from YahooFinance,
|
| 55 |
+
- computes the technical analysis metrics
|
| 56 |
+
- plots the price and TA metrics.
|
| 57 |
+
"""
|
| 58 |
+
self.df = self.fetch_data()
|
| 59 |
+
# add the features based on technical analysis
|
| 60 |
+
self.df = self.tech_analysis()
|
| 61 |
+
# plot the results
|
| 62 |
+
if self.plot_ta:
|
| 63 |
+
os.makedirs('plots', exist_ok=True)
|
| 64 |
+
self.plot_stock_metrics(
|
| 65 |
+
self.df,
|
| 66 |
+
datasets={
|
| 67 |
+
'Volume': ['Volume'],
|
| 68 |
+
'Prices': ['Close', 'VWAP'], # 'High','Low',
|
| 69 |
+
'Indices': ['RSI', 'StochOsc'],
|
| 70 |
+
'Trend': ['MACD', 'MACDsig', 'MACDdif']}
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def fetch_data(
|
| 75 |
+
self
|
| 76 |
+
) -> pd.DataFrame:
|
| 77 |
+
"""
|
| 78 |
+
Fetches historical stock price data from Yahoo Finance.
|
| 79 |
+
This method downloads historical stock price data for the specified
|
| 80 |
+
ticker over a given period of weeks. The data is fetched on a daily
|
| 81 |
+
interval and stored in a pandas DataFrame. If the download is successful,
|
| 82 |
+
redundant ticker columns are removed, and logging information is
|
| 83 |
+
recorded. In case of failure, an empty DataFrame is returned and an
|
| 84 |
+
exception is raised.
|
| 85 |
+
Returns:
|
| 86 |
+
pd.DataFrame
|
| 87 |
+
A DataFrame containing the historical stock price data with columns
|
| 88 |
+
for open, high, low, close, volume, and adjusted close prices.
|
| 89 |
+
Raises:
|
| 90 |
+
Exception
|
| 91 |
+
If the data fetching fails, an exception is raised with an error
|
| 92 |
+
message.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
period_start = dt.datetime.now() - dt.timedelta(weeks=self.fetchperiodinweeks)
|
| 96 |
+
self.logger.info(f'Fetching price data for {ticker}')
|
| 97 |
+
try:
|
| 98 |
+
df = yf.download(
|
| 99 |
+
self.ticker,
|
| 100 |
+
start=period_start,
|
| 101 |
+
end=dt.datetime.now(),
|
| 102 |
+
interval='1d'
|
| 103 |
+
)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
self.logger.error(f'{e}')
|
| 106 |
+
# create empty df
|
| 107 |
+
df = pd.DataFrame()
|
| 108 |
+
|
| 109 |
+
if df.shape[0] > 0:
|
| 110 |
+
# get rid of the redundant ticker column
|
| 111 |
+
df.columns = df.columns.droplevel('Ticker')
|
| 112 |
+
self.logger.debug(df.head(10))
|
| 113 |
+
self.logger.info(f'Fetched {df.shape[0]} rows for {ticker}')
|
| 114 |
+
else:
|
| 115 |
+
raise Exception(f'Failed to fetch data for {ticker}')
|
| 116 |
+
|
| 117 |
+
return df
|
| 118 |
+
|
| 119 |
+
def tech_analysis(
|
| 120 |
+
self
|
| 121 |
+
) -> pd.DataFrame:
|
| 122 |
+
"""
|
| 123 |
+
Calculates technical analysis indicators for the fetched stock price data.
|
| 124 |
+
This method takes the fetched stock price data and calculates several
|
| 125 |
+
technical analysis indicators. The following indicators are calculated:
|
| 126 |
+
- Additional Price Indicators:
|
| 127 |
+
- Volume-Weighted Average Price (VWAP)
|
| 128 |
+
- Momentum Indicators:
|
| 129 |
+
- Relative Strength Index (RSI)
|
| 130 |
+
- Stochastic Oscillator
|
| 131 |
+
- Trend Indicators:
|
| 132 |
+
- Moving Average Convergence Divergence (MACD)
|
| 133 |
+
The calculated indicators are added to the DataFrame as new columns.
|
| 134 |
+
Returns:
|
| 135 |
+
pd.DataFrame
|
| 136 |
+
The DataFrame with the calculated technical analysis indicators.
|
| 137 |
+
"""
|
| 138 |
+
df = self.df
|
| 139 |
+
|
| 140 |
+
# Price Indicators
|
| 141 |
+
# Volume-Weighted Average Price (VWAP)
|
| 142 |
+
# https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-overlays/volume-weighted-average-price-vwap
|
| 143 |
+
df['VWAP'] = volume_weighted_average_price(
|
| 144 |
+
high=df['High'],
|
| 145 |
+
low=df['Low'],
|
| 146 |
+
close=df['Close'],
|
| 147 |
+
volume=df['Volume'],
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Indices
|
| 151 |
+
# RSI:
|
| 152 |
+
# https://www.investopedia.com/terms/r/rsi.asp
|
| 153 |
+
df['RSI'] = RSIIndicator(
|
| 154 |
+
df['Close'],
|
| 155 |
+
window=14).rsi()
|
| 156 |
+
# Stochastic Oscillator:
|
| 157 |
+
# https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/stochastic-oscillator-fast-slow-and-full
|
| 158 |
+
df['StochOsc'] = StochasticOscillator(
|
| 159 |
+
df['High'],
|
| 160 |
+
df['Low'],
|
| 161 |
+
df['Close'],
|
| 162 |
+
window=14).stoch()
|
| 163 |
+
|
| 164 |
+
# Trend signals
|
| 165 |
+
# Moving Average Convergence Divergence (MACD):
|
| 166 |
+
# https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/macd-moving-average-convergence-divergence-oscillator
|
| 167 |
+
macd = MACD(
|
| 168 |
+
df['Close'],
|
| 169 |
+
window_slow=26,
|
| 170 |
+
window_fast=12,
|
| 171 |
+
window_sign=9)
|
| 172 |
+
df['MACD'] = macd.macd()
|
| 173 |
+
df['MACDsig'] = macd.macd_signal()
|
| 174 |
+
df['MACDdif'] = macd.macd_diff()
|
| 175 |
+
|
| 176 |
+
return df
|
| 177 |
+
|
| 178 |
+
def plot_stock_metrics(
|
| 179 |
+
self,
|
| 180 |
+
df,
|
| 181 |
+
datasets={
|
| 182 |
+
'Volume': ['Volume'],
|
| 183 |
+
'Price': ['Close'] # 'High','Low'
|
| 184 |
+
}
|
| 185 |
+
) -> None:
|
| 186 |
+
"""
|
| 187 |
+
Plots the given stock metrics datasets as subplots.
|
| 188 |
+
This method takes in a DataFrame and a dictionary of datasets, where
|
| 189 |
+
each key is a dataset name and the value is a list of column names.
|
| 190 |
+
The method creates a figure with subplots for each dataset and plots
|
| 191 |
+
the corresponding columns of the DataFrame.
|
| 192 |
+
The figure is then saved to a file in the 'plots' directory in png format
|
| 193 |
+
with the ticker symbol as the filename.
|
| 194 |
+
Args:
|
| 195 |
+
df (pd.DataFrame)
|
| 196 |
+
The DataFrame to plot
|
| 197 |
+
datasets (dict)
|
| 198 |
+
A dictionary of datasets, where each key is a dataset name and
|
| 199 |
+
the value is a list of column names to be plotted
|
| 200 |
+
"""
|
| 201 |
+
numax = len(datasets)
|
| 202 |
+
fig, axes = plt.subplots(
|
| 203 |
+
nrows=numax,
|
| 204 |
+
ncols=1,
|
| 205 |
+
figsize=(10, 5*numax))
|
| 206 |
+
for i, ax in enumerate(axes.flat):
|
| 207 |
+
dataset = list(datasets.keys())[i]
|
| 208 |
+
colstoplot = datasets[dataset]
|
| 209 |
+
self.plot_stock_metrics_ax(
|
| 210 |
+
ax,
|
| 211 |
+
dataset,
|
| 212 |
+
df,
|
| 213 |
+
colstoplot)
|
| 214 |
+
plt.tight_layout()
|
| 215 |
+
plt.savefig(os.path.join('plots', f'{self.ticker}.png'))
|
| 216 |
+
|
| 217 |
+
def plot_stock_metrics_ax(
|
| 218 |
+
self,
|
| 219 |
+
ax:Axes,
|
| 220 |
+
dataset:str,
|
| 221 |
+
df:pd.DataFrame,
|
| 222 |
+
colstoplot:list) -> None:
|
| 223 |
+
"""
|
| 224 |
+
Plots specified stock metrics on the provided Axes object.
|
| 225 |
+
This function takes in an Axes object and plots the specified columns from
|
| 226 |
+
the DataFrame `df` on it. It formats the x-axis with major ticks set to
|
| 227 |
+
every Monday and minor ticks set to every day. The plot includes a title,
|
| 228 |
+
x and y labels, and optionally a legend if more than one column is plotted.
|
| 229 |
+
Additional shaded regions are added for certain datasets.
|
| 230 |
+
Args:
|
| 231 |
+
ax (matplotlib.axes.Axes): The flattened axes to plot on.
|
| 232 |
+
dataset (str): The name of the dataset, used for the title and y-label.
|
| 233 |
+
df (pd.DataFrame): The DataFrame containing the data to plot.
|
| 234 |
+
colstoplot (list): A list of column names to plot from the DataFrame.
|
| 235 |
+
Note:
|
| 236 |
+
- If `dataset` is 'Index' or 'Indices', the y-axis is limited to [0, 100]
|
| 237 |
+
and a shaded region between y=30 and y=70 is added.
|
| 238 |
+
- If `dataset` is 'Price' or 'Prices', a shaded region between 'Low' and
|
| 239 |
+
'High' columns is added.
|
| 240 |
+
"""
|
| 241 |
+
|
| 242 |
+
print(f'plotting {colstoplot} in {dataset}')
|
| 243 |
+
colorcycle = ['black', 'blue', 'red', 'green', 'orange']
|
| 244 |
+
for i, col in enumerate(colstoplot):
|
| 245 |
+
ax.plot(
|
| 246 |
+
df.index,
|
| 247 |
+
df[col],
|
| 248 |
+
color=colorcycle[i],
|
| 249 |
+
label=col,
|
| 250 |
+
linewidth=2)
|
| 251 |
+
# Format major ticks with year
|
| 252 |
+
# Set major ticks (every Monday with labels)
|
| 253 |
+
ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO))
|
| 254 |
+
ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
|
| 255 |
+
# Set minor ticks (every day, but without labels)
|
| 256 |
+
ax.xaxis.set_minor_locator(mdates.DayLocator())
|
| 257 |
+
|
| 258 |
+
ax.set_title(dataset)
|
| 259 |
+
ax.set_xlabel('Date')
|
| 260 |
+
ax.set_ylabel(dataset)
|
| 261 |
+
if len(colstoplot) > 1:
|
| 262 |
+
ax.legend()
|
| 263 |
+
if dataset in ['Index', 'Indices']:
|
| 264 |
+
ax.set_ylim([0, 100])
|
| 265 |
+
# Add a transparent shaded region between y=30 and y=70
|
| 266 |
+
ax.fill_between(df.index, 30, 70, color='gray', alpha=0.3)
|
| 267 |
+
if dataset in ['Price', 'Prices']:
|
| 268 |
+
# Add a transparent shaded region between y=30 and y=70
|
| 269 |
+
ax.fill_between(df.index, df['Low'], df['High'], color='gray', alpha=0.3)
|
| 270 |
+
ax.grid(True, linestyle='--', alpha=0.7)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
if __name__ == '__main__':
|
| 274 |
+
ticker = 'GOOG'
|
| 275 |
+
ta = TechnicalAnalysis(ticker, debug=False).run()
|
| 276 |
+
|
| 277 |
+
|