yasirapunsith commited on
Commit
51c6c3d
·
1 Parent(s): 0c285af

add files

Browse files
handler.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # handler.py
2
+ import joblib
3
+ import pandas as pd
4
+ import numpy as np
5
+ import math
6
+ from joblib import Parallel, delayed
7
+ from sklearn.cluster import DBSCAN
8
+ import os # For accessing model path
9
+
10
+ # Import your utility functions
11
+ # Make sure your utils directory is alongside handler.py
12
+ # and contains __init__.py, eval.py, formatAndPreprocessNewPatterns.py
13
+ from utils.eval import intersection_over_union
14
+ from utils.formatAndPreprocessNewPatterns import get_patetrn_name_by_encoding, get_pattern_encoding_by_name, get_reverse_pattern_encoding
15
+
16
+ # --- Global Model Loading (Crucial for performance) ---
17
+ # This model will be loaded ONLY ONCE when the server starts.
18
+ # Ensure the path is correct relative to where handler.py runs in the container.
19
+ # The `MODEL_DIR` env var is automatically set by Inference Endpoints.
20
+ # If you place 'Models/' directly in your repo root, it will be at /repository/Models/
21
+ # If you place it outside (not recommended), you'd need to adjust paths.
22
+ # For simplicity, assume `Models/` is in the root of your HF repo.
23
+ MODEL_PATH = os.path.join(os.environ.get("MODEL_DIR", "."), "Models", "Width Aug OHLC_mini_rocket_xgb.joblib")
24
+
25
+ # Load the model globally
26
+ try:
27
+ print(f"Loading model from: {MODEL_PATH}")
28
+ rocket_model = joblib.load(MODEL_PATH)
29
+ print("Model loaded successfully!")
30
+ except Exception as e:
31
+ print(f"Error loading model: {e}")
32
+ # In a real scenario, you might want to raise an exception to prevent the server from starting
33
+ rocket_model = None
34
+
35
+ # --- Helper functions (from your provided code) ---
36
+ # Paste your `process_window`, `parallel_process_sliding_window`,
37
+ # `prepare_dataset_for_cluster`, `cluster_windows` here.
38
+ # Make sure they are defined before `locate_patterns`
39
+ # because locate_patterns depends on them.
40
+
41
+ # Make sure these globals are outside functions if they are truly global constants
42
+ pattern_encoding_reversed = get_reverse_pattern_encoding()
43
+ # model is now `rocket_model` loaded globally
44
+ # plot_count is handled by the API input now
45
+ win_size_proportions = np.round(np.logspace(0, np.log10(20), num=10), 2).tolist()
46
+ padding_proportion = 0.6
47
+ stride = 1
48
+ probab_threshold_list = 0.5
49
+ prob_threshold_of_no_pattern_to_mark_as_no_pattern = 0.5
50
+ target_len = 30 # Not used in your current code
51
+
52
+ eps=0.04
53
+ min_samples=3
54
+ win_width_proportion=10 # Not used in your current code
55
+
56
+
57
+ def process_window(i, ohlc_data_segment, rocket_model, probability_threshold, pattern_encoding_reversed,seg_start, seg_end, window_size, padding_proportion,prob_threshold_of_no_pattern_to_mark_as_no_pattern=1):
58
+ start_index = i - math.ceil(window_size * padding_proportion)
59
+ end_index = start_index + window_size
60
+
61
+ start_index = max(start_index, 0)
62
+ end_index = min(end_index, len(ohlc_data_segment))
63
+
64
+ ohlc_segment = ohlc_data_segment[start_index:end_index]
65
+ if len(ohlc_segment) == 0:
66
+ return None # Skip empty segments
67
+ win_start_date = ohlc_segment['Date'].iloc[0]
68
+ win_end_date = ohlc_segment['Date'].iloc[-1]
69
+
70
+ ohlc_array_for_rocket = ohlc_segment[['Open', 'High', 'Low', 'Close','Volume']].to_numpy().reshape(1, len(ohlc_segment), 5)
71
+ ohlc_array_for_rocket = np.transpose(ohlc_array_for_rocket, (0, 2, 1))
72
+ try:
73
+ pattern_probabilities = rocket_model.predict_proba(ohlc_array_for_rocket)
74
+ except Exception as e:
75
+ print(f"Error in prediction: {e}")
76
+ return None
77
+ max_probability = np.max(pattern_probabilities)
78
+ no_pattern_proba = pattern_probabilities[0][get_pattern_encoding_by_name ('No Pattern')]
79
+ pattern_index = np.argmax(pattern_probabilities)
80
+
81
+ pred_proba = max_probability
82
+ pred_pattern = get_patetrn_name_by_encoding(pattern_index)
83
+ if no_pattern_proba > prob_threshold_of_no_pattern_to_mark_as_no_pattern:
84
+ pred_proba = no_pattern_proba
85
+ pred_pattern = 'No Pattern'
86
+
87
+ new_row = {
88
+ 'Start': win_start_date, 'End': win_end_date, 'Chart Pattern': pred_pattern, 'Seg_Start': seg_start, 'Seg_End': seg_end ,
89
+ 'Probability': pred_proba
90
+ }
91
+ return new_row
92
+
93
+
94
+ def parallel_process_sliding_window(ohlc_data_segment, rocket_model, probability_threshold, stride, pattern_encoding_reversed, window_size, padding_proportion,prob_threshold_of_no_pattern_to_mark_as_no_pattern=1,parallel=True,num_cores=-1):
95
+ seg_start = ohlc_data_segment['Date'].iloc[0]
96
+ seg_end = ohlc_data_segment['Date'].iloc[-1]
97
+
98
+ # Render.com's worker environment for the HF endpoint will have limited cores for single instances.
99
+ # Parallel processing (`joblib.Parallel`) within the *single* HF endpoint worker
100
+ # might not yield significant benefits or might even cause issues if not configured carefully.
101
+ # It's generally better to rely on HF's scaling for multiple requests.
102
+ # Consider setting `parallel=False` or `num_cores=1` for initial deployment if you hit issues.
103
+ # For now, let's keep it as is, but be mindful of resource constraints.
104
+
105
+ if parallel:
106
+ with Parallel(n_jobs=num_cores, verbose=0) as parallel: # verbose=0 to reduce log spam
107
+ results = parallel(
108
+ delayed(process_window)(
109
+ i=i,
110
+ ohlc_data_segment=ohlc_data_segment,
111
+ rocket_model=rocket_model,
112
+ probability_threshold=probability_threshold,
113
+ pattern_encoding_reversed=pattern_encoding_reversed,
114
+ window_size=window_size,
115
+ seg_start=seg_start,
116
+ seg_end=seg_end,
117
+ padding_proportion=padding_proportion,
118
+ prob_threshold_of_no_pattern_to_mark_as_no_pattern=prob_threshold_of_no_pattern_to_mark_as_no_pattern
119
+ )
120
+ for i in range(0, len(ohlc_data_segment), stride)
121
+ )
122
+ return pd.DataFrame([res for res in results if res is not None])
123
+ else:
124
+ results = []
125
+ for i_idx, i in enumerate(range(0, len(ohlc_data_segment), stride)):
126
+ res = process_window(i, ohlc_data_segment, rocket_model, probability_threshold, pattern_encoding_reversed, seg_start, seg_end, window_size, padding_proportion)
127
+ if res is not None:
128
+ results.append(res)
129
+ return pd.DataFrame(results)
130
+
131
+ def prepare_dataset_for_cluster(ohlc_data_segment, win_results_df):
132
+ predicted_patterns = win_results_df.copy()
133
+ # origin_date = ohlc_data_segment['Date'].min() # Not used
134
+ for index, row in predicted_patterns.iterrows():
135
+ pattern_start = row['Start']
136
+ pattern_end = row['End']
137
+ start_point_index = len(ohlc_data_segment[ohlc_data_segment['Date'] < pattern_start])
138
+ pattern_len = len(ohlc_data_segment[(ohlc_data_segment['Date'] >= pattern_start) & (ohlc_data_segment['Date'] <= pattern_end)])
139
+ pattern_mid_index = start_point_index + (pattern_len / 2)
140
+ predicted_patterns.at[index, 'Center'] = pattern_mid_index
141
+ predicted_patterns.at[index, 'Pattern_Start_pos'] = start_point_index
142
+ predicted_patterns.at[index, 'Pattern_End_pos'] = start_point_index + pattern_len
143
+ return predicted_patterns
144
+
145
+ def cluster_windows(predicted_patterns , probability_threshold, window_size,eps = 0.05 , min_samples = 2):
146
+ df = predicted_patterns.copy()
147
+
148
+ if isinstance(probability_threshold, list):
149
+ for i in range(len(probability_threshold)):
150
+ pattern_name = get_patetrn_name_by_encoding(i)
151
+ df.drop(df[(df['Chart Pattern'] == pattern_name) & (df['Probability'] < probability_threshold[i])].index, inplace=True)
152
+ else:
153
+ df = df[df['Probability'] > probability_threshold]
154
+
155
+ cluster_labled_windows = []
156
+ interseced_clusters = []
157
+
158
+ if df.empty: # Handle case where df might be empty after filtering
159
+ return None, None
160
+
161
+ min_center = df['Center'].min()
162
+ max_center = df['Center'].max()
163
+
164
+ for pattern, group in df.groupby('Chart Pattern'):
165
+ centers = group['Center'].values.reshape(-1, 1)
166
+
167
+ if min_center < max_center:
168
+ norm_centers = (centers - min_center) / (max_center - min_center)
169
+ else:
170
+ norm_centers = np.ones_like(centers)
171
+
172
+ db = DBSCAN(eps=eps, min_samples=min_samples).fit(norm_centers)
173
+ group['Cluster'] = db.labels_
174
+ cluster_labled_windows.append(group)
175
+
176
+ for cluster_id, cluster_group in group[group['Cluster'] != -1].groupby('Cluster'):
177
+ expanded_dates = []
178
+ for _, row in cluster_group.iterrows():
179
+ dates = pd.date_range(row["Start"], row["End"])
180
+ expanded_dates.extend(dates)
181
+
182
+ date_counts = pd.Series(expanded_dates).value_counts().sort_index()
183
+ cluster_start = date_counts[date_counts >= 2].index.min()
184
+ cluster_end = date_counts[date_counts >= 2].index.max()
185
+
186
+ interseced_clusters.append({
187
+ 'Chart Pattern': pattern,
188
+ 'Cluster': cluster_id,
189
+ 'Start': cluster_start,
190
+ 'End': cluster_end,
191
+ 'Seg_Start': cluster_group['Seg_Start'].iloc[0],
192
+ 'Seg_End': cluster_group['Seg_End'].iloc[0],
193
+ 'Avg_Probability': cluster_group['Probability'].mean(),
194
+ })
195
+
196
+ if len(cluster_labled_windows) == 0 or len(interseced_clusters) == 0:
197
+ return None, None
198
+
199
+ cluster_labled_windows_df = pd.concat(cluster_labled_windows)
200
+ interseced_clusters_df = pd.DataFrame(interseced_clusters)
201
+ cluster_labled_windows_df = cluster_labled_windows_df.sort_index()
202
+ return cluster_labled_windows_df, interseced_clusters_df
203
+
204
+
205
+ # ========================= locate_patterns function ==========================
206
+
207
+ # This will be your primary inference function called by the HF endpoint.
208
+ class InferenceHandler:
209
+ def __init__(self):
210
+ # Model is loaded globally, so it's accessible here
211
+ self.model = rocket_model
212
+ if self.model is None:
213
+ raise ValueError("ML model failed to load during initialization.")
214
+
215
+ # Initialize other global parameters here as well
216
+ self.pattern_encoding_reversed = pattern_encoding_reversed
217
+ self.win_size_proportions = win_size_proportions
218
+ self.padding_proportion = padding_proportion
219
+ self.stride = stride
220
+ self.probab_threshold_list = probab_threshold_list
221
+ self.prob_threshold_of_no_pattern_to_mark_as_no_pattern = prob_threshold_of_no_pattern_to_mark_as_no_pattern
222
+ self.eps = eps
223
+ self.min_samples = min_samples
224
+
225
+ def __call__(self, inputs):
226
+ """
227
+ Main inference method for the Hugging Face Inference Endpoint.
228
+ Args:
229
+ inputs: A dictionary or list of dictionaries representing the input data.
230
+ For your case, this will be the OHLC data sent from Django.
231
+ Expected format: [{"Date": "YYYY-MM-DD", "Open": ..., "High": ..., ...}, ...]
232
+ Returns:
233
+ A list of dictionaries representing the detected patterns.
234
+ """
235
+ if not self.model:
236
+ raise ValueError("ML model is not loaded. Cannot perform inference.")
237
+
238
+ # Ensure inputs is a list of dictionaries if not already
239
+ if isinstance(inputs, dict):
240
+ inputs = [inputs] # Handle single input dict if needed
241
+
242
+ # Convert input (list of dicts) to pandas DataFrame
243
+ try:
244
+ ohlc_data = pd.DataFrame(inputs)
245
+ # Ensure 'Date' is datetime, it might come as string from JSON
246
+ ohlc_data['Date'] = pd.to_datetime(ohlc_data['Date'])
247
+ # Ensure proper columns exist
248
+ required_cols = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
249
+ if not all(col in ohlc_data.columns for col in required_cols):
250
+ raise ValueError(f"Missing required columns in input data. Expected: {required_cols}, Got: {ohlc_data.columns.tolist()}")
251
+
252
+ except Exception as e:
253
+ print(f"Error processing input data: {e}")
254
+ raise ValueError(f"Invalid input data format: {e}")
255
+
256
+
257
+ ohlc_data_segment = ohlc_data.copy()
258
+ seg_len = len(ohlc_data_segment)
259
+
260
+ if ohlc_data_segment.empty:
261
+ raise ValueError("OHLC Data segment is empty or invalid after processing.")
262
+
263
+ win_results_for_each_size = []
264
+ located_patterns_and_other_info_for_each_size = []
265
+ cluster_labled_windows_list = []
266
+
267
+ used_win_sizes = []
268
+ win_iteration = 0
269
+
270
+ for win_size_proportion in self.win_size_proportions:
271
+ window_size = seg_len // win_size_proportion
272
+ if window_size < 10:
273
+ window_size = 10
274
+ window_size = int(window_size)
275
+ if window_size in used_win_sizes:
276
+ continue
277
+ used_win_sizes.append(window_size)
278
+
279
+ # Pass the globally loaded model `self.model`
280
+ win_results_df = parallel_process_sliding_window(
281
+ ohlc_data_segment,
282
+ self.model,
283
+ self.probab_threshold_list,
284
+ self.stride,
285
+ self.pattern_encoding_reversed,
286
+ window_size,
287
+ self.padding_proportion,
288
+ self.prob_threshold_of_no_pattern_to_mark_as_no_pattern,
289
+ parallel=True, # You might want to test with False/num_cores=1 on HF to avoid internal parallelism issues
290
+ num_cores=-1 # -1 means all available cores; on HF, this will be limited by the instance type
291
+ )
292
+
293
+ if win_results_df is None or win_results_df.empty:
294
+ print(f"Window results dataframe is empty for window size {window_size}")
295
+ continue
296
+ win_results_df['Window_Size'] = window_size
297
+ win_results_for_each_size.append(win_results_df)
298
+
299
+ predicted_patterns = prepare_dataset_for_cluster(ohlc_data_segment, win_results_df)
300
+ if predicted_patterns is None or predicted_patterns.empty:
301
+ print("Predicted patterns dataframe is empty")
302
+ continue
303
+
304
+ # Pass eps and min_samples from handler's state
305
+ cluster_labled_windows_df , interseced_clusters_df = cluster_windows(
306
+ predicted_patterns,
307
+ self.probab_threshold_list,
308
+ window_size,
309
+ eps=self.eps,
310
+ min_samples=self.min_samples
311
+ )
312
+
313
+ if cluster_labled_windows_df is None or interseced_clusters_df is None or cluster_labled_windows_df.empty or interseced_clusters_df.empty:
314
+ print("Clustered windows dataframe is empty")
315
+ continue
316
+ mask = cluster_labled_windows_df['Cluster'] != -1
317
+ cluster_labled_windows_df.loc[mask, 'Cluster'] = cluster_labled_windows_df.loc[mask, 'Cluster'].astype(int) + win_iteration
318
+ interseced_clusters_df['Cluster'] = interseced_clusters_df['Cluster'].astype(int) + win_iteration
319
+ num_of_unique_clusters = interseced_clusters_df[interseced_clusters_df['Cluster']!=-1]['Cluster'].nunique()
320
+ win_iteration += num_of_unique_clusters
321
+ cluster_labled_windows_list.append(cluster_labled_windows_df)
322
+
323
+ interseced_clusters_df['Calc_Start'] = interseced_clusters_df['Start']
324
+ interseced_clusters_df['Calc_End'] = interseced_clusters_df['End']
325
+ located_patterns_and_other_info = interseced_clusters_df.copy()
326
+
327
+ if located_patterns_and_other_info is None or located_patterns_and_other_info.empty:
328
+ print("Located patterns and other info dataframe is empty")
329
+ continue
330
+ located_patterns_and_other_info['Window_Size'] = window_size
331
+
332
+ located_patterns_and_other_info_for_each_size.append(located_patterns_and_other_info)
333
+
334
+ if located_patterns_and_other_info_for_each_size is None or not located_patterns_and_other_info_for_each_size:
335
+ print("Located patterns and other info for each size is empty")
336
+ return [] # Return empty list if no patterns found
337
+
338
+ located_patterns_and_other_info_for_each_size_df = pd.concat(located_patterns_and_other_info_for_each_size)
339
+
340
+ unique_window_sizes = located_patterns_and_other_info_for_each_size_df['Window_Size'].unique()
341
+ unique_patterns = located_patterns_and_other_info_for_each_size_df['Chart Pattern'].unique()
342
+ unique_window_sizes = np.sort(unique_window_sizes)[::-1]
343
+
344
+ filtered_loc_pat_and_info_rows_list = []
345
+
346
+ for chart_pattern in unique_patterns:
347
+ located_patterns_and_other_info_for_each_size_df_chart_pattern = located_patterns_and_other_info_for_each_size_df[located_patterns_and_other_info_for_each_size_df['Chart Pattern'] == chart_pattern]
348
+ for win_size in unique_window_sizes:
349
+ located_patterns_and_other_info_for_each_size_df_win_size_chart_pattern = located_patterns_and_other_info_for_each_size_df_chart_pattern[located_patterns_and_other_info_for_each_size_df_chart_pattern['Window_Size'] == win_size]
350
+ for idx , row in located_patterns_and_other_info_for_each_size_df_win_size_chart_pattern.iterrows():
351
+ start_date = row['Calc_Start']
352
+ end_date = row['Calc_End']
353
+ is_already_included = False
354
+ intersecting_rows = located_patterns_and_other_info_for_each_size_df_chart_pattern[
355
+ (located_patterns_and_other_info_for_each_size_df_chart_pattern['Calc_Start'] <= end_date) &
356
+ (located_patterns_and_other_info_for_each_size_df_chart_pattern['Calc_End'] >= start_date)
357
+ ]
358
+ is_already_included = False
359
+ for idx2, row2 in intersecting_rows.iterrows():
360
+ iou = intersection_over_union(start_date, end_date, row2['Calc_Start'], row2['Calc_End'])
361
+
362
+ if iou > 0.6:
363
+ if row2['Window_Size'] > row['Window_Size']:
364
+ if (row['Avg_Probability'] - row2['Avg_Probability']) > 0.1:
365
+ is_already_included = False
366
+ else:
367
+ is_already_included = True
368
+ break
369
+ elif row['Window_Size'] >= row2['Window_Size']:
370
+ if (row2['Avg_Probability'] - row['Avg_Probability']) > 0.1:
371
+ is_already_included = True
372
+ break
373
+ else:
374
+ is_already_included = False
375
+
376
+ if not is_already_included:
377
+ filtered_loc_pat_and_info_rows_list.append(row)
378
+
379
+ filtered_loc_pat_and_info_df = pd.DataFrame(filtered_loc_pat_and_info_rows_list)
380
+
381
+ # Convert datetime columns to string format for serialization before returning
382
+ datetime_columns = ['Start', 'End', 'Seg_Start', 'Seg_End', 'Calc_Start', 'Calc_End']
383
+ for col in datetime_columns:
384
+ if col in filtered_loc_pat_and_info_df.columns:
385
+ if pd.api.types.is_datetime64_any_dtype(filtered_loc_pat_and_info_df[col]):
386
+ filtered_loc_pat_and_info_df[col] = pd.to_datetime(filtered_loc_pat_and_info_df[col]).dt.strftime('%Y-%m-%d')
387
+ elif not filtered_loc_pat_and_info_df[col].empty and isinstance(filtered_loc_pat_and_info_df[col].iloc[0], str):
388
+ pass
389
+ else:
390
+ filtered_loc_pat_and_info_df[col] = filtered_loc_pat_and_info_df[col].astype(str)
391
+
392
+ # Return as a list of dictionaries (JSON serializable)
393
+ return filtered_loc_pat_and_info_df.to_dict('records')
requirements.txt ADDED
Binary file (2.71 kB). View file
 
