QuantumLearner commited on
Commit
32fc2f7
·
verified ·
1 Parent(s): b15ee7f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -0
app.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import yfinance as yf
3
+ import pandas as pd
4
+ import numpy as np
5
+ import plotly.graph_objects as go
6
+ from fastdtw import fastdtw
7
+ from scipy.spatial.distance import euclidean
8
+ import ta
9
+ from sklearn.decomposition import PCA
10
+ from sklearn.impute import SimpleImputer
11
+
12
+ # Function to normalize the time series
13
+ def normalize(ts):
14
+ return (ts - ts.min()) / (ts.max() - ts.min())
15
+
16
+ # Function to calculate DTW distance
17
+ def dtw_distance(ts1, ts2):
18
+ ts1_normalized = normalize(ts1)
19
+ ts2_normalized = normalize(ts2)
20
+ distance, _ = fastdtw(ts1_normalized.reshape(-1, 1), ts2_normalized.reshape(-1, 1), dist=euclidean)
21
+ return distance
22
+
23
+ # Function to calculate correlation
24
+ def correlation(ts1, ts2):
25
+ return np.corrcoef(ts1, ts2)[0, 1]
26
+
27
+ # Function to find most similar patterns using DTW
28
+ def find_most_similar_pattern_dtw(price_data_pct_change, n_days, subsequent_days):
29
+ current_window = price_data_pct_change[-n_days:].values
30
+
31
+ min_distances = [(float('inf'), -1), (float('inf'), -1), (float('inf'), -1)]
32
+ for start_index in range(len(price_data_pct_change) - 2 * n_days - subsequent_days):
33
+ past_window = price_data_pct_change[start_index:start_index + n_days].values
34
+ distance = dtw_distance(current_window, past_window)
35
+
36
+ for i, (min_distance, _) in enumerate(min_distances):
37
+ if distance < min_distance:
38
+ min_distances[i] = (distance, start_index)
39
+ break
40
+
41
+ return min_distances
42
+
43
+ # Function to find most similar patterns using correlation
44
+ def find_most_similar_pattern_corr(price_data_pct_change, n_days, subsequent_days, pre_days):
45
+ current_window = price_data_pct_change[-n_days:].values
46
+
47
+ max_correlations = [(float('-inf'), -1), (float('-inf'), -1), (float('-inf'), -1)]
48
+ for start_index in range(len(price_data_pct_change) - 2 * n_days - subsequent_days):
49
+ past_window = price_data_pct_change[start_index:start_index + n_days].values
50
+ corr = correlation(current_window, past_window)
51
+
52
+ for i, (max_corr, _) in enumerate(max_correlations):
53
+ if corr > max_corr:
54
+ max_correlations[i] = (corr, start_index)
55
+ break
56
+
57
+ return max_correlations
58
+
59
+ # Add technical analysis features
60
+ def add_ta_features(data):
61
+ data['trend_ichimoku_conv'] = ta.trend.ichimoku_a(data['High'], data['Low'])
62
+ data['trend_ema_slow'] = ta.trend.ema_indicator(data['Close'], 50)
63
+ data['momentum_kama'] = ta.momentum.kama(data['Close'])
64
+ data['trend_psar_up'] = ta.trend.psar_up(data['High'], data['Low'], data['Close'])
65
+ data['volume_vwap'] = ta.volume.VolumeWeightedAveragePrice(data['High'], data['Low'], data['Close'], data['Volume']).volume_weighted_average_price()
66
+ data['trend_ichimoku_a'] = ta.trend.ichimoku_a(data['High'], data['Low'])
67
+ data['volatility_kcl'] = ta.volatility.KeltnerChannel(data['High'], data['Low'], data['Close']).keltner_channel_lband()
68
+ data['trend_ichimoku_b'] = ta.trend.ichimoku_b(data['High'], data['Low'])
69
+ data['trend_ichimoku_base'] = ta.trend.ichimoku_base_line(data['High'], data['Low'])
70
+ data['trend_sma_fast'] = ta.trend.sma_indicator(data['Close'], 20)
71
+ data['volatility_dcm'] = ta.volatility.DonchianChannel(data['High'], data['Low'], data['Close']).donchian_channel_mband()
72
+ data['volatility_bbl'] = ta.volatility.BollingerBands(data['Close']).bollinger_lband()
73
+ data['volatility_bbm'] = ta.volatility.BollingerBands(data['Close']).bollinger_mavg()
74
+ data['volatility_kcc'] = ta.volatility.KeltnerChannel(data['High'], data['Low'], data['Close']).keltner_channel_mband()
75
+ data['volatility_kch'] = ta.volatility.KeltnerChannel(data['High'], data['Low'], data['Close']).keltner_channel_hband()
76
+ data['trend_sma_slow'] = ta.trend.sma_indicator(data['Close'], 200)
77
+ data['trend_ema_fast'] = ta.trend.ema_indicator(data['Close'], 20)
78
+ data['volatility_dch'] = ta.volatility.DonchianChannel(data['High'], data['Low'], data['Close']).donchian_channel_hband()
79
+ data['others_cr'] = ta.others.cumulative_return(data['Close'])
80
+ data['Adj Close'] = data['Close']
81
+
82
+ return data
83
+
84
+ def dtw_distance_with_ta(ts1, ts2, ts1_ta, ts2_ta, weight=0.8):
85
+ ts1_normalized = normalize(ts1)
86
+ ts2_normalized = normalize(ts2)
87
+ distance_pct_change, _ = fastdtw(ts1_normalized.reshape(-1, 1), ts2_normalized.reshape(-1, 1), dist=euclidean)
88
+
89
+ distance_ta, _ = fastdtw(ts1_ta, ts2_ta, dist=euclidean)
90
+ distance = weight * distance_pct_change + (1 - weight) * distance_ta
91
+ return distance
92
+
93
+ def extract_and_reduce_features(data, n_components=3):
94
+ ta_features = data.drop(columns=['Open', 'High', 'Low', 'Close', 'Volume', 'Adj Close'])
95
+
96
+ imputer = SimpleImputer(strategy='mean')
97
+ imputed_ta_features = imputer.fit_transform(ta_features)
98
+
99
+ pca = PCA(n_components=n_components)
100
+ reduced_features = pca.fit_transform(imputed_ta_features)
101
+ return reduced_features
102
+
103
+ # Streamlit app
104
+ st.set_page_config(page_title="Pattern Recognition", layout="wide")
105
+ st.title('Pattern Recognition in Stock Prices')
106
+
107
+ # Sidebar for method selection
108
+ selected = st.sidebar.radio("Select Method", ["DTW Pattern Recognition", "Correlation Pattern Recognition", "TA-Enhanced DTW Pattern Recognition"])
109
+
110
+ # Sidebar for input parameters
111
+ st.sidebar.header("Input Parameters")
112
+ ticker = st.sidebar.text_input('Enter Stock Ticker', 'ASML.AS')
113
+ start_date = st.sidebar.date_input('Start Date', pd.to_datetime('2000-01-01'))
114
+ end_date = st.sidebar.date_input('End Date', pd.to_datetime('2023-12-30'))
115
+ subsequent_days = st.sidebar.slider('Subsequent Days to Forecast', min_value=5, max_value=60, value=20, step=5)
116
+ n_days_options = st.sidebar.multiselect('Days to Compare', options=[15, 20, 30, 40, 50], default=[15, 20, 40])
117
+ pre_days = st.sidebar.slider('Days Prior to Similar Series', min_value=10, max_value=100, value=60, step=10)
118
+
119
+ if 'data' not in st.session_state:
120
+ st.session_state.data = None
121
+ if 'price_data_pct_change' not in st.session_state:
122
+ st.session_state.price_data_pct_change = None
123
+ if 'data_with_ta' not in st.session_state:
124
+ st.session_state.data_with_ta = None
125
+ if 'reduced_features' not in st.session_state:
126
+ st.session_state.reduced_features = None
127
+ if 'results_dtw' not in st.session_state:
128
+ st.session_state.results_dtw = None
129
+ if 'results_corr' not in st.session_state:
130
+ st.session_state.results_corr = None
131
+ if 'results_ta_dtw' not in st.session_state:
132
+ st.session_state.results_ta_dtw = None
133
+
134
+ def run_dtw():
135
+ min_distances = []
136
+ figs = []
137
+ for n_days in n_days_options:
138
+ min_distances.append(find_most_similar_pattern_dtw(st.session_state.price_data_pct_change, n_days, subsequent_days))
139
+
140
+ fig1 = go.Figure()
141
+ # Plot the entire stock price data
142
+ fig1.add_trace(go.Scatter(x=st.session_state.price_data.index, y=st.session_state.price_data, mode='lines', name='Overall stock price', line=dict(color='blue')))
143
+ colors = ['red', 'green', 'orange']
144
+ for i, (_, start_index) in enumerate(min_distances[-1]):
145
+ # Plot the pattern period
146
+ past_window_start_date = st.session_state.price_data.index[start_index]
147
+ past_window_end_date = st.session_state.price_data.index[start_index + n_days + subsequent_days]
148
+ fig1.add_trace(go.Scatter(x=st.session_state.price_data[past_window_start_date:past_window_end_date].index,
149
+ y=st.session_state.price_data[past_window_start_date:past_window_end_date],
150
+ mode='lines', name=f"Pattern {i + 1}", line=dict(color=colors[i % len(colors)])))
151
+
152
+ # Add labels and legend
153
+ fig1.update_layout(title=f'{ticker} Stock Price Data',
154
+ xaxis_title='Date',
155
+ yaxis_title='Stock Price',
156
+ legend_title='Legend')
157
+
158
+ reindexed_current_window = (st.session_state.price_data_pct_change[-n_days:] + 1).cumprod() * 100
159
+ fig2 = go.Figure()
160
+ for i, (_, start_index) in enumerate(min_distances[-1]):
161
+ past_window = st.session_state.price_data_pct_change[start_index:start_index + n_days + subsequent_days]
162
+ reindexed_past_window = (past_window + 1).cumprod() * 100
163
+ fig2.add_trace(go.Scatter(x=list(range(n_days + subsequent_days)), y=reindexed_past_window,
164
+ mode='lines', name=f"Past window {i + 1} (with subsequent {subsequent_days} days)",
165
+ line=dict(color=colors[i % len(colors)], width=3 if i == 0 else 1)))
166
+
167
+ fig2.add_trace(go.Scatter(x=list(range(n_days)), y=reindexed_current_window, mode='lines',
168
+ name="Current window", line=dict(color='black', width=3)))
169
+
170
+ fig2.update_layout(title=f"Most similar {n_days}-day patterns in {ticker} stock price history (aligned, reindexed)",
171
+ xaxis_title="Days",
172
+ yaxis_title="Reindexed Price",
173
+ legend_title="Legend")
174
+
175
+ figs.append((fig1, fig2))
176
+
177
+ st.session_state.results_dtw = figs
178
+
179
+ def run_corr():
180
+ max_correlations = []
181
+ figs = []
182
+ for n_days in n_days_options:
183
+ max_correlations.append(find_most_similar_pattern_corr(st.session_state.price_data_pct_change, n_days, subsequent_days, pre_days))
184
+
185
+ fig1 = go.Figure()
186
+ # Plot the entire stock price data
187
+ fig1.add_trace(go.Scatter(x=st.session_state.price_data.index, y=st.session_state.price_data, mode='lines', name='Overall stock price', line=dict(color='blue')))
188
+ colors = ['red', 'green', 'orange']
189
+ for i, (_, start_index) in enumerate(max_correlations[-1]):
190
+ # Plot the previous period
191
+ past_window_start_date = st.session_state.price_data.index[start_index - pre_days]
192
+ past_window_end_date = st.session_state.price_data.index[start_index + n_days + subsequent_days]
193
+ fig1.add_trace(go.Scatter(x=st.session_state.price_data[past_window_start_date:past_window_end_date].index,
194
+ y=st.session_state.price_data[past_window_start_date:past_window_end_date],
195
+ mode='lines', name=f"Pattern {i + 1}", line=dict(color=colors[i % len(colors)])))
196
+
197
+ # Add labels and legend
198
+ fig1.update_layout(title=f'{ticker} Stock Price Data',
199
+ xaxis_title='Date',
200
+ yaxis_title='Stock Price',
201
+ legend_title='Legend')
202
+
203
+ reindexed_current_window = (st.session_state.price_data_pct_change[-n_days:] + 1).cumprod() * 100
204
+ fig2 = go.Figure()
205
+ for i, (_, start_index) in enumerate(max_correlations[-1]):
206
+ past_window = st.session_state.price_data_pct_change[start_index:start_index + n_days + subsequent_days]
207
+ reindexed_past_window = (past_window + 1).cumprod() * 100
208
+ fig2.add_trace(go.Scatter(x=list(range(pre_days, pre_days + n_days + subsequent_days)), y=reindexed_past_window,
209
+ mode='lines', name=f"Past window {i + 1} (with subsequent {subsequent_days} days)",
210
+ line=dict(color=colors[i % len(colors)], width=3 if i == 0 else 1)))
211
+
212
+ fig2.add_trace(go.Scatter(x=list(range(pre_days, pre_days + n_days)), y=reindexed_current_window, mode='lines',
213
+ name="Current window", line=dict(color='black', width=3)))
214
+
215
+ fig2.update_layout(title=f"Most similar {n_days}-day patterns in {ticker} stock price history (aligned, reindexed with correlation)",
216
+ xaxis_title="Days",
217
+ yaxis_title="Reindexed Price",
218
+ legend_title="Legend")
219
+
220
+ figs.append((fig1, fig2))
221
+
222
+ st.session_state.results_corr = figs
223
+
224
+ def run_ta_dtw():
225
+ min_distance_indices = []
226
+ figs = []
227
+ for n_days in n_days_options:
228
+ current_window = st.session_state.price_data_pct_change[-n_days:].values
229
+ current_ta_window = st.session_state.reduced_features[-n_days:]
230
+
231
+ distances = [dtw_distance_with_ta(current_window, st.session_state.price_data_pct_change[start_index:start_index + n_days].values,
232
+ current_ta_window, st.session_state.reduced_features[start_index:start_index + n_days])
233
+ for start_index in range(len(st.session_state.price_data_pct_change) - 2 * n_days - subsequent_days)]
234
+
235
+ min_distance_indices.append(np.argsort(distances)[:3]) # find indices of 3 smallest distances
236
+
237
+ fig1 = go.Figure()
238
+ # Plot the entire stock price data
239
+ fig1.add_trace(go.Scatter(x=st.session_state.data.index, y=st.session_state.data['Close'], mode='lines', name='Overall stock price', line=dict(color='blue')))
240
+ colors = ['red', 'green', 'orange']
241
+ for i, start_index in enumerate(min_distance_indices[-1]):
242
+ # Plot the pattern period
243
+ past_window_start_date = st.session_state.data.index[start_index]
244
+ past_window_end_date = st.session_state.data.index[start_index + n_days + subsequent_days]
245
+ fig1.add_trace(go.Scatter(x=st.session_state.data['Close'][past_window_start_date:past_window_end_date].index,
246
+ y=st.session_state.data['Close'][past_window_start_date:past_window_end_date],
247
+ mode='lines', name=f"Pattern {i + 1}", line=dict(color=colors[i % len(colors)])))
248
+
249
+ # Add labels and legend
250
+ fig1.update_layout(title=f'{ticker} Stock Price Data',
251
+ xaxis_title='Date',
252
+ yaxis_title='Stock Price',
253
+ legend_title='Legend')
254
+
255
+ reindexed_current_window = (st.session_state.price_data_pct_change[-n_days:] + 1).cumprod() * 100
256
+ fig2 = go.Figure()
257
+ for i, start_index in enumerate(min_distance_indices[-1]):
258
+ past_window = st.session_state.price_data_pct_change[start_index:start_index + n_days + subsequent_days]
259
+ reindexed_past_window = (past_window + 1).cumprod() * 100
260
+ fig2.add_trace(go.Scatter(x=list(range(n_days + subsequent_days)), y=reindexed_past_window,
261
+ mode='lines', name=f"Past window {i + 1} (with subsequent {subsequent_days} days)",
262
+ line=dict(color=colors[i % len(colors)], width=3 if i == 0 else 1)))
263
+
264
+ fig2.add_trace(go.Scatter(x=list(range(n_days)), y=reindexed_current_window, mode='lines',
265
+ name="Current window", line=dict(color='black', width=3)))
266
+
267
+ fig2.update_layout(title=f"Most similar {n_days}-day patterns in {ticker} stock price history (aligned, reindexed)",
268
+ xaxis_title="Days",
269
+ yaxis_title="Reindexed Price",
270
+ legend_title="Legend")
271
+
272
+ figs.append((fig1, fig2))
273
+
274
+ st.session_state.results_ta_dtw = figs
275
+
276
+ if st.sidebar.button('Run'):
277
+ st.session_state.data = yf.download(ticker, start=start_date, end=end_date)
278
+ if not st.session_state.data.empty:
279
+ st.session_state.price_data = st.session_state.data['Close']
280
+ st.session_state.price_data_pct_change = st.session_state.price_data.pct_change().dropna()
281
+ st.session_state.data_with_ta = add_ta_features(st.session_state.data)
282
+ st.session_state.reduced_features = extract_and_reduce_features(st.session_state.data_with_ta, n_components=2)
283
+
284
+ if selected == "DTW Pattern Recognition":
285
+ run_dtw()
286
+ elif selected == "Correlation Pattern Recognition":
287
+ run_corr()
288
+ elif selected == "TA-Enhanced DTW Pattern Recognition":
289
+ run_ta_dtw()
290
+
291
+ # Display results and descriptions based on the selected method
292
+ if selected == "DTW Pattern Recognition":
293
+ st.markdown("""
294
+ ### DTW Pattern Recognition
295
+
296
+ This method uses Dynamic Time Warping (DTW) to find patterns in stock prices. DTW is an algorithm that measures similarity between two temporal sequences, which may vary in time or speed. By comparing the current pattern with historical data, it identifies periods in the past with the most similar patterns.
297
+
298
+ **How to use:**
299
+ 1. Enter the stock ticker, start date, and end date.
300
+ 2. Select the number of subsequent days to forecast.
301
+ 3. Select the number of days to compare.
302
+ 4. Click the 'Run' button.
303
+
304
+ **Results:**
305
+ The left chart shows the entire stock price data with the identified patterns highlighted. The right chart shows the reindexed price patterns for comparison.
306
+ """)
307
+ if st.session_state.results_dtw:
308
+ for fig1, fig2 in st.session_state.results_dtw:
309
+ col1, col2 = st.columns(2)
310
+ with col1:
311
+ st.plotly_chart(fig1)
312
+ with col2:
313
+ st.plotly_chart(fig2)
314
+
315
+ elif selected == "Correlation Pattern Recognition":
316
+ st.markdown("""
317
+ ### Correlation Pattern Recognition
318
+
319
+ This method calculates the correlation between the current stock price pattern and historical patterns. Correlation measures how closely two time series move together. Higher correlation values indicate more similar patterns.
320
+
321
+ **How to use:**
322
+ 1. Enter the stock ticker, start date, and end date.
323
+ 2. Select the number of subsequent days to forecast.
324
+ 3. Select the number of days to compare.
325
+ 4. Select the number of days prior to the similar series.
326
+ 5. Click the 'Run' button.
327
+
328
+ **Results:**
329
+ The left chart shows the entire stock price data with the identified patterns highlighted. The right chart shows the reindexed price patterns for comparison.
330
+ """)
331
+ if st.session_state.results_corr:
332
+ for fig1, fig2 in st.session_state.results_corr:
333
+ col1, col2 = st.columns(2)
334
+ with col1:
335
+ st.plotly_chart(fig1)
336
+ with col2:
337
+ st.plotly_chart(fig2)
338
+
339
+ elif selected == "TA-Enhanced DTW Pattern Recognition":
340
+ st.markdown("""
341
+ ### TA-Enhanced DTW Pattern Recognition
342
+
343
+ This method combines technical analysis (TA) features with DTW to enhance pattern recognition. It integrates various TA indicators into the time series data and uses DTW to find the most similar historical patterns, providing a more comprehensive analysis.
344
+
345
+ **How to use:**
346
+ 1. Enter the stock ticker, start date, and end date.
347
+ 2. Select the number of subsequent days to forecast.
348
+ 3. Select the number of days to compare.
349
+ 4. Click the 'Run' button.
350
+
351
+ **Results:**
352
+ The left chart shows the entire stock price data with the identified patterns highlighted. The right chart shows the reindexed price patterns for comparison.
353
+ """)
354
+ if st.session_state.results_ta_dtw:
355
+ for fig1, fig2 in st.session_state.results_ta_dtw:
356
+ col1, col2 = st.columns(2)
357
+ with col1:
358
+ st.plotly_chart(fig1)
359
+ with col2:
360
+ st.plotly_chart(fig2)