OnurKerimoglu commited on
Commit
f680f62
·
1 Parent(s): 8ba4914

introduced src.technical_analysis.py

Browse files
Files changed (1) hide show
  1. src/technical_analysis.py +277 -0
src/technical_analysis.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime as dt
2
+ import logging
3
+
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.dates as mdates
6
+ from matplotlib.axes import Axes
7
+ import os
8
+ import pandas as pd
9
+ from ta.volume import volume_weighted_average_price
10
+ from ta.momentum import RSIIndicator, StochasticOscillator
11
+ from ta.trend import MACD
12
+ import yfinance as yf
13
+
14
+
15
+ class TechnicalAnalysis():
16
+ def __init__(
17
+ self,
18
+ ticker:str,
19
+ fetchperiodinweeks:int=8,
20
+ plot_ta:bool=True,
21
+ debug=False):
22
+ # input arguments
23
+ """
24
+ Initialize TechnicalAnalysis object.
25
+ Args:
26
+ ticker : str
27
+ stock ticker to analyze
28
+ fetchperiodinweeks : int, optional, default: 8
29
+ number of weeks to fetch historical data from YahooFinance
30
+ plot_ta : bool, optional, default: True
31
+ whether to generate plots of technical analysis metrics. Plot will be created under plots/{ticker}.png
32
+ debug : bool, optional, default: False
33
+ whether run in debug mode, so that logging should be produced at debug level
34
+ """
35
+ self.ticker = ticker
36
+ self.fetchperiodinweeks = fetchperiodinweeks
37
+ self.plot_ta = plot_ta
38
+
39
+ # set up logging
40
+ if debug:
41
+ self.logger_level = logging.DEBUG
42
+ else:
43
+ self.logger_level = logging.INFO
44
+ self.logger = logging.getLogger(__name__)
45
+ logging.basicConfig(level=self.logger_level) # filename='TechnicalAnalysis.log',
46
+
47
+ def run(
48
+ self
49
+ ) -> None:
50
+ # fetch data from yf
51
+ """
52
+ Main entry point for the TechnicalAnalysis object.
53
+ This method:
54
+ - fetches historical data from YahooFinance,
55
+ - computes the technical analysis metrics
56
+ - plots the price and TA metrics.
57
+ """
58
+ self.df = self.fetch_data()
59
+ # add the features based on technical analysis
60
+ self.df = self.tech_analysis()
61
+ # plot the results
62
+ if self.plot_ta:
63
+ os.makedirs('plots', exist_ok=True)
64
+ self.plot_stock_metrics(
65
+ self.df,
66
+ datasets={
67
+ 'Volume': ['Volume'],
68
+ 'Prices': ['Close', 'VWAP'], # 'High','Low',
69
+ 'Indices': ['RSI', 'StochOsc'],
70
+ 'Trend': ['MACD', 'MACDsig', 'MACDdif']}
71
+ )
72
+
73
+
74
+ def fetch_data(
75
+ self
76
+ ) -> pd.DataFrame:
77
+ """
78
+ Fetches historical stock price data from Yahoo Finance.
79
+ This method downloads historical stock price data for the specified
80
+ ticker over a given period of weeks. The data is fetched on a daily
81
+ interval and stored in a pandas DataFrame. If the download is successful,
82
+ redundant ticker columns are removed, and logging information is
83
+ recorded. In case of failure, an empty DataFrame is returned and an
84
+ exception is raised.
85
+ Returns:
86
+ pd.DataFrame
87
+ A DataFrame containing the historical stock price data with columns
88
+ for open, high, low, close, volume, and adjusted close prices.
89
+ Raises:
90
+ Exception
91
+ If the data fetching fails, an exception is raised with an error
92
+ message.
93
+ """
94
+
95
+ period_start = dt.datetime.now() - dt.timedelta(weeks=self.fetchperiodinweeks)
96
+ self.logger.info(f'Fetching price data for {ticker}')
97
+ try:
98
+ df = yf.download(
99
+ self.ticker,
100
+ start=period_start,
101
+ end=dt.datetime.now(),
102
+ interval='1d'
103
+ )
104
+ except Exception as e:
105
+ self.logger.error(f'{e}')
106
+ # create empty df
107
+ df = pd.DataFrame()
108
+
109
+ if df.shape[0] > 0:
110
+ # get rid of the redundant ticker column
111
+ df.columns = df.columns.droplevel('Ticker')
112
+ self.logger.debug(df.head(10))
113
+ self.logger.info(f'Fetched {df.shape[0]} rows for {ticker}')
114
+ else:
115
+ raise Exception(f'Failed to fetch data for {ticker}')
116
+
117
+ return df
118
+
119
+ def tech_analysis(
120
+ self
121
+ ) -> pd.DataFrame:
122
+ """
123
+ Calculates technical analysis indicators for the fetched stock price data.
124
+ This method takes the fetched stock price data and calculates several
125
+ technical analysis indicators. The following indicators are calculated:
126
+ - Additional Price Indicators:
127
+ - Volume-Weighted Average Price (VWAP)
128
+ - Momentum Indicators:
129
+ - Relative Strength Index (RSI)
130
+ - Stochastic Oscillator
131
+ - Trend Indicators:
132
+ - Moving Average Convergence Divergence (MACD)
133
+ The calculated indicators are added to the DataFrame as new columns.
134
+ Returns:
135
+ pd.DataFrame
136
+ The DataFrame with the calculated technical analysis indicators.
137
+ """
138
+ df = self.df
139
+
140
+ # Price Indicators
141
+ # Volume-Weighted Average Price (VWAP)
142
+ # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-overlays/volume-weighted-average-price-vwap
143
+ df['VWAP'] = volume_weighted_average_price(
144
+ high=df['High'],
145
+ low=df['Low'],
146
+ close=df['Close'],
147
+ volume=df['Volume'],
148
+ )
149
+
150
+ # Indices
151
+ # RSI:
152
+ # https://www.investopedia.com/terms/r/rsi.asp
153
+ df['RSI'] = RSIIndicator(
154
+ df['Close'],
155
+ window=14).rsi()
156
+ # Stochastic Oscillator:
157
+ # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/stochastic-oscillator-fast-slow-and-full
158
+ df['StochOsc'] = StochasticOscillator(
159
+ df['High'],
160
+ df['Low'],
161
+ df['Close'],
162
+ window=14).stoch()
163
+
164
+ # Trend signals
165
+ # Moving Average Convergence Divergence (MACD):
166
+ # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/macd-moving-average-convergence-divergence-oscillator
167
+ macd = MACD(
168
+ df['Close'],
169
+ window_slow=26,
170
+ window_fast=12,
171
+ window_sign=9)
172
+ df['MACD'] = macd.macd()
173
+ df['MACDsig'] = macd.macd_signal()
174
+ df['MACDdif'] = macd.macd_diff()
175
+
176
+ return df
177
+
178
+ def plot_stock_metrics(
179
+ self,
180
+ df,
181
+ datasets={
182
+ 'Volume': ['Volume'],
183
+ 'Price': ['Close'] # 'High','Low'
184
+ }
185
+ ) -> None:
186
+ """
187
+ Plots the given stock metrics datasets as subplots.
188
+ This method takes in a DataFrame and a dictionary of datasets, where
189
+ each key is a dataset name and the value is a list of column names.
190
+ The method creates a figure with subplots for each dataset and plots
191
+ the corresponding columns of the DataFrame.
192
+ The figure is then saved to a file in the 'plots' directory in png format
193
+ with the ticker symbol as the filename.
194
+ Args:
195
+ df (pd.DataFrame)
196
+ The DataFrame to plot
197
+ datasets (dict)
198
+ A dictionary of datasets, where each key is a dataset name and
199
+ the value is a list of column names to be plotted
200
+ """
201
+ numax = len(datasets)
202
+ fig, axes = plt.subplots(
203
+ nrows=numax,
204
+ ncols=1,
205
+ figsize=(10, 5*numax))
206
+ for i, ax in enumerate(axes.flat):
207
+ dataset = list(datasets.keys())[i]
208
+ colstoplot = datasets[dataset]
209
+ self.plot_stock_metrics_ax(
210
+ ax,
211
+ dataset,
212
+ df,
213
+ colstoplot)
214
+ plt.tight_layout()
215
+ plt.savefig(os.path.join('plots', f'{self.ticker}.png'))
216
+
217
+ def plot_stock_metrics_ax(
218
+ self,
219
+ ax:Axes,
220
+ dataset:str,
221
+ df:pd.DataFrame,
222
+ colstoplot:list) -> None:
223
+ """
224
+ Plots specified stock metrics on the provided Axes object.
225
+ This function takes in an Axes object and plots the specified columns from
226
+ the DataFrame `df` on it. It formats the x-axis with major ticks set to
227
+ every Monday and minor ticks set to every day. The plot includes a title,
228
+ x and y labels, and optionally a legend if more than one column is plotted.
229
+ Additional shaded regions are added for certain datasets.
230
+ Args:
231
+ ax (matplotlib.axes.Axes): The flattened axes to plot on.
232
+ dataset (str): The name of the dataset, used for the title and y-label.
233
+ df (pd.DataFrame): The DataFrame containing the data to plot.
234
+ colstoplot (list): A list of column names to plot from the DataFrame.
235
+ Note:
236
+ - If `dataset` is 'Index' or 'Indices', the y-axis is limited to [0, 100]
237
+ and a shaded region between y=30 and y=70 is added.
238
+ - If `dataset` is 'Price' or 'Prices', a shaded region between 'Low' and
239
+ 'High' columns is added.
240
+ """
241
+
242
+ print(f'plotting {colstoplot} in {dataset}')
243
+ colorcycle = ['black', 'blue', 'red', 'green', 'orange']
244
+ for i, col in enumerate(colstoplot):
245
+ ax.plot(
246
+ df.index,
247
+ df[col],
248
+ color=colorcycle[i],
249
+ label=col,
250
+ linewidth=2)
251
+ # Format major ticks with year
252
+ # Set major ticks (every Monday with labels)
253
+ ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO))
254
+ ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
255
+ # Set minor ticks (every day, but without labels)
256
+ ax.xaxis.set_minor_locator(mdates.DayLocator())
257
+
258
+ ax.set_title(dataset)
259
+ ax.set_xlabel('Date')
260
+ ax.set_ylabel(dataset)
261
+ if len(colstoplot) > 1:
262
+ ax.legend()
263
+ if dataset in ['Index', 'Indices']:
264
+ ax.set_ylim([0, 100])
265
+ # Add a transparent shaded region between y=30 and y=70
266
+ ax.fill_between(df.index, 30, 70, color='gray', alpha=0.3)
267
+ if dataset in ['Price', 'Prices']:
268
+ # Add a transparent shaded region between y=30 and y=70
269
+ ax.fill_between(df.index, df['Low'], df['High'], color='gray', alpha=0.3)
270
+ ax.grid(True, linestyle='--', alpha=0.7)
271
+
272
+
273
+ if __name__ == '__main__':
274
+ ticker = 'GOOG'
275
+ ta = TechnicalAnalysis(ticker, debug=False).run()
276
+
277
+