utils/__init__.py ADDED
File without changes
utils/eval.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ # from matplotlib import pyplot as plt
3
+ # from matplotlib.gridspec import GridSpec
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+
8
+ def intersection_over_union(start1, end1, start2, end2):
9
+ """
10
+ Compute Intersection over Union (IoU) between two date ranges.
11
+ """
12
+ latest_start = max(start1, start2)
13
+ earliest_end = min(end1, end2)
14
+ overlap = max(0, (earliest_end - latest_start).days + 1)
15
+ union = (end1 - start1).days + (end2 - start2).days + 2 - overlap
16
+ return overlap / union if union > 0 else 0 # Avoid division by zero
17
+
18
+ def mean_abselute_error(start1, end1, start2, end2):
19
+ """
20
+ Compute Mean Absolute Error (MAE) between two date ranges.
21
+ """
22
+ # check if start or end are NAT
23
+ if start1 is pd.NaT or end1 is pd.NaT or start2 is pd.NaT or end2 is pd.NaT:
24
+ print("One of the dates is NaT")
25
+ print(f"start1: {start1}, end1: {end1}, start2: {start2}, end2: {end2}")
26
+ return None
27
+ return (abs(start1 - start2).days + abs(end1 - end2).days) / 2
28
+
29
+
30
+ def get_model_eval_res(located_patterns_and_other_info_updated_dict,window_results_dict,selected_models,selected_test_patterns_without_no_pattern):
31
+ model_eval_results_dict = {}
32
+ for model_name in selected_models:
33
+ print(f"\n Selected model: {model_name}")
34
+
35
+ located_patterns_and_other_info_updated_df = located_patterns_and_other_info_updated_dict[model_name]
36
+ window_results_df = window_results_dict[model_name]
37
+
38
+ # dictionary to store the count of properly located patterns , iou and mae for each properly detected pattern for each model
39
+
40
+
41
+ # Dictionary to store the count of properly located patterns
42
+ number_of_properly_located_patterns = {}
43
+ iou_for_each_properly_detected_pattern = {}
44
+ mae_for_each_properly_detected_pattern = {}
45
+
46
+ # Convert date columns to datetime (once, outside the loop for efficiency)
47
+ located_patterns_and_other_info_updated_df['Calc_Start'] = pd.to_datetime(located_patterns_and_other_info_updated_df['Calc_Start'])
48
+ located_patterns_and_other_info_updated_df['Calc_End'] = pd.to_datetime(located_patterns_and_other_info_updated_df['Calc_End'])
49
+
50
+ # Iterate over test patterns with progress bar
51
+ for index, row in selected_test_patterns_without_no_pattern.iterrows():
52
+ sys.stdout.write(f"\rProcessing row {index + 1}/{len(selected_test_patterns_without_no_pattern)}")
53
+ sys.stdout.flush()
54
+ symbol = row['Symbol']
55
+ chart_pattern = row['Chart Pattern']
56
+ start_date = pd.to_datetime(row['Start']).tz_localize(None)
57
+ end_date = pd.to_datetime(row['End']).tz_localize(None)
58
+
59
+ # Filter for matching symbol and chart pattern
60
+ located_patterns_for_this = located_patterns_and_other_info_updated_df[
61
+ (located_patterns_and_other_info_updated_df['Symbol'] == symbol) &
62
+ (located_patterns_and_other_info_updated_df['Chart Pattern'] == chart_pattern)
63
+ ].copy() # Use `.copy()` to avoid SettingWithCopyWarning
64
+
65
+ if located_patterns_for_this.empty:
66
+ continue # Skip if no matching rows
67
+
68
+ # Compute IoU for each row using .loc to avoid warnings
69
+ located_patterns_for_this.loc[:, 'IoU'] = located_patterns_for_this.apply(
70
+ lambda x: intersection_over_union(start_date, end_date, x['Calc_Start'], x['Calc_End']),
71
+ axis=1
72
+ )
73
+
74
+ # Compute MAE for each row using .loc to avoid warnings
75
+ located_patterns_for_this.loc[:, 'MAE'] = located_patterns_for_this.apply(
76
+ lambda x: mean_abselute_error(start_date, end_date, x['Calc_Start'], x['Calc_End']),
77
+ axis=1
78
+ )
79
+
80
+
81
+ # Filter based on IoU threshold (≥ 0.8)
82
+ located_patterns_for_this_proper = located_patterns_for_this[located_patterns_for_this['IoU'] >= 0.25]
83
+
84
+ if not located_patterns_for_this_proper.empty:
85
+ number_of_properly_located_patterns[chart_pattern] = number_of_properly_located_patterns.get(chart_pattern, 0) + 1
86
+ iou_for_each_properly_detected_pattern[chart_pattern] = iou_for_each_properly_detected_pattern.get(chart_pattern, 0) + max(located_patterns_for_this_proper['IoU'])
87
+ mae_for_each_properly_detected_pattern[chart_pattern] = mae_for_each_properly_detected_pattern.get(chart_pattern, 0) + min(located_patterns_for_this_proper['MAE'])
88
+
89
+ number_of_properly_located_patterns
90
+
91
+ model_eval_results_dict[model_name] = {
92
+ 'number_of_properly_located_patterns': number_of_properly_located_patterns,
93
+ 'iou_for_each_properly_detected_pattern': iou_for_each_properly_detected_pattern,
94
+ 'mae_for_each_properly_detected_pattern': mae_for_each_properly_detected_pattern
95
+ }
96
+ return model_eval_results_dict
97
+
98
+ ############################################################################################
99
+ # Evaluate multiple models and plot
100
+ ############################################################################################
101
+ # Commenting out plotting functions
102
+ """
103
+ def create_comprehensive_model_comparison(all_models_metrics):
104
+
105
+ Create a comprehensive visualization comparing all models across all metrics,
106
+ using nested concentric pie charts for Precision and Recall.
107
+
108
+ Parameters:
109
+ -----------
110
+ all_models_metrics : dict
111
+ Dictionary containing metrics for each model
112
+
113
+ models = list(all_models_metrics.keys())
114
+ n_models = len(models)
115
+
116
+ # Define the metrics to include
117
+ key_metrics = {
118
+ 'total_recall': 'Recall',
119
+ 'total_precision': 'Precision',
120
+ 'overall_f1': 'F1 Score',
121
+ 'overall_iou': 'IoU',
122
+ 'overall_mae': 'MAE'
123
+ }
124
+
125
+ # Create figure with GridSpec for flexible layout
126
+ fig = plt.figure(figsize=(20, 14))
127
+
128
+ # Add main title with enough space for legend below it
129
+ plt.suptitle('Comprehensive Model Evaluation', fontsize=16, y=0.98)
130
+
131
+ # Define a color palette for models
132
+ colors = plt.cm.tab10(np.linspace(0, 1, n_models))
133
+
134
+ # Create a master legend below the title
135
+ legend_handles = [plt.Line2D([0], [0], color=colors[i], lw=4, label=model) for i, model in enumerate(models)]
136
+ fig.legend(
137
+ handles=legend_handles,
138
+ labels=models,
139
+ loc='upper center',
140
+ bbox_to_anchor=(0.5, 0.93), # Moved down from 0.98 to 0.93
141
+ ncol=n_models,
142
+ fontsize=12
143
+ )
144
+
145
+ # Adjust GridSpec to account for the title and legend
146
+ gs = GridSpec(3, 3, figure=fig, height_ratios=[1.2, 1.2, 1], top=0.88) # Reduced top from 0.95 to 0.88
147
+
148
+
149
+
150
+ # 1. Precision Nested Pie Chart - top left
151
+ ax1 = fig.add_subplot(gs[0, 0])
152
+
153
+ # Create a multi-layer nested pie chart for precision
154
+ # Each ring represents a different model
155
+ precision_values = [metrics['total_precision'] for metrics in all_models_metrics.values()]
156
+
157
+ # Calculate radii for each ring (outermost ring is largest)
158
+ radii = np.linspace(0.5, 1.0, n_models+1)[1:] # start from second element to skip 0.5
159
+
160
+ # Plot each model as a ring, outermost = first model
161
+ for i, model in enumerate(models):
162
+ # Create data for this model's ring [precision, 1-precision]
163
+ data = [precision_values[i], 1-precision_values[i]]
164
+ colors_ring = [colors[i], 'lightgray']
165
+
166
+ # Create pie chart for this ring
167
+ wedges, texts = ax1.pie(
168
+ data,
169
+ radius=radii[i],
170
+ colors=colors_ring,
171
+ startangle=90,
172
+ counterclock=False,
173
+ wedgeprops=dict(width=0.15, edgecolor='w')
174
+ )
175
+
176
+ # Add only the value (no model name) to the pie chart wedge
177
+ angle = (wedges[0].theta1 + wedges[0].theta2) / 2
178
+ x = (radii[i] - 0.075) * np.cos(np.radians(angle))
179
+ y = (radii[i] - 0.075) * np.sin(np.radians(angle))
180
+ ax1.text(x, y, f"{precision_values[i]:.3f}",
181
+ ha='center', va='center', fontsize=10, fontweight='bold')
182
+
183
+ # Create center circle for donut effect
184
+ centre_circle = plt.Circle((0, 0), 0.25, fc='white')
185
+ ax1.add_patch(centre_circle)
186
+
187
+ ax1.set_title('Precision Comparison (Higher is Better)')
188
+ ax1.set_aspect('equal')
189
+
190
+ # 2. Recall Nested Pie Chart - top middle
191
+ ax2 = fig.add_subplot(gs[0, 1])
192
+
193
+ # Create a multi-layer nested pie chart for recall
194
+ recall_values = [metrics['total_recall'] for metrics in all_models_metrics.values()]
195
+
196
+ # Plot each model as a ring, outermost = first model
197
+ for i, model in enumerate(models):
198
+ # Create data for this model's ring [recall, 1-recall]
199
+ data = [recall_values[i], 1-recall_values[i]]
200
+ colors_ring = [colors[i], 'lightgray']
201
+
202
+ # Create pie chart for this ring
203
+ wedges, texts = ax2.pie(
204
+ data,
205
+ radius=radii[i],
206
+ colors=colors_ring,
207
+ startangle=90,
208
+ counterclock=False,
209
+ wedgeprops=dict(width=0.15, edgecolor='w')
210
+ )
211
+
212
+ # Add only the value (no model name) to the pie chart wedge
213
+ angle = (wedges[0].theta1 + wedges[0].theta2) / 2
214
+ x = (radii[i] - 0.075) * np.cos(np.radians(angle))
215
+ y = (radii[i] - 0.075) * np.sin(np.radians(angle))
216
+ ax2.text(x, y, f"{recall_values[i]:.3f}",
217
+ ha='center', va='center', fontsize=10, fontweight='bold')
218
+
219
+ # Create center circle for donut effect
220
+ centre_circle = plt.Circle((0, 0), 0.25, fc='white')
221
+ ax2.add_patch(centre_circle)
222
+
223
+ ax2.set_title('Recall Comparison (Higher is Better)')
224
+ ax2.set_aspect('equal')
225
+
226
+ # 3. F1 Score and IoU - top right
227
+ ax3 = fig.add_subplot(gs[0, 2])
228
+
229
+ # Prepare data for grouped bar chart
230
+ metrics_to_plot = ['overall_f1', 'overall_iou']
231
+ x = np.arange(len(metrics_to_plot))
232
+ width = 0.8 / n_models
233
+
234
+ # Plot grouped bars for each model
235
+ for i, (model_name, metrics) in enumerate(all_models_metrics.items()):
236
+ values = [metrics[key] for key in metrics_to_plot]
237
+ bars = ax3.bar(x + i*width - width*(n_models-1)/2, values, width, color=colors[i])
238
+
239
+ # Add value labels above each bar
240
+ for bar, value in zip(bars, values):
241
+ height = bar.get_height()
242
+ ax3.text(bar.get_x() + bar.get_width()/2., height + 0.01,
243
+ f'{value:.3f}', ha='center', va='bottom', fontsize=9, rotation=0)
244
+
245
+ # Customize the plot
246
+ ax3.set_xticks(x)
247
+ ax3.set_xticklabels([key_metrics[key] for key in metrics_to_plot])
248
+ ax3.set_ylabel('Score')
249
+ ax3.set_title('F1 Score & IoU Comparison (Higher is Better)')
250
+ ax3.set_ylim(0, 1.0)
251
+ ax3.grid(axis='y', linestyle='--', alpha=0.7)
252
+
253
+ # 4. MAE comparison (separate bar chart) - middle left
254
+ ax4 = fig.add_subplot(gs[1, 0])
255
+
256
+ mae_values = [metrics['overall_mae'] for metrics in all_models_metrics.values()]
257
+ bars = ax4.bar(models, mae_values, color=colors)
258
+
259
+ # Add value labels above MAE bars
260
+ for bar, value in zip(bars, mae_values):
261
+ height = bar.get_height()
262
+ ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01,
263
+ f'{value:.3f}', ha='center', va='bottom', fontsize=9)
264
+
265
+ ax4.set_ylabel('Error')
266
+ ax4.set_title('Mean Absolute Error (Lower is Better)')
267
+ ax4.grid(axis='y', linestyle='--', alpha=0.7)
268
+
269
+ # 5. Model metrics radar chart - middle center
270
+ ax5 = fig.add_subplot(gs[1, 1], polar=True)
271
+
272
+ # Setup for radar chart
273
+ metrics_for_radar = ['total_recall', 'total_precision', 'overall_f1', 'overall_iou']
274
+ num_vars = len(metrics_for_radar)
275
+ angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False).tolist()
276
+ angles += angles[:1] # Close the loop
277
+
278
+ # Plot each model on the radar chart
279
+ for i, (model_name, metrics) in enumerate(all_models_metrics.items()):
280
+ values = [metrics[metric] for metric in metrics_for_radar]
281
+ values += values[:1] # Close the loop
282
+
283
+ ax5.plot(angles, values, linewidth=2, linestyle='solid', color=colors[i])
284
+ ax5.fill(angles, values, alpha=0.1, color=colors[i])
285
+
286
+ # Set radar chart labels
287
+ ax5.set_xticks(angles[:-1])
288
+ ax5.set_xticklabels([key_metrics[metric] for metric in metrics_for_radar])
289
+ ax5.set_ylim(0, 1)
290
+ ax5.set_title('Model Performance Radar Chart')
291
+
292
+ # 6. Model comparison bar - middle right
293
+ ax6 = fig.add_subplot(gs[1, 2])
294
+
295
+ # Calculate the average of the four main metrics for an overall score
296
+ # (excluding MAE which is inverse, lower is better)
297
+ overall_scores = []
298
+ for model_name, metrics in all_models_metrics.items():
299
+ score = (metrics['total_recall'] + metrics['total_precision'] +
300
+ metrics['overall_f1'] + metrics['overall_iou']) / 4
301
+ overall_scores.append(score)
302
+
303
+ # Create horizontal bar chart
304
+ y_pos = np.arange(len(models))
305
+ ax6.barh(y_pos, overall_scores, color=colors)
306
+ ax6.set_yticks(y_pos)
307
+ ax6.set_yticklabels(models)
308
+ ax6.invert_yaxis() # labels read top-to-bottom
309
+ ax6.set_xlabel('Overall Performance Score')
310
+ ax6.set_title('Overall Model Comparison (Higher is Better)')
311
+
312
+ # Add value labels
313
+ for i, v in enumerate(overall_scores):
314
+ ax6.text(v + 0.01, i, f'{v:.3f}', va='center')
315
+
316
+ # 7. Detailed per-model metrics table - bottom span all columns
317
+ ax7 = fig.add_subplot(gs[2, :])
318
+ ax7.axis('tight')
319
+ ax7.axis('off')
320
+
321
+ # Prepare table data
322
+ table_data = []
323
+ for model_name, metrics in all_models_metrics.items():
324
+ row = [model_name]
325
+ for key in key_metrics:
326
+ row.append(f"{metrics[key]:.4f}")
327
+ table_data.append(row)
328
+
329
+ # Create table
330
+ column_labels = ['Model'] + list(key_metrics.values())
331
+ table = ax7.table(
332
+ cellText=table_data,
333
+ colLabels=column_labels,
334
+ loc='center',
335
+ cellLoc='center'
336
+ )
337
+ table.auto_set_font_size(False)
338
+ table.set_fontsize(10)
339
+ table.scale(1, 1.5)
340
+ ax7.set_title('Model Metrics Summary Table')
341
+
342
+ plt.tight_layout(rect=[0, 0.03, 1, 0.88]) # Adjusted rect to account for title and legend
343
+
344
+ plt.show()
345
+
346
+ return fig
347
+
348
+ # The evaluate_model and evaluate_all_models functions remain unchanged
349
+ # The evaluate_model and evaluate_all_models functions remain unchanged
350
+ # The evaluate_model function remains unchanged from your second code snippet
351
+ def evaluate_model(model_name, model_eval_results_dict, pattern_row_count, test_patterns, located_patterns_and_other_info_updated_dict):
352
+ Evaluate a model and calculate metrics without redundant plots
353
+ print(f"\n{'='*20} Model: {model_name} {'='*20}")
354
+
355
+ # Extract model results
356
+ number_of_properly_located_patterns = model_eval_results_dict[model_name]['number_of_properly_located_patterns']
357
+ located_patterns_df = located_patterns_and_other_info_updated_dict[model_name]
358
+ mae_for_each_properly_detected_pattern = model_eval_results_dict[model_name]['mae_for_each_properly_detected_pattern']
359
+ iou_for_each_properly_detected_pattern = model_eval_results_dict[model_name]['iou_for_each_properly_detected_pattern']
360
+
361
+ # Calculate metrics without plotting
362
+ # Recall
363
+ total_number_of_all_patterns = sum(pattern_row_count.values())
364
+ total_number_of_properly_located_patterns = sum(number_of_properly_located_patterns.values())
365
+ total_recall = total_number_of_properly_located_patterns / total_number_of_all_patterns if total_number_of_all_patterns > 0 else 0
366
+
367
+ per_pattern_recall = {}
368
+ for pattern, count in number_of_properly_located_patterns.items():
369
+ pattern_count = test_patterns[test_patterns['Chart Pattern'] == pattern].shape[0]
370
+ if pattern_count > 0:
371
+ per_pattern_recall[pattern] = count / pattern_count
372
+ else:
373
+ per_pattern_recall[pattern] = 0
374
+
375
+ # Precision
376
+ total_number_of_all_located_patterns = len(located_patterns_df)
377
+ total_precision = total_number_of_properly_located_patterns / total_number_of_all_located_patterns if total_number_of_all_located_patterns > 0 else 0
378
+
379
+ per_pattern_precision = {}
380
+ for pattern, count in number_of_properly_located_patterns.items():
381
+ pattern_predictions = located_patterns_df[located_patterns_df['Chart Pattern'] == pattern].shape[0]
382
+ if pattern_predictions > 0:
383
+ per_pattern_precision[pattern] = count / pattern_predictions
384
+ else:
385
+ per_pattern_precision[pattern] = 0
386
+
387
+ # F1 Score
388
+ per_pattern_f1 = {}
389
+ for pattern in per_pattern_recall.keys():
390
+ precision = per_pattern_precision.get(pattern, 0)
391
+ recall = per_pattern_recall.get(pattern, 0)
392
+ if precision + recall > 0:
393
+ per_pattern_f1[pattern] = 2 * (precision * recall) / (precision + recall)
394
+ else:
395
+ per_pattern_f1[pattern] = 0
396
+
397
+ all_precisions = list(per_pattern_precision.values())
398
+ all_recalls = list(per_pattern_recall.values())
399
+ avg_precision = sum(all_precisions) / len(all_precisions) if all_precisions else 0
400
+ avg_recall = sum(all_recalls) / len(all_recalls) if all_recalls else 0
401
+
402
+ if avg_precision + avg_recall == 0:
403
+ overall_f1 = 0
404
+ else:
405
+ overall_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall)
406
+
407
+ # MAE
408
+ per_pattern_mae = {}
409
+ for pattern, count in number_of_properly_located_patterns.items():
410
+ if count > 0:
411
+ per_pattern_mae[pattern] = mae_for_each_properly_detected_pattern.get(pattern, 0) / count
412
+ else:
413
+ per_pattern_mae[pattern] = 0
414
+
415
+ total_mae_sum = sum(mae_for_each_properly_detected_pattern.values())
416
+ total_proper_patterns = sum(number_of_properly_located_patterns.values())
417
+ overall_mae = total_mae_sum / total_proper_patterns if total_proper_patterns > 0 else 0
418
+
419
+ # IoU
420
+ per_pattern_iou = {}
421
+ for pattern, count in number_of_properly_located_patterns.items():
422
+ if count > 0:
423
+ per_pattern_iou[pattern] = iou_for_each_properly_detected_pattern.get(pattern, 0) / count
424
+ else:
425
+ per_pattern_iou[pattern] = 0
426
+
427
+ total_iou_sum = sum(iou_for_each_properly_detected_pattern.values())
428
+ overall_iou = total_iou_sum / total_proper_patterns if total_proper_patterns > 0 else 0
429
+
430
+ # Print summary of metrics
431
+ print(f"Overall Recall: {total_recall:.4f}")
432
+ print(f"Overall Precision: {total_precision:.4f}")
433
+ print(f"Overall F1 Score: {overall_f1:.4f}")
434
+ print(f"Overall Mean Absolute Error: {overall_mae:.4f}")
435
+ print(f"Overall Mean Intersection over Union: {overall_iou:.4f}")
436
+
437
+ # Store all metrics in one place for easy access
438
+ metrics_summary = {
439
+ 'total_recall': total_recall,
440
+ 'per_pattern_recall': per_pattern_recall,
441
+ 'total_precision': total_precision,
442
+ 'per_pattern_precision': per_pattern_precision,
443
+ 'overall_f1': overall_f1,
444
+ 'per_pattern_f1': per_pattern_f1,
445
+ 'overall_mae': overall_mae,
446
+ 'per_pattern_mae': per_pattern_mae,
447
+ 'overall_iou': overall_iou,
448
+ 'per_pattern_iou': per_pattern_iou
449
+ }
450
+
451
+ return metrics_summary
452
+
453
+ # Updated evaluate_all_models function that only creates the comprehensive plot
454
+ def evaluate_all_models(model_eval_results_dict, pattern_row_count, test_patterns, located_patterns_and_other_info_updated_dict):
455
+ Evaluate all models and return metrics summary with comprehensive plot only
456
+ all_models_metrics = {}
457
+
458
+ for model_name in model_eval_results_dict.keys():
459
+ all_models_metrics[model_name] = evaluate_model(
460
+ model_name,
461
+ model_eval_results_dict,
462
+ pattern_row_count,
463
+ test_patterns,
464
+ located_patterns_and_other_info_updated_dict
465
+ )
466
+
467
+ # Only create the comprehensive visualization
468
+ if len(model_eval_results_dict) > 0:
469
+ print("\n--- Comprehensive Model Comparison ---")
470
+ # figure = create_comprehensive_model_comparison(all_models_metrics)
471
+
472
+ return all_models_metrics, None # Return None instead of figure
473
+ """
474
+ ###########################################################################################################
utils/formatAndPreprocessNewPatterns.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import the necessary libraries
2
+ from multiprocessing import Manager, Value
3
+ import os
4
+ import numpy as np
5
+ import pandas as pd
6
+ from joblib import Parallel, delayed
7
+ import math
8
+ from scipy import interpolate
9
+ from tqdm import tqdm
10
+
11
+ from utils.drawPlots import plot_ohlc_segment
12
+
13
+ original_pattern_name_list = [
14
+ 'Double Top, Adam and Adam',
15
+ 'Double Top, Adam and Eve',
16
+ 'Double Top, Eve and Eve',
17
+ 'Double Top, Eve and Adam',
18
+ 'Double Bottom, Adam and Adam',
19
+ 'Double Bottom, Eve and Adam',
20
+ 'Double Bottom, Eve and Eve',
21
+ 'Double Bottom, Adam and Eve',
22
+ 'Triangle, symmetrical',
23
+ 'Head-and-shoulders top',
24
+ 'Head-and-shoulders bottom',
25
+ 'Flag, high and tight'
26
+ ]
27
+
28
+ # Updated pattern encoding
29
+ pattern_encoding = {
30
+ 'Double Top': 0,
31
+ 'Double Bottom': 1,
32
+ 'Triangle, symmetrical': 2,
33
+ 'Head-and-shoulders top': 3,
34
+ 'Head-and-shoulders bottom': 4,
35
+ 'Flag, high and tight': 5,
36
+ 'No Pattern': 6
37
+ }
38
+
39
+ def get_pattern_encoding():
40
+ return pattern_encoding
41
+
42
+ def get_reverse_pattern_encoding():
43
+ return {v: k for k, v in pattern_encoding.items()}
44
+
45
+ def get_patetrn_name_by_encoding(encoding):
46
+ """
47
+ Get the pattern name by encoding.
48
+
49
+ # Input:
50
+ - encoding (int): The encoding of the pattern.
51
+
52
+ # Returns:
53
+ - str: The name of the pattern.
54
+ """
55
+ return get_reverse_pattern_encoding().get(encoding, 'Unknown Pattern')
56
+
57
+ def get_pattern_encoding_by_name(name):
58
+ """
59
+ Get the pattern encoding by name.
60
+
61
+ # Input:
62
+ - name (str): The name of the pattern.
63
+
64
+ # Returns:
65
+ - int: The encoding of the pattern.
66
+ """
67
+ return get_pattern_encoding().get(name, -1)
68
+
69
+ def get_pattern_list():
70
+ return list(pattern_encoding.keys())
71
+
72
+ def filter_to_get_selected_patterns(df):
73
+ # Filter dataframe to only include selected patterns
74
+ df = df[df['Chart Pattern'].isin(original_pattern_name_list)].copy() # Explicit copy to avoid warning
75
+
76
+ # Replace all variations of Double Top and Double Bottom with simplified names
77
+ double_top_variations = {
78
+ 'Double Top, Adam and Adam': 'Double Top',
79
+ 'Double Top, Adam and Eve': 'Double Top',
80
+ 'Double Top, Eve and Eve': 'Double Top',
81
+ 'Double Top, Eve and Adam': 'Double Top'
82
+ }
83
+
84
+ double_bottom_variations = {
85
+ 'Double Bottom, Adam and Adam': 'Double Bottom',
86
+ 'Double Bottom, Eve and Adam': 'Double Bottom',
87
+ 'Double Bottom, Eve and Eve': 'Double Bottom',
88
+ 'Double Bottom, Adam and Eve': 'Double Bottom'
89
+ }
90
+
91
+ # Combine all variations into a single mapping
92
+ pattern_mapping = {**double_top_variations, **double_bottom_variations}
93
+
94
+ # Use .loc[] to modify the dataframe safely
95
+ df.loc[:, 'Chart Pattern'] = df['Chart Pattern'].replace(pattern_mapping)
96
+
97
+ return df
98
+
99
+ def normalize_dataset(dataset):
100
+ # calculate the min values from Low column and max values from High column for each instance
101
+ min_low = dataset.groupby(level='Instance')['Low'].transform('min')
102
+ max_high = dataset.groupby(level='Instance')['High'].transform('max')
103
+
104
+ # OHLC columns to normalize
105
+ ohlc_columns = ['Open', 'High', 'Low', 'Close']
106
+
107
+ dataset_normalized = dataset.copy()
108
+
109
+ # Apply the normalization formula to all columns in one go
110
+ dataset_normalized[ohlc_columns] = (dataset_normalized[ohlc_columns] - min_low.values[:, None]) / (max_high.values[:, None] - min_low.values[:, None])
111
+
112
+ # if there is a Volume column normalize it
113
+ if 'Volume' in dataset.columns:
114
+ # calculate the min values from Volume column and max values from Volume column for each instance
115
+ min_volume = dataset.groupby(level='Instance')['Volume'].transform('min')
116
+ max_volume = dataset.groupby(level='Instance')['Volume'].transform('max')
117
+
118
+ # Normalize the Volume column
119
+ dataset_normalized['Volume'] = (dataset_normalized['Volume'] - min_volume.values) / (max_volume.values - min_volume)
120
+
121
+
122
+ return dataset_normalized
123
+
124
+ def normalize_ohlc_segment(dataset):
125
+ # calculate the min values from Low column and max values from High column for each instance
126
+ min_low = dataset['Low'].min()
127
+ max_high = dataset['High'].max()
128
+
129
+ # OHLC columns to normalize
130
+ ohlc_columns = ['Open', 'High', 'Low', 'Close']
131
+
132
+ dataset_normalized = dataset.copy()
133
+
134
+ if (max_high - min_low) != 0:
135
+ # Apply the normalization formula to all columns in one go
136
+ dataset_normalized[ohlc_columns] = (dataset_normalized[ohlc_columns] - min_low) / (max_high - min_low)
137
+ else :
138
+ print("Error: Max high and min low are equal")
139
+
140
+ # if there is a Volume column normalize it
141
+ if 'Volume' in dataset.columns:
142
+ # calculate the min values from Volume column and max values from Volume column for each instance
143
+ min_volume = dataset['Volume'].min()
144
+ max_volume = dataset['Volume'].max()
145
+
146
+ if (max_volume - min_volume) != 0:
147
+ # Normalize the Volume column
148
+ dataset_normalized['Volume'] = (dataset_normalized['Volume'] - min_volume) / (max_volume - min_volume)
149
+ else:
150
+ print("Error: Max volume and min volume are equal")
151
+
152
+
153
+ return dataset_normalized
154
+
155
+ def process_row_improved(idx, row, ohlc_df, instance_counter, lock, successful_instances, instance_index_mapping):
156
+ try:
157
+ # Extract info and filter data
158
+ start_date = pd.to_datetime(row['Start'])
159
+ end_date = pd.to_datetime(row['End'])
160
+
161
+ symbol_df_filtered = ohlc_df[(ohlc_df['Date'] >= start_date) &
162
+ (ohlc_df['Date'] <= end_date)]
163
+
164
+ if symbol_df_filtered.empty:
165
+ print(f"Empty result for {row['Symbol']} from {start_date} to {end_date}")
166
+ return None
167
+
168
+ # Get unique instance ID
169
+ with lock:
170
+ unique_instance = instance_counter.value
171
+ instance_counter.value += 1
172
+
173
+ # Explicitly add to instance_index_mapping using string key conversion
174
+ instance_index_mapping[unique_instance] = idx
175
+
176
+ # Track successful instances
177
+ successful_instances.append(unique_instance)
178
+
179
+ # Setup MultiIndex
180
+ symbol_df_filtered = symbol_df_filtered.reset_index(drop=True)
181
+ multi_index = pd.MultiIndex.from_arrays(
182
+ [[unique_instance] * len(symbol_df_filtered), range(len(symbol_df_filtered))],
183
+ names=["Instance", "Time"]
184
+ )
185
+ symbol_df_filtered.index = multi_index
186
+
187
+ # Set index levels to proper types
188
+ symbol_df_filtered.index = symbol_df_filtered.index.set_levels(
189
+ symbol_df_filtered.index.levels[0].astype('int'), level=0
190
+ )
191
+ symbol_df_filtered.index = symbol_df_filtered.index.set_levels(
192
+ symbol_df_filtered.index.levels[1].astype('int64'), level=1
193
+ )
194
+
195
+ # Add pattern and clean up
196
+ symbol_df_filtered['Pattern'] = pattern_encoding[row['Chart Pattern']]
197
+ symbol_df_filtered.drop('Date', axis=1, inplace=True)
198
+ if 'Adj Close' in symbol_df_filtered.columns:
199
+ symbol_df_filtered.drop('Adj Close', axis=1, inplace=True)
200
+
201
+ # Normalize
202
+ symbol_df_filtered = normalize_ohlc_segment(symbol_df_filtered)
203
+
204
+ return symbol_df_filtered
205
+
206
+ except Exception as e:
207
+ print(f"Error processing {row['Symbol']}: {str(e)}")
208
+ return None
209
+
210
+ def dataset_format(filteredPatternDf, give_instance_index_mapping=False):
211
+ """
212
+ Formats and preprocesses the dataset with better tracking of successful instances.
213
+ """
214
+ # Get symbol list from files
215
+ folder_path = 'Datasets/OHLC data/'
216
+ file_list = os.listdir(folder_path)
217
+ symbol_list = [file[:-4] for file in file_list if file.endswith('.csv')]
218
+
219
+ # Check for missing symbols
220
+ symbols_in_df = filteredPatternDf['Symbol'].unique()
221
+ missing_symbols = set(symbols_in_df) - set(symbol_list)
222
+ if missing_symbols:
223
+ print("Missing symbols: ", missing_symbols)
224
+
225
+ # Create a list of tasks (symbol, row pairs)
226
+ tasks = []
227
+ for symbol in symbols_in_df:
228
+ if symbol in symbol_list: # Skip missing symbols
229
+ filteredPatternDf_for_symbol = filteredPatternDf[filteredPatternDf['Symbol'] == symbol]
230
+ file_path = os.path.join(folder_path, f"{symbol}.csv")
231
+
232
+ # Pre-load symbol data
233
+ try:
234
+ symbol_df = pd.read_csv(file_path)
235
+ symbol_df['Date'] = pd.to_datetime(symbol_df['Date'])
236
+ symbol_df['Date'] = symbol_df['Date'].dt.tz_localize(None)
237
+
238
+ for idx, row in filteredPatternDf_for_symbol.iterrows():
239
+ tasks.append((idx, row, symbol_df))
240
+ except Exception as e:
241
+ print(f"Error loading {symbol}: {str(e)}")
242
+
243
+ print(f"Processing {len(tasks)} tasks in parallel...")
244
+
245
+ # Process all tasks with instance tracking
246
+ with Manager() as manager:
247
+ instance_counter = manager.Value('i', 0)
248
+ lock = manager.Lock()
249
+ successful_instances = manager.list() # Track which instances succeed
250
+ instance_index_mapping = manager.dict() # Mapping from instance ID to index
251
+
252
+ results = Parallel(n_jobs=-1, verbose=1)(
253
+ delayed(process_row_improved)(task_idx, row, df, instance_counter, lock, successful_instances, instance_index_mapping)
254
+ for task_idx, row, df in tasks
255
+ )
256
+
257
+ # Filter out None results
258
+ results = [result for result in results if result is not None]
259
+
260
+ print(f"Total tasks: {len(tasks)}, Successful: {len(results)}")
261
+ print(f"Instance counter final value: {instance_counter.value}")
262
+ print(f"Number of successful instances: {len(successful_instances)}")
263
+
264
+ # # Debug print for mapping
265
+ # print("Debug - Instance Index Mapping:")
266
+ # for k, v in instance_index_mapping.items():
267
+ # print(f"Key: {k}, Value: {v}")
268
+
269
+ if len(successful_instances) < instance_counter.value:
270
+ print("Warning: Some instances were assigned but their tasks failed")
271
+
272
+ # Concatenate results and renumber instances if needed
273
+ if results:
274
+ dataset = pd.concat(results)
275
+ dataset = dataset.sort_index(level=0)
276
+
277
+ # Replace inf/nan values
278
+ dataset.replace([np.inf, -np.inf], np.nan, inplace=True)
279
+ dataset.fillna(method='ffill', inplace=True)
280
+
281
+ if give_instance_index_mapping:
282
+ # Convert manager.dict to a regular dictionary
283
+ instance_index_mapping_dict = dict(instance_index_mapping)
284
+
285
+ print("Converted Mapping:", instance_index_mapping_dict)
286
+ return dataset, instance_index_mapping_dict
287
+ else:
288
+ return dataset
289
+ else:
290
+ return pd.DataFrame()
291
+
292
+
293
+
294
+
295
+ def width_augmentation (filteredPatternDf, min_aug_len , aug_len_fraction, make_duplicates = False , keep_original = False):
296
+ """
297
+ Perform width augmentation on the filtered pattern DataFrame.
298
+
299
+ # Input:
300
+ - filteredPatternDf (pd.DataFrame): The filtered pattern DataFrame.
301
+ - min_aug_len (int): The minimum length of the augmented data.
302
+ - aug_len_fraction (float): The fraction of the original data size to determine the maximum length of the augmented data.
303
+ - make_duplicates (bool): Flag to indicate whether to make duplicates of patterns to reduce dataset imbalance.(make this false on test data)
304
+ - keep_original (bool): Flag to indicate whether to keep the original patterns in the augmented DataFrame.
305
+
306
+ # Returns:
307
+ - filteredPattern_width_aug_df (pd.DataFrame): The DataFrame with width-augmented patterns.
308
+
309
+ """
310
+
311
+ filteredPattern_width_aug_df = pd.DataFrame(columns=filteredPatternDf.columns)
312
+
313
+ print('Performing width augmentation...')
314
+ # print('Number of patterns:', len(filteredPatternDf))
315
+
316
+ # loop through the rows of filteredPatternDf
317
+ for index, row in tqdm(filteredPatternDf.iterrows(), total=len(filteredPatternDf), desc="Processing"):
318
+
319
+ symbol = row['Symbol']
320
+ start_date = row['Start']
321
+ end_date = row['End']
322
+ pattern = row['Chart Pattern']
323
+
324
+ ohlc_df = pd.read_csv(f'Datasets/OHLC data/{symbol}.csv')
325
+ # Ensure all datetime objects are timezone-naive
326
+ ohlc_df['Date'] = pd.to_datetime(ohlc_df['Date']).dt.tz_localize(None)
327
+
328
+ # Convert start_date and end_date to timezone-naive if they have a timezone
329
+ start_date = pd.to_datetime(start_date).tz_localize(None)
330
+ end_date = pd.to_datetime(end_date).tz_localize(None)
331
+
332
+ ohlc_of_interest = ohlc_df[(ohlc_df['Date'] >= start_date) & (ohlc_df['Date'] <= end_date)]
333
+ data_size = len(ohlc_of_interest)
334
+
335
+ if data_size <= 0:
336
+ print (f'No data for {symbol} between {start_date} and {end_date}')
337
+ continue
338
+
339
+ # index of ohlc data on the start date and end date
340
+ start_index = ohlc_of_interest.index[0]
341
+ end_index = ohlc_of_interest.index[-1]
342
+
343
+ min_possible_index = 0
344
+ max_possible_index = len(ohlc_df) - 1
345
+
346
+ number_of_rows_for_pattern= filteredPatternDf['Chart Pattern'].value_counts()[pattern]
347
+ max_num_of_rows_for_pattern = filteredPatternDf['Chart Pattern'].value_counts().max()
348
+
349
+ # to make the number of rows for each pattern equal to reduce the imbalance in the dataset
350
+ if make_duplicates:
351
+ num_row_diff = (max_num_of_rows_for_pattern - number_of_rows_for_pattern)*2
352
+
353
+ multiplier = math.ceil(num_row_diff / number_of_rows_for_pattern) +2
354
+ # print ('Pattern :', pattern , 'Multiplier :' , multiplier , 'Number of rows for pattern :', number_of_rows_for_pattern)
355
+ # get a random mvalue between 1 to multiplier
356
+ m = np.random.randint(1, multiplier)
357
+ else:
358
+ m = 1
359
+
360
+ for i in range(m):
361
+ max_aug_len = math.ceil(data_size * aug_len_fraction)
362
+ if max_aug_len < min_aug_len:
363
+ max_aug_len = min_aug_len
364
+ aug_len_l = np.random.randint(1, max_aug_len)
365
+ aug_len_r = np.random.randint(1, max_aug_len)
366
+
367
+ # get the start and end index of the augmented data
368
+ start_index_aug = start_index - aug_len_l
369
+ end_index_aug = end_index + aug_len_r
370
+
371
+ if start_index_aug < min_possible_index:
372
+ start_index_aug = min_possible_index
373
+ if end_index_aug > max_possible_index:
374
+ end_index_aug = max_possible_index
375
+
376
+ # get the date of the start and end index of the augmented data
377
+ start_date_aug = ohlc_df.iloc[start_index_aug]['Date']
378
+ end_date_aug = ohlc_df.iloc[end_index_aug]['Date']
379
+
380
+ # create a new row for the augmented data
381
+ new_row = row.copy()
382
+ new_row['Start'] = start_date_aug
383
+ new_row['End'] = end_date_aug
384
+ filteredPattern_width_aug_df = pd.concat([filteredPattern_width_aug_df, pd.DataFrame([new_row])], ignore_index=True)
385
+
386
+ if keep_original:
387
+ # concat the original row too
388
+ filteredPattern_width_aug_df = pd.concat([filteredPattern_width_aug_df, pd.DataFrame([row])], ignore_index=True)
389
+
390
+ return filteredPattern_width_aug_df
391
+
392
+ def normalize_ohlc_len(df, target_len=30 , plot_count= 0):
393
+
394
+ instances_list = df.index.get_level_values(0).unique()
395
+ normalized_df_list = []
396
+
397
+ # pick 10 random instances from the list of instances to plot
398
+ random_indices = np.random.choice(len(instances_list), plot_count, replace=False)
399
+
400
+ for instance in instances_list:
401
+
402
+ sample = df.loc[instance]
403
+
404
+ pattern_df = sample.copy()
405
+ new_data = {}
406
+ orig_indices = pattern_df.index.values # Changed this line
407
+ new_indices = np.linspace(0, len(orig_indices) - 1, target_len)
408
+
409
+ # First interpolate all numerical columns
410
+ for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
411
+ # Determine the best interpolation method based on data length
412
+ if len(orig_indices) >= 4: # Enough points for cubic
413
+ kind = 'cubic'
414
+ elif len(orig_indices) >= 3: # Can use quadratic
415
+ kind = 'quadratic'
416
+ elif len(orig_indices) >= 2: # Can use linear
417
+ kind = 'linear'
418
+ else: # Not enough points, use nearest
419
+ kind = 'nearest'
420
+
421
+ f = interpolate.interp1d(np.arange(len(orig_indices)), pattern_df[col].values,
422
+ kind=kind, bounds_error=False, fill_value='extrapolate')
423
+ # Apply interpolation function to get new values
424
+ new_data[col] = f(new_indices)
425
+
426
+ # Ensure all OHLC values are positive
427
+ for col in ['Open', 'High', 'Low', 'Close']:
428
+ new_data[col] = np.maximum(new_data[col], 0.001) # Small positive value instead of zero
429
+
430
+ # Fix OHLC relationships
431
+ for i in range(len(new_indices)):
432
+ # Ensure High is the maximum
433
+ new_data['High'][i] = max(new_data['High'][i], new_data['Open'][i], new_data['Close'][i])
434
+
435
+ # Ensure Low is the minimum
436
+ new_data['Low'][i] = min(new_data['Low'][i], new_data['Open'][i], new_data['Close'][i])
437
+
438
+ # Handle categorical data separately
439
+ if 'Pattern' in pattern_df.columns:
440
+ f = interpolate.interp1d(np.arange(len(orig_indices)), pattern_df['Pattern'].values,
441
+ kind='nearest', bounds_error=False, fill_value=pattern_df['Pattern'].iloc[0])
442
+ new_data['Pattern'] = f(new_indices)
443
+
444
+ result_df = pd.DataFrame(new_data)
445
+ result_df.index = pd.MultiIndex.from_product([[instance], result_df.index])
446
+ normalized_df_list.append(result_df)
447
+
448
+ if instance in instances_list[random_indices]: # Fixed this line
449
+ # plot results
450
+ plot_ohlc_segment(pattern_df)
451
+ plot_ohlc_segment(result_df)
452
+
453
+ combined_result_df = pd.concat(normalized_df_list, axis=0) # Fixed this line
454
+ return combined_result_df
455
+
456
+ # Define features, target, and desired series length
457
+ features = ['Open', 'High', 'Low', 'Close', 'Volume']
458
+ target = 'Pattern'
459
+ series_length = 100
460
+
461
+ # This function pads or truncates every instance to length=100,
462
+ # then stacks into an array of shape (n_instances, n_features, series_length)
463
+ def prepare_rocket_data(dataset, features = features, target = target, series_length = series_length):
464
+ def adjust_series_length(group):
465
+ arr = group[features].values
466
+ if len(arr) > series_length:
467
+ return arr[:series_length]
468
+ padding = np.zeros((series_length - len(arr), arr.shape[1]))
469
+ return np.vstack([arr, padding])
470
+
471
+ # Apply per-instance adjustment
472
+ adjusted = dataset.groupby(level=0).apply(adjust_series_length)
473
+ X = np.stack(adjusted.values) # (n_instances, series_length, n_features)
474
+ X = np.transpose(X, (0, 2, 1)) # → (n_instances, n_features, series_length)
475
+
476
+ y = dataset.groupby(level=0)[target].first().values
477
+ return X, y