File size: 12,451 Bytes
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac2093
f680f62
0975e8a
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac2093
5df6537
 
 
0975e8a
2ac2093
5df6537
 
f680f62
 
 
 
 
 
 
 
 
 
d730ea7
f680f62
 
2ac2093
 
 
 
 
0975e8a
2ac2093
 
 
 
 
0975e8a
 
 
 
 
 
d730ea7
 
 
 
0975e8a
 
d730ea7
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5df6537
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5df6537
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0975e8a
 
 
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
0975e8a
 
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0975e8a
6d0b28e
 
 
 
0975e8a
 
 
f680f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0975e8a
f680f62
 
 
 
 
 
 
 
 
 
 
 
d730ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f680f62
 
 
6d0b28e
5df6537
f680f62
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import datetime as dt
import logging

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.axes import Axes
import os
import pandas as pd
from ta.volume import volume_weighted_average_price
from ta.momentum import RSIIndicator, StochasticOscillator
from ta.trend import MACD 
import yfinance as yf


class TechnicalAnalysis():
    def __init__(
            self,
            ticker:str,
            fetchperiodinweeks:int=12,
            plot_ta:bool=True,
            savefig:bool=False,
            debug=False):
        # input arguments
        """
        Initialize TechnicalAnalysis object.
        Args:
        ticker : str
            stock ticker to analyze
        fetchperiodinweeks : int, optional, default: 8
            number of weeks to fetch historical data from YahooFinance
        plot_ta : bool, optional, default: True
            whether to generate plots of technical analysis metrics. Plot will be created under plots/{ticker}.png
        debug : bool, optional, default: False
            whether run in debug mode, so that logging should be produced at debug level
        """
        # set up logging
        if debug:
            self.logger_level = logging.DEBUG
        else:
            self.logger_level = logging.INFO
        self.logger = logging.getLogger(__name__)
        logging.basicConfig(level=self.logger_level)  # filename='TechnicalAnalysis.log', 

        # input arguments
        self.ticker = ticker
        self.fetchperiodinweeks = fetchperiodinweeks
        self.plot_ta = plot_ta
        self.savefig = savefig
        # done initializing
        self.logger.info(f'Initialized TechnicalAnalysis object for ticker: {ticker}')

    def run(
            self
            ) -> None:
        """
        Main entry point for the TechnicalAnalysis object.
        This method:
        - fetches historical data from YahooFinance,
        - computes the technical analysis metrics
        - plots the price and TA metrics.
        """
        # fetch data from yf
        self.df = self.fetch_data()
        # add the features based on technical analysis
        if self.df.shape[0] > 0:
            self.df = self.tech_analysis()
            # plot the results
            if self.plot_ta:
                os.makedirs('plots', exist_ok=True)
                fig = self.plot_stock_metrics(
                    self.df,
                    datasets={
                            'Volume': ['Volume'],
                            'Prices': ['Close', 'VWAP'],  # 'High','Low', 
                            'Indices': ['RSI', 'StochOsc'],
                            'Trend': ['MACD', 'MACDsig', 'MACDdif']},
                    savefig=self.savefig
                    )
            else:
                fig = None
        else:
            if self.plot_ta:
                fig = self.get_fetcherror_fig(message='failed fetching data')
            else:
                fig = None

        return self.df, fig
    
    def fetch_data(
            self
            ) -> pd.DataFrame:
        """
        Fetches historical stock price data from Yahoo Finance.
        This method downloads historical stock price data for the specified
        ticker over a given period of weeks. The data is fetched on a daily
        interval and stored in a pandas DataFrame. If the download is successful,
        redundant ticker columns are removed, and logging information is
        recorded. In case of failure, an empty DataFrame is returned and an
        exception is raised.
        Returns:
        pd.DataFrame
            A DataFrame containing the historical stock price data with columns
            for open, high, low, close, volume, and adjusted close prices.
        Raises:
        Exception
            If the data fetching fails, an exception is raised with an error
            message.
        """

        period_start = dt.datetime.now() - dt.timedelta(weeks=self.fetchperiodinweeks)
        self.logger.info(f'Fetching price data for {self.ticker}')
        try:
            df = yf.download(
                self.ticker,
                start=period_start,
                end=dt.datetime.now(),
                interval='1d'
            )
        except Exception as e:
            self.logger.error(f'{e}')
            # create empty df
            df = pd.DataFrame()

        if df.shape[0] > 0:
            # get rid of the redundant ticker column
            df.columns = df.columns.droplevel('Ticker')
            self.logger.debug(df.head(10))
            self.logger.info(f'Fetched {df.shape[0]} rows for {self.ticker}')
        
        return df
    
    def tech_analysis(
            self
            ) -> pd.DataFrame:
        """
        Calculates technical analysis indicators for the fetched stock price data.
        This method takes the fetched stock price data and calculates several
        technical analysis indicators. The following indicators are calculated:
        - Additional Price Indicators:
          - Volume-Weighted Average Price (VWAP)
        - Momentum Indicators:
          - Relative Strength Index (RSI)
            - Stochastic Oscillator
        - Trend Indicators: 
            - Moving Average Convergence Divergence (MACD)
        The calculated indicators are added to the DataFrame as new columns.
        Returns:
        pd.DataFrame
            The DataFrame with the calculated technical analysis indicators.
        """
        df = self.df

        # Price Indicators
        # Volume-Weighted Average Price (VWAP)
        # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-overlays/volume-weighted-average-price-vwap
        df['VWAP'] = volume_weighted_average_price(
            high=df['High'],
            low=df['Low'],
            close=df['Close'],
            volume=df['Volume'],
        )

        # Indices
        # RSI:
        # https://www.investopedia.com/terms/r/rsi.asp
        df['RSI'] = RSIIndicator(
            df['Close'],
            window=14).rsi()
        # Stochastic Oscillator: 
        # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/stochastic-oscillator-fast-slow-and-full
        df['StochOsc'] = StochasticOscillator(
            df['High'],
            df['Low'],
            df['Close'],
            window=14).stoch()

        # Trend signals
        # Moving Average Convergence Divergence (MACD):
        # https://chartschool.stockcharts.com/table-of-contents/technical-indicators-and-overlays/technical-indicators/macd-moving-average-convergence-divergence-oscillator
        macd = MACD(
            df['Close'],
            window_slow=26,
            window_fast=12,
            window_sign=9)
        df['MACD'] = macd.macd()
        df['MACDsig'] = macd.macd_signal()
        df['MACDdif'] = macd.macd_diff()

        return df

    def plot_stock_metrics(
            self,
            df,
            datasets={
                'Volume': ['Volume'],
                'Price': ['Close']  # 'High','Low'
                },
            savefig=False
            ) -> None:
        """
        Plots the given stock metrics datasets as subplots.
        This method takes in a DataFrame and a dictionary of datasets, where
        each key is a dataset name and the value is a list of column names.
        The method creates a figure with subplots for each dataset and plots
        the corresponding columns of the DataFrame.
        The figure is then saved to a file in the 'plots' directory in png format
        with the ticker symbol as the filename.
        Args:
        df (pd.DataFrame)
            The DataFrame to plot
        datasets (dict)
            A dictionary of datasets, where each key is a dataset name and
            the value is a list of column names to be plotted
        savefig (bool)
            Whether to save the figure to a file
        """
        numax = len(datasets)
        fig, axes = plt.subplots(
            nrows=numax,
            ncols=1,
            figsize=(10, 5*numax))
        for i, ax in enumerate(axes.flat):
            dataset = list(datasets.keys())[i]
            colstoplot = datasets[dataset]
            self.plot_stock_metrics_ax(
                ax,
                dataset,
                df,
                colstoplot)
        plt.tight_layout()
        if savefig:
            rootdir = os.path.dirname(os.path.dirname(__file__))
            fname = os.path.join(rootdir, 'plots', f'{self.ticker}.png')
            plt.savefig(fname)
            self.logger.info(f'Saved figure into: {fname}')
            plt.close()
            fig = None
        return fig

    def plot_stock_metrics_ax(
            self,
            ax:Axes,
            dataset:str,
            df:pd.DataFrame,
            colstoplot:list) -> None:
        """
        Plots specified stock metrics on the provided Axes object.
        This function takes in an Axes object and plots the specified columns from
        the DataFrame `df` on it. It formats the x-axis with major ticks set to 
        every Monday and minor ticks set to every day. The plot includes a title, 
        x and y labels, and optionally a legend if more than one column is plotted.
        Additional shaded regions are added for certain datasets.
        Args:
            ax (matplotlib.axes.Axes): The flattened axes to plot on.
            dataset (str): The name of the dataset, used for the title and y-label.
            df (pd.DataFrame): The DataFrame containing the data to plot.
            colstoplot (list): A list of column names to plot from the DataFrame.
        Note:
            - If `dataset` is 'Index' or 'Indices', the y-axis is limited to [0, 100]
            and a shaded region between y=30 and y=70 is added.
            - If `dataset` is 'Price' or 'Prices', a shaded region between 'Low' and 
            'High' columns is added.
        """

        print(f'plotting {colstoplot} in {dataset}')
        colorcycle = ['black', 'blue', 'red', 'green', 'orange']
        for i, col in enumerate(colstoplot):
            ax.plot(
                df.index,
                df[col],
                color=colorcycle[i],
                label=col,
                linewidth=2)
        # Format major ticks with year
        # Set major ticks (every Monday with labels)
        ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.MO))
        ax.xaxis.set_major_formatter(mdates.DateFormatter('%b %d'))
        # Set minor ticks (every day, but without labels)
        ax.xaxis.set_minor_locator(mdates.DayLocator())
        
        ax.set_title(dataset)
        # ax.set_xlabel('Date')
        ax.set_ylabel(dataset)
        if len(colstoplot) > 1:
            ax.legend()
        if dataset in ['Index', 'Indices']:
            ax.set_ylim([0, 100])
            # Add a transparent shaded region between y=30 and y=70
            ax.fill_between(df.index, 30, 70, color='gray', alpha=0.3)
        if dataset in ['Price', 'Prices']:
            # Add a transparent shaded region between y=30 and y=70
            ax.fill_between(df.index, df['Low'], df['High'], color='gray', alpha=0.3)
        ax.grid(True, linestyle='--', alpha=0.7)

    def get_fetcherror_fig(
            self,
            message
            ) -> plt.Figure:
        """
        Fetches images/plot_error.png, annotates it and returns it as a matplotlib.pyplot.Figure object
        Args:
        message (str): message to be annotated on the displayed image
        Returns:
        plt.Figure: figure object containing the annotated image
        """
        
        fig, ax = plt.subplots(
            figsize=(5, 5)
        )

        # Load and display the image
        parentdir = os.path.dirname(os.path.dirname(__file__)) 
        fname = os.path.join(parentdir, 'images', 'plot_error.png')
        img = plt.imread(fname)
        ax.imshow(img)

        # Remove axes ticks and labels
        # ax.set_xticks([])
        # ax.set_yticks([])
        ax.axis('off')  # removes both ticks and axes lines completely

        ax.text(0.5, 0.05, message, fontsize=20, ha='center', va='center', transform=ax.transAxes)
        return fig

if __name__ == '__main__':
    ticker = 'GOOG'
    df, fig = TechnicalAnalysis(ticker, plot_ta=True, savefig=True, debug=False).run()
    print(f'columns: {df.columns}')