Commit ·
51c6c3d
1
Parent(s): 0c285af
add files
Browse files- handler.py +393 -0
- requirements.txt +0 -0
- utils/__init__.py +0 -0
- utils/eval.py +474 -0
- utils/formatAndPreprocessNewPatterns.py +477 -0
handler.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# handler.py
|
| 2 |
+
import joblib
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
from joblib import Parallel, delayed
|
| 7 |
+
from sklearn.cluster import DBSCAN
|
| 8 |
+
import os # For accessing model path
|
| 9 |
+
|
| 10 |
+
# Import your utility functions
|
| 11 |
+
# Make sure your utils directory is alongside handler.py
|
| 12 |
+
# and contains __init__.py, eval.py, formatAndPreprocessNewPatterns.py
|
| 13 |
+
from utils.eval import intersection_over_union
|
| 14 |
+
from utils.formatAndPreprocessNewPatterns import get_patetrn_name_by_encoding, get_pattern_encoding_by_name, get_reverse_pattern_encoding
|
| 15 |
+
|
| 16 |
+
# --- Global Model Loading (Crucial for performance) ---
|
| 17 |
+
# This model will be loaded ONLY ONCE when the server starts.
|
| 18 |
+
# Ensure the path is correct relative to where handler.py runs in the container.
|
| 19 |
+
# The `MODEL_DIR` env var is automatically set by Inference Endpoints.
|
| 20 |
+
# If you place 'Models/' directly in your repo root, it will be at /repository/Models/
|
| 21 |
+
# If you place it outside (not recommended), you'd need to adjust paths.
|
| 22 |
+
# For simplicity, assume `Models/` is in the root of your HF repo.
|
| 23 |
+
MODEL_PATH = os.path.join(os.environ.get("MODEL_DIR", "."), "Models", "Width Aug OHLC_mini_rocket_xgb.joblib")
|
| 24 |
+
|
| 25 |
+
# Load the model globally
|
| 26 |
+
try:
|
| 27 |
+
print(f"Loading model from: {MODEL_PATH}")
|
| 28 |
+
rocket_model = joblib.load(MODEL_PATH)
|
| 29 |
+
print("Model loaded successfully!")
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"Error loading model: {e}")
|
| 32 |
+
# In a real scenario, you might want to raise an exception to prevent the server from starting
|
| 33 |
+
rocket_model = None
|
| 34 |
+
|
| 35 |
+
# --- Helper functions (from your provided code) ---
|
| 36 |
+
# Paste your `process_window`, `parallel_process_sliding_window`,
|
| 37 |
+
# `prepare_dataset_for_cluster`, `cluster_windows` here.
|
| 38 |
+
# Make sure they are defined before `locate_patterns`
|
| 39 |
+
# because locate_patterns depends on them.
|
| 40 |
+
|
| 41 |
+
# Make sure these globals are outside functions if they are truly global constants
|
| 42 |
+
pattern_encoding_reversed = get_reverse_pattern_encoding()
|
| 43 |
+
# model is now `rocket_model` loaded globally
|
| 44 |
+
# plot_count is handled by the API input now
|
| 45 |
+
win_size_proportions = np.round(np.logspace(0, np.log10(20), num=10), 2).tolist()
|
| 46 |
+
padding_proportion = 0.6
|
| 47 |
+
stride = 1
|
| 48 |
+
probab_threshold_list = 0.5
|
| 49 |
+
prob_threshold_of_no_pattern_to_mark_as_no_pattern = 0.5
|
| 50 |
+
target_len = 30 # Not used in your current code
|
| 51 |
+
|
| 52 |
+
eps=0.04
|
| 53 |
+
min_samples=3
|
| 54 |
+
win_width_proportion=10 # Not used in your current code
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def process_window(i, ohlc_data_segment, rocket_model, probability_threshold, pattern_encoding_reversed,seg_start, seg_end, window_size, padding_proportion,prob_threshold_of_no_pattern_to_mark_as_no_pattern=1):
|
| 58 |
+
start_index = i - math.ceil(window_size * padding_proportion)
|
| 59 |
+
end_index = start_index + window_size
|
| 60 |
+
|
| 61 |
+
start_index = max(start_index, 0)
|
| 62 |
+
end_index = min(end_index, len(ohlc_data_segment))
|
| 63 |
+
|
| 64 |
+
ohlc_segment = ohlc_data_segment[start_index:end_index]
|
| 65 |
+
if len(ohlc_segment) == 0:
|
| 66 |
+
return None # Skip empty segments
|
| 67 |
+
win_start_date = ohlc_segment['Date'].iloc[0]
|
| 68 |
+
win_end_date = ohlc_segment['Date'].iloc[-1]
|
| 69 |
+
|
| 70 |
+
ohlc_array_for_rocket = ohlc_segment[['Open', 'High', 'Low', 'Close','Volume']].to_numpy().reshape(1, len(ohlc_segment), 5)
|
| 71 |
+
ohlc_array_for_rocket = np.transpose(ohlc_array_for_rocket, (0, 2, 1))
|
| 72 |
+
try:
|
| 73 |
+
pattern_probabilities = rocket_model.predict_proba(ohlc_array_for_rocket)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Error in prediction: {e}")
|
| 76 |
+
return None
|
| 77 |
+
max_probability = np.max(pattern_probabilities)
|
| 78 |
+
no_pattern_proba = pattern_probabilities[0][get_pattern_encoding_by_name ('No Pattern')]
|
| 79 |
+
pattern_index = np.argmax(pattern_probabilities)
|
| 80 |
+
|
| 81 |
+
pred_proba = max_probability
|
| 82 |
+
pred_pattern = get_patetrn_name_by_encoding(pattern_index)
|
| 83 |
+
if no_pattern_proba > prob_threshold_of_no_pattern_to_mark_as_no_pattern:
|
| 84 |
+
pred_proba = no_pattern_proba
|
| 85 |
+
pred_pattern = 'No Pattern'
|
| 86 |
+
|
| 87 |
+
new_row = {
|
| 88 |
+
'Start': win_start_date, 'End': win_end_date, 'Chart Pattern': pred_pattern, 'Seg_Start': seg_start, 'Seg_End': seg_end ,
|
| 89 |
+
'Probability': pred_proba
|
| 90 |
+
}
|
| 91 |
+
return new_row
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def parallel_process_sliding_window(ohlc_data_segment, rocket_model, probability_threshold, stride, pattern_encoding_reversed, window_size, padding_proportion,prob_threshold_of_no_pattern_to_mark_as_no_pattern=1,parallel=True,num_cores=-1):
|
| 95 |
+
seg_start = ohlc_data_segment['Date'].iloc[0]
|
| 96 |
+
seg_end = ohlc_data_segment['Date'].iloc[-1]
|
| 97 |
+
|
| 98 |
+
# Render.com's worker environment for the HF endpoint will have limited cores for single instances.
|
| 99 |
+
# Parallel processing (`joblib.Parallel`) within the *single* HF endpoint worker
|
| 100 |
+
# might not yield significant benefits or might even cause issues if not configured carefully.
|
| 101 |
+
# It's generally better to rely on HF's scaling for multiple requests.
|
| 102 |
+
# Consider setting `parallel=False` or `num_cores=1` for initial deployment if you hit issues.
|
| 103 |
+
# For now, let's keep it as is, but be mindful of resource constraints.
|
| 104 |
+
|
| 105 |
+
if parallel:
|
| 106 |
+
with Parallel(n_jobs=num_cores, verbose=0) as parallel: # verbose=0 to reduce log spam
|
| 107 |
+
results = parallel(
|
| 108 |
+
delayed(process_window)(
|
| 109 |
+
i=i,
|
| 110 |
+
ohlc_data_segment=ohlc_data_segment,
|
| 111 |
+
rocket_model=rocket_model,
|
| 112 |
+
probability_threshold=probability_threshold,
|
| 113 |
+
pattern_encoding_reversed=pattern_encoding_reversed,
|
| 114 |
+
window_size=window_size,
|
| 115 |
+
seg_start=seg_start,
|
| 116 |
+
seg_end=seg_end,
|
| 117 |
+
padding_proportion=padding_proportion,
|
| 118 |
+
prob_threshold_of_no_pattern_to_mark_as_no_pattern=prob_threshold_of_no_pattern_to_mark_as_no_pattern
|
| 119 |
+
)
|
| 120 |
+
for i in range(0, len(ohlc_data_segment), stride)
|
| 121 |
+
)
|
| 122 |
+
return pd.DataFrame([res for res in results if res is not None])
|
| 123 |
+
else:
|
| 124 |
+
results = []
|
| 125 |
+
for i_idx, i in enumerate(range(0, len(ohlc_data_segment), stride)):
|
| 126 |
+
res = process_window(i, ohlc_data_segment, rocket_model, probability_threshold, pattern_encoding_reversed, seg_start, seg_end, window_size, padding_proportion)
|
| 127 |
+
if res is not None:
|
| 128 |
+
results.append(res)
|
| 129 |
+
return pd.DataFrame(results)
|
| 130 |
+
|
| 131 |
+
def prepare_dataset_for_cluster(ohlc_data_segment, win_results_df):
|
| 132 |
+
predicted_patterns = win_results_df.copy()
|
| 133 |
+
# origin_date = ohlc_data_segment['Date'].min() # Not used
|
| 134 |
+
for index, row in predicted_patterns.iterrows():
|
| 135 |
+
pattern_start = row['Start']
|
| 136 |
+
pattern_end = row['End']
|
| 137 |
+
start_point_index = len(ohlc_data_segment[ohlc_data_segment['Date'] < pattern_start])
|
| 138 |
+
pattern_len = len(ohlc_data_segment[(ohlc_data_segment['Date'] >= pattern_start) & (ohlc_data_segment['Date'] <= pattern_end)])
|
| 139 |
+
pattern_mid_index = start_point_index + (pattern_len / 2)
|
| 140 |
+
predicted_patterns.at[index, 'Center'] = pattern_mid_index
|
| 141 |
+
predicted_patterns.at[index, 'Pattern_Start_pos'] = start_point_index
|
| 142 |
+
predicted_patterns.at[index, 'Pattern_End_pos'] = start_point_index + pattern_len
|
| 143 |
+
return predicted_patterns
|
| 144 |
+
|
| 145 |
+
def cluster_windows(predicted_patterns , probability_threshold, window_size,eps = 0.05 , min_samples = 2):
|
| 146 |
+
df = predicted_patterns.copy()
|
| 147 |
+
|
| 148 |
+
if isinstance(probability_threshold, list):
|
| 149 |
+
for i in range(len(probability_threshold)):
|
| 150 |
+
pattern_name = get_patetrn_name_by_encoding(i)
|
| 151 |
+
df.drop(df[(df['Chart Pattern'] == pattern_name) & (df['Probability'] < probability_threshold[i])].index, inplace=True)
|
| 152 |
+
else:
|
| 153 |
+
df = df[df['Probability'] > probability_threshold]
|
| 154 |
+
|
| 155 |
+
cluster_labled_windows = []
|
| 156 |
+
interseced_clusters = []
|
| 157 |
+
|
| 158 |
+
if df.empty: # Handle case where df might be empty after filtering
|
| 159 |
+
return None, None
|
| 160 |
+
|
| 161 |
+
min_center = df['Center'].min()
|
| 162 |
+
max_center = df['Center'].max()
|
| 163 |
+
|
| 164 |
+
for pattern, group in df.groupby('Chart Pattern'):
|
| 165 |
+
centers = group['Center'].values.reshape(-1, 1)
|
| 166 |
+
|
| 167 |
+
if min_center < max_center:
|
| 168 |
+
norm_centers = (centers - min_center) / (max_center - min_center)
|
| 169 |
+
else:
|
| 170 |
+
norm_centers = np.ones_like(centers)
|
| 171 |
+
|
| 172 |
+
db = DBSCAN(eps=eps, min_samples=min_samples).fit(norm_centers)
|
| 173 |
+
group['Cluster'] = db.labels_
|
| 174 |
+
cluster_labled_windows.append(group)
|
| 175 |
+
|
| 176 |
+
for cluster_id, cluster_group in group[group['Cluster'] != -1].groupby('Cluster'):
|
| 177 |
+
expanded_dates = []
|
| 178 |
+
for _, row in cluster_group.iterrows():
|
| 179 |
+
dates = pd.date_range(row["Start"], row["End"])
|
| 180 |
+
expanded_dates.extend(dates)
|
| 181 |
+
|
| 182 |
+
date_counts = pd.Series(expanded_dates).value_counts().sort_index()
|
| 183 |
+
cluster_start = date_counts[date_counts >= 2].index.min()
|
| 184 |
+
cluster_end = date_counts[date_counts >= 2].index.max()
|
| 185 |
+
|
| 186 |
+
interseced_clusters.append({
|
| 187 |
+
'Chart Pattern': pattern,
|
| 188 |
+
'Cluster': cluster_id,
|
| 189 |
+
'Start': cluster_start,
|
| 190 |
+
'End': cluster_end,
|
| 191 |
+
'Seg_Start': cluster_group['Seg_Start'].iloc[0],
|
| 192 |
+
'Seg_End': cluster_group['Seg_End'].iloc[0],
|
| 193 |
+
'Avg_Probability': cluster_group['Probability'].mean(),
|
| 194 |
+
})
|
| 195 |
+
|
| 196 |
+
if len(cluster_labled_windows) == 0 or len(interseced_clusters) == 0:
|
| 197 |
+
return None, None
|
| 198 |
+
|
| 199 |
+
cluster_labled_windows_df = pd.concat(cluster_labled_windows)
|
| 200 |
+
interseced_clusters_df = pd.DataFrame(interseced_clusters)
|
| 201 |
+
cluster_labled_windows_df = cluster_labled_windows_df.sort_index()
|
| 202 |
+
return cluster_labled_windows_df, interseced_clusters_df
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ========================= locate_patterns function ==========================
|
| 206 |
+
|
| 207 |
+
# This will be your primary inference function called by the HF endpoint.
|
| 208 |
+
class InferenceHandler:
|
| 209 |
+
def __init__(self):
|
| 210 |
+
# Model is loaded globally, so it's accessible here
|
| 211 |
+
self.model = rocket_model
|
| 212 |
+
if self.model is None:
|
| 213 |
+
raise ValueError("ML model failed to load during initialization.")
|
| 214 |
+
|
| 215 |
+
# Initialize other global parameters here as well
|
| 216 |
+
self.pattern_encoding_reversed = pattern_encoding_reversed
|
| 217 |
+
self.win_size_proportions = win_size_proportions
|
| 218 |
+
self.padding_proportion = padding_proportion
|
| 219 |
+
self.stride = stride
|
| 220 |
+
self.probab_threshold_list = probab_threshold_list
|
| 221 |
+
self.prob_threshold_of_no_pattern_to_mark_as_no_pattern = prob_threshold_of_no_pattern_to_mark_as_no_pattern
|
| 222 |
+
self.eps = eps
|
| 223 |
+
self.min_samples = min_samples
|
| 224 |
+
|
| 225 |
+
def __call__(self, inputs):
|
| 226 |
+
"""
|
| 227 |
+
Main inference method for the Hugging Face Inference Endpoint.
|
| 228 |
+
Args:
|
| 229 |
+
inputs: A dictionary or list of dictionaries representing the input data.
|
| 230 |
+
For your case, this will be the OHLC data sent from Django.
|
| 231 |
+
Expected format: [{"Date": "YYYY-MM-DD", "Open": ..., "High": ..., ...}, ...]
|
| 232 |
+
Returns:
|
| 233 |
+
A list of dictionaries representing the detected patterns.
|
| 234 |
+
"""
|
| 235 |
+
if not self.model:
|
| 236 |
+
raise ValueError("ML model is not loaded. Cannot perform inference.")
|
| 237 |
+
|
| 238 |
+
# Ensure inputs is a list of dictionaries if not already
|
| 239 |
+
if isinstance(inputs, dict):
|
| 240 |
+
inputs = [inputs] # Handle single input dict if needed
|
| 241 |
+
|
| 242 |
+
# Convert input (list of dicts) to pandas DataFrame
|
| 243 |
+
try:
|
| 244 |
+
ohlc_data = pd.DataFrame(inputs)
|
| 245 |
+
# Ensure 'Date' is datetime, it might come as string from JSON
|
| 246 |
+
ohlc_data['Date'] = pd.to_datetime(ohlc_data['Date'])
|
| 247 |
+
# Ensure proper columns exist
|
| 248 |
+
required_cols = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
| 249 |
+
if not all(col in ohlc_data.columns for col in required_cols):
|
| 250 |
+
raise ValueError(f"Missing required columns in input data. Expected: {required_cols}, Got: {ohlc_data.columns.tolist()}")
|
| 251 |
+
|
| 252 |
+
except Exception as e:
|
| 253 |
+
print(f"Error processing input data: {e}")
|
| 254 |
+
raise ValueError(f"Invalid input data format: {e}")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
ohlc_data_segment = ohlc_data.copy()
|
| 258 |
+
seg_len = len(ohlc_data_segment)
|
| 259 |
+
|
| 260 |
+
if ohlc_data_segment.empty:
|
| 261 |
+
raise ValueError("OHLC Data segment is empty or invalid after processing.")
|
| 262 |
+
|
| 263 |
+
win_results_for_each_size = []
|
| 264 |
+
located_patterns_and_other_info_for_each_size = []
|
| 265 |
+
cluster_labled_windows_list = []
|
| 266 |
+
|
| 267 |
+
used_win_sizes = []
|
| 268 |
+
win_iteration = 0
|
| 269 |
+
|
| 270 |
+
for win_size_proportion in self.win_size_proportions:
|
| 271 |
+
window_size = seg_len // win_size_proportion
|
| 272 |
+
if window_size < 10:
|
| 273 |
+
window_size = 10
|
| 274 |
+
window_size = int(window_size)
|
| 275 |
+
if window_size in used_win_sizes:
|
| 276 |
+
continue
|
| 277 |
+
used_win_sizes.append(window_size)
|
| 278 |
+
|
| 279 |
+
# Pass the globally loaded model `self.model`
|
| 280 |
+
win_results_df = parallel_process_sliding_window(
|
| 281 |
+
ohlc_data_segment,
|
| 282 |
+
self.model,
|
| 283 |
+
self.probab_threshold_list,
|
| 284 |
+
self.stride,
|
| 285 |
+
self.pattern_encoding_reversed,
|
| 286 |
+
window_size,
|
| 287 |
+
self.padding_proportion,
|
| 288 |
+
self.prob_threshold_of_no_pattern_to_mark_as_no_pattern,
|
| 289 |
+
parallel=True, # You might want to test with False/num_cores=1 on HF to avoid internal parallelism issues
|
| 290 |
+
num_cores=-1 # -1 means all available cores; on HF, this will be limited by the instance type
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if win_results_df is None or win_results_df.empty:
|
| 294 |
+
print(f"Window results dataframe is empty for window size {window_size}")
|
| 295 |
+
continue
|
| 296 |
+
win_results_df['Window_Size'] = window_size
|
| 297 |
+
win_results_for_each_size.append(win_results_df)
|
| 298 |
+
|
| 299 |
+
predicted_patterns = prepare_dataset_for_cluster(ohlc_data_segment, win_results_df)
|
| 300 |
+
if predicted_patterns is None or predicted_patterns.empty:
|
| 301 |
+
print("Predicted patterns dataframe is empty")
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
# Pass eps and min_samples from handler's state
|
| 305 |
+
cluster_labled_windows_df , interseced_clusters_df = cluster_windows(
|
| 306 |
+
predicted_patterns,
|
| 307 |
+
self.probab_threshold_list,
|
| 308 |
+
window_size,
|
| 309 |
+
eps=self.eps,
|
| 310 |
+
min_samples=self.min_samples
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if cluster_labled_windows_df is None or interseced_clusters_df is None or cluster_labled_windows_df.empty or interseced_clusters_df.empty:
|
| 314 |
+
print("Clustered windows dataframe is empty")
|
| 315 |
+
continue
|
| 316 |
+
mask = cluster_labled_windows_df['Cluster'] != -1
|
| 317 |
+
cluster_labled_windows_df.loc[mask, 'Cluster'] = cluster_labled_windows_df.loc[mask, 'Cluster'].astype(int) + win_iteration
|
| 318 |
+
interseced_clusters_df['Cluster'] = interseced_clusters_df['Cluster'].astype(int) + win_iteration
|
| 319 |
+
num_of_unique_clusters = interseced_clusters_df[interseced_clusters_df['Cluster']!=-1]['Cluster'].nunique()
|
| 320 |
+
win_iteration += num_of_unique_clusters
|
| 321 |
+
cluster_labled_windows_list.append(cluster_labled_windows_df)
|
| 322 |
+
|
| 323 |
+
interseced_clusters_df['Calc_Start'] = interseced_clusters_df['Start']
|
| 324 |
+
interseced_clusters_df['Calc_End'] = interseced_clusters_df['End']
|
| 325 |
+
located_patterns_and_other_info = interseced_clusters_df.copy()
|
| 326 |
+
|
| 327 |
+
if located_patterns_and_other_info is None or located_patterns_and_other_info.empty:
|
| 328 |
+
print("Located patterns and other info dataframe is empty")
|
| 329 |
+
continue
|
| 330 |
+
located_patterns_and_other_info['Window_Size'] = window_size
|
| 331 |
+
|
| 332 |
+
located_patterns_and_other_info_for_each_size.append(located_patterns_and_other_info)
|
| 333 |
+
|
| 334 |
+
if located_patterns_and_other_info_for_each_size is None or not located_patterns_and_other_info_for_each_size:
|
| 335 |
+
print("Located patterns and other info for each size is empty")
|
| 336 |
+
return [] # Return empty list if no patterns found
|
| 337 |
+
|
| 338 |
+
located_patterns_and_other_info_for_each_size_df = pd.concat(located_patterns_and_other_info_for_each_size)
|
| 339 |
+
|
| 340 |
+
unique_window_sizes = located_patterns_and_other_info_for_each_size_df['Window_Size'].unique()
|
| 341 |
+
unique_patterns = located_patterns_and_other_info_for_each_size_df['Chart Pattern'].unique()
|
| 342 |
+
unique_window_sizes = np.sort(unique_window_sizes)[::-1]
|
| 343 |
+
|
| 344 |
+
filtered_loc_pat_and_info_rows_list = []
|
| 345 |
+
|
| 346 |
+
for chart_pattern in unique_patterns:
|
| 347 |
+
located_patterns_and_other_info_for_each_size_df_chart_pattern = located_patterns_and_other_info_for_each_size_df[located_patterns_and_other_info_for_each_size_df['Chart Pattern'] == chart_pattern]
|
| 348 |
+
for win_size in unique_window_sizes:
|
| 349 |
+
located_patterns_and_other_info_for_each_size_df_win_size_chart_pattern = located_patterns_and_other_info_for_each_size_df_chart_pattern[located_patterns_and_other_info_for_each_size_df_chart_pattern['Window_Size'] == win_size]
|
| 350 |
+
for idx , row in located_patterns_and_other_info_for_each_size_df_win_size_chart_pattern.iterrows():
|
| 351 |
+
start_date = row['Calc_Start']
|
| 352 |
+
end_date = row['Calc_End']
|
| 353 |
+
is_already_included = False
|
| 354 |
+
intersecting_rows = located_patterns_and_other_info_for_each_size_df_chart_pattern[
|
| 355 |
+
(located_patterns_and_other_info_for_each_size_df_chart_pattern['Calc_Start'] <= end_date) &
|
| 356 |
+
(located_patterns_and_other_info_for_each_size_df_chart_pattern['Calc_End'] >= start_date)
|
| 357 |
+
]
|
| 358 |
+
is_already_included = False
|
| 359 |
+
for idx2, row2 in intersecting_rows.iterrows():
|
| 360 |
+
iou = intersection_over_union(start_date, end_date, row2['Calc_Start'], row2['Calc_End'])
|
| 361 |
+
|
| 362 |
+
if iou > 0.6:
|
| 363 |
+
if row2['Window_Size'] > row['Window_Size']:
|
| 364 |
+
if (row['Avg_Probability'] - row2['Avg_Probability']) > 0.1:
|
| 365 |
+
is_already_included = False
|
| 366 |
+
else:
|
| 367 |
+
is_already_included = True
|
| 368 |
+
break
|
| 369 |
+
elif row['Window_Size'] >= row2['Window_Size']:
|
| 370 |
+
if (row2['Avg_Probability'] - row['Avg_Probability']) > 0.1:
|
| 371 |
+
is_already_included = True
|
| 372 |
+
break
|
| 373 |
+
else:
|
| 374 |
+
is_already_included = False
|
| 375 |
+
|
| 376 |
+
if not is_already_included:
|
| 377 |
+
filtered_loc_pat_and_info_rows_list.append(row)
|
| 378 |
+
|
| 379 |
+
filtered_loc_pat_and_info_df = pd.DataFrame(filtered_loc_pat_and_info_rows_list)
|
| 380 |
+
|
| 381 |
+
# Convert datetime columns to string format for serialization before returning
|
| 382 |
+
datetime_columns = ['Start', 'End', 'Seg_Start', 'Seg_End', 'Calc_Start', 'Calc_End']
|
| 383 |
+
for col in datetime_columns:
|
| 384 |
+
if col in filtered_loc_pat_and_info_df.columns:
|
| 385 |
+
if pd.api.types.is_datetime64_any_dtype(filtered_loc_pat_and_info_df[col]):
|
| 386 |
+
filtered_loc_pat_and_info_df[col] = pd.to_datetime(filtered_loc_pat_and_info_df[col]).dt.strftime('%Y-%m-%d')
|
| 387 |
+
elif not filtered_loc_pat_and_info_df[col].empty and isinstance(filtered_loc_pat_and_info_df[col].iloc[0], str):
|
| 388 |
+
pass
|
| 389 |
+
else:
|
| 390 |
+
filtered_loc_pat_and_info_df[col] = filtered_loc_pat_and_info_df[col].astype(str)
|
| 391 |
+
|
| 392 |
+
# Return as a list of dictionaries (JSON serializable)
|
| 393 |
+
return filtered_loc_pat_and_info_df.to_dict('records')
|
requirements.txt
ADDED
|
Binary file (2.71 kB). View file
|
|
|
utils/__init__.py
ADDED
|
File without changes
|
utils/eval.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
# from matplotlib import pyplot as plt
|
| 3 |
+
# from matplotlib.gridspec import GridSpec
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def intersection_over_union(start1, end1, start2, end2):
|
| 9 |
+
"""
|
| 10 |
+
Compute Intersection over Union (IoU) between two date ranges.
|
| 11 |
+
"""
|
| 12 |
+
latest_start = max(start1, start2)
|
| 13 |
+
earliest_end = min(end1, end2)
|
| 14 |
+
overlap = max(0, (earliest_end - latest_start).days + 1)
|
| 15 |
+
union = (end1 - start1).days + (end2 - start2).days + 2 - overlap
|
| 16 |
+
return overlap / union if union > 0 else 0 # Avoid division by zero
|
| 17 |
+
|
| 18 |
+
def mean_abselute_error(start1, end1, start2, end2):
|
| 19 |
+
"""
|
| 20 |
+
Compute Mean Absolute Error (MAE) between two date ranges.
|
| 21 |
+
"""
|
| 22 |
+
# check if start or end are NAT
|
| 23 |
+
if start1 is pd.NaT or end1 is pd.NaT or start2 is pd.NaT or end2 is pd.NaT:
|
| 24 |
+
print("One of the dates is NaT")
|
| 25 |
+
print(f"start1: {start1}, end1: {end1}, start2: {start2}, end2: {end2}")
|
| 26 |
+
return None
|
| 27 |
+
return (abs(start1 - start2).days + abs(end1 - end2).days) / 2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_model_eval_res(located_patterns_and_other_info_updated_dict,window_results_dict,selected_models,selected_test_patterns_without_no_pattern):
|
| 31 |
+
model_eval_results_dict = {}
|
| 32 |
+
for model_name in selected_models:
|
| 33 |
+
print(f"\n Selected model: {model_name}")
|
| 34 |
+
|
| 35 |
+
located_patterns_and_other_info_updated_df = located_patterns_and_other_info_updated_dict[model_name]
|
| 36 |
+
window_results_df = window_results_dict[model_name]
|
| 37 |
+
|
| 38 |
+
# dictionary to store the count of properly located patterns , iou and mae for each properly detected pattern for each model
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Dictionary to store the count of properly located patterns
|
| 42 |
+
number_of_properly_located_patterns = {}
|
| 43 |
+
iou_for_each_properly_detected_pattern = {}
|
| 44 |
+
mae_for_each_properly_detected_pattern = {}
|
| 45 |
+
|
| 46 |
+
# Convert date columns to datetime (once, outside the loop for efficiency)
|
| 47 |
+
located_patterns_and_other_info_updated_df['Calc_Start'] = pd.to_datetime(located_patterns_and_other_info_updated_df['Calc_Start'])
|
| 48 |
+
located_patterns_and_other_info_updated_df['Calc_End'] = pd.to_datetime(located_patterns_and_other_info_updated_df['Calc_End'])
|
| 49 |
+
|
| 50 |
+
# Iterate over test patterns with progress bar
|
| 51 |
+
for index, row in selected_test_patterns_without_no_pattern.iterrows():
|
| 52 |
+
sys.stdout.write(f"\rProcessing row {index + 1}/{len(selected_test_patterns_without_no_pattern)}")
|
| 53 |
+
sys.stdout.flush()
|
| 54 |
+
symbol = row['Symbol']
|
| 55 |
+
chart_pattern = row['Chart Pattern']
|
| 56 |
+
start_date = pd.to_datetime(row['Start']).tz_localize(None)
|
| 57 |
+
end_date = pd.to_datetime(row['End']).tz_localize(None)
|
| 58 |
+
|
| 59 |
+
# Filter for matching symbol and chart pattern
|
| 60 |
+
located_patterns_for_this = located_patterns_and_other_info_updated_df[
|
| 61 |
+
(located_patterns_and_other_info_updated_df['Symbol'] == symbol) &
|
| 62 |
+
(located_patterns_and_other_info_updated_df['Chart Pattern'] == chart_pattern)
|
| 63 |
+
].copy() # Use `.copy()` to avoid SettingWithCopyWarning
|
| 64 |
+
|
| 65 |
+
if located_patterns_for_this.empty:
|
| 66 |
+
continue # Skip if no matching rows
|
| 67 |
+
|
| 68 |
+
# Compute IoU for each row using .loc to avoid warnings
|
| 69 |
+
located_patterns_for_this.loc[:, 'IoU'] = located_patterns_for_this.apply(
|
| 70 |
+
lambda x: intersection_over_union(start_date, end_date, x['Calc_Start'], x['Calc_End']),
|
| 71 |
+
axis=1
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Compute MAE for each row using .loc to avoid warnings
|
| 75 |
+
located_patterns_for_this.loc[:, 'MAE'] = located_patterns_for_this.apply(
|
| 76 |
+
lambda x: mean_abselute_error(start_date, end_date, x['Calc_Start'], x['Calc_End']),
|
| 77 |
+
axis=1
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Filter based on IoU threshold (≥ 0.8)
|
| 82 |
+
located_patterns_for_this_proper = located_patterns_for_this[located_patterns_for_this['IoU'] >= 0.25]
|
| 83 |
+
|
| 84 |
+
if not located_patterns_for_this_proper.empty:
|
| 85 |
+
number_of_properly_located_patterns[chart_pattern] = number_of_properly_located_patterns.get(chart_pattern, 0) + 1
|
| 86 |
+
iou_for_each_properly_detected_pattern[chart_pattern] = iou_for_each_properly_detected_pattern.get(chart_pattern, 0) + max(located_patterns_for_this_proper['IoU'])
|
| 87 |
+
mae_for_each_properly_detected_pattern[chart_pattern] = mae_for_each_properly_detected_pattern.get(chart_pattern, 0) + min(located_patterns_for_this_proper['MAE'])
|
| 88 |
+
|
| 89 |
+
number_of_properly_located_patterns
|
| 90 |
+
|
| 91 |
+
model_eval_results_dict[model_name] = {
|
| 92 |
+
'number_of_properly_located_patterns': number_of_properly_located_patterns,
|
| 93 |
+
'iou_for_each_properly_detected_pattern': iou_for_each_properly_detected_pattern,
|
| 94 |
+
'mae_for_each_properly_detected_pattern': mae_for_each_properly_detected_pattern
|
| 95 |
+
}
|
| 96 |
+
return model_eval_results_dict
|
| 97 |
+
|
| 98 |
+
############################################################################################
|
| 99 |
+
# Evaluate multiple models and plot
|
| 100 |
+
############################################################################################
|
| 101 |
+
# Commenting out plotting functions
|
| 102 |
+
"""
|
| 103 |
+
def create_comprehensive_model_comparison(all_models_metrics):
|
| 104 |
+
|
| 105 |
+
Create a comprehensive visualization comparing all models across all metrics,
|
| 106 |
+
using nested concentric pie charts for Precision and Recall.
|
| 107 |
+
|
| 108 |
+
Parameters:
|
| 109 |
+
-----------
|
| 110 |
+
all_models_metrics : dict
|
| 111 |
+
Dictionary containing metrics for each model
|
| 112 |
+
|
| 113 |
+
models = list(all_models_metrics.keys())
|
| 114 |
+
n_models = len(models)
|
| 115 |
+
|
| 116 |
+
# Define the metrics to include
|
| 117 |
+
key_metrics = {
|
| 118 |
+
'total_recall': 'Recall',
|
| 119 |
+
'total_precision': 'Precision',
|
| 120 |
+
'overall_f1': 'F1 Score',
|
| 121 |
+
'overall_iou': 'IoU',
|
| 122 |
+
'overall_mae': 'MAE'
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
# Create figure with GridSpec for flexible layout
|
| 126 |
+
fig = plt.figure(figsize=(20, 14))
|
| 127 |
+
|
| 128 |
+
# Add main title with enough space for legend below it
|
| 129 |
+
plt.suptitle('Comprehensive Model Evaluation', fontsize=16, y=0.98)
|
| 130 |
+
|
| 131 |
+
# Define a color palette for models
|
| 132 |
+
colors = plt.cm.tab10(np.linspace(0, 1, n_models))
|
| 133 |
+
|
| 134 |
+
# Create a master legend below the title
|
| 135 |
+
legend_handles = [plt.Line2D([0], [0], color=colors[i], lw=4, label=model) for i, model in enumerate(models)]
|
| 136 |
+
fig.legend(
|
| 137 |
+
handles=legend_handles,
|
| 138 |
+
labels=models,
|
| 139 |
+
loc='upper center',
|
| 140 |
+
bbox_to_anchor=(0.5, 0.93), # Moved down from 0.98 to 0.93
|
| 141 |
+
ncol=n_models,
|
| 142 |
+
fontsize=12
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Adjust GridSpec to account for the title and legend
|
| 146 |
+
gs = GridSpec(3, 3, figure=fig, height_ratios=[1.2, 1.2, 1], top=0.88) # Reduced top from 0.95 to 0.88
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# 1. Precision Nested Pie Chart - top left
|
| 151 |
+
ax1 = fig.add_subplot(gs[0, 0])
|
| 152 |
+
|
| 153 |
+
# Create a multi-layer nested pie chart for precision
|
| 154 |
+
# Each ring represents a different model
|
| 155 |
+
precision_values = [metrics['total_precision'] for metrics in all_models_metrics.values()]
|
| 156 |
+
|
| 157 |
+
# Calculate radii for each ring (outermost ring is largest)
|
| 158 |
+
radii = np.linspace(0.5, 1.0, n_models+1)[1:] # start from second element to skip 0.5
|
| 159 |
+
|
| 160 |
+
# Plot each model as a ring, outermost = first model
|
| 161 |
+
for i, model in enumerate(models):
|
| 162 |
+
# Create data for this model's ring [precision, 1-precision]
|
| 163 |
+
data = [precision_values[i], 1-precision_values[i]]
|
| 164 |
+
colors_ring = [colors[i], 'lightgray']
|
| 165 |
+
|
| 166 |
+
# Create pie chart for this ring
|
| 167 |
+
wedges, texts = ax1.pie(
|
| 168 |
+
data,
|
| 169 |
+
radius=radii[i],
|
| 170 |
+
colors=colors_ring,
|
| 171 |
+
startangle=90,
|
| 172 |
+
counterclock=False,
|
| 173 |
+
wedgeprops=dict(width=0.15, edgecolor='w')
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Add only the value (no model name) to the pie chart wedge
|
| 177 |
+
angle = (wedges[0].theta1 + wedges[0].theta2) / 2
|
| 178 |
+
x = (radii[i] - 0.075) * np.cos(np.radians(angle))
|
| 179 |
+
y = (radii[i] - 0.075) * np.sin(np.radians(angle))
|
| 180 |
+
ax1.text(x, y, f"{precision_values[i]:.3f}",
|
| 181 |
+
ha='center', va='center', fontsize=10, fontweight='bold')
|
| 182 |
+
|
| 183 |
+
# Create center circle for donut effect
|
| 184 |
+
centre_circle = plt.Circle((0, 0), 0.25, fc='white')
|
| 185 |
+
ax1.add_patch(centre_circle)
|
| 186 |
+
|
| 187 |
+
ax1.set_title('Precision Comparison (Higher is Better)')
|
| 188 |
+
ax1.set_aspect('equal')
|
| 189 |
+
|
| 190 |
+
# 2. Recall Nested Pie Chart - top middle
|
| 191 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 192 |
+
|
| 193 |
+
# Create a multi-layer nested pie chart for recall
|
| 194 |
+
recall_values = [metrics['total_recall'] for metrics in all_models_metrics.values()]
|
| 195 |
+
|
| 196 |
+
# Plot each model as a ring, outermost = first model
|
| 197 |
+
for i, model in enumerate(models):
|
| 198 |
+
# Create data for this model's ring [recall, 1-recall]
|
| 199 |
+
data = [recall_values[i], 1-recall_values[i]]
|
| 200 |
+
colors_ring = [colors[i], 'lightgray']
|
| 201 |
+
|
| 202 |
+
# Create pie chart for this ring
|
| 203 |
+
wedges, texts = ax2.pie(
|
| 204 |
+
data,
|
| 205 |
+
radius=radii[i],
|
| 206 |
+
colors=colors_ring,
|
| 207 |
+
startangle=90,
|
| 208 |
+
counterclock=False,
|
| 209 |
+
wedgeprops=dict(width=0.15, edgecolor='w')
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Add only the value (no model name) to the pie chart wedge
|
| 213 |
+
angle = (wedges[0].theta1 + wedges[0].theta2) / 2
|
| 214 |
+
x = (radii[i] - 0.075) * np.cos(np.radians(angle))
|
| 215 |
+
y = (radii[i] - 0.075) * np.sin(np.radians(angle))
|
| 216 |
+
ax2.text(x, y, f"{recall_values[i]:.3f}",
|
| 217 |
+
ha='center', va='center', fontsize=10, fontweight='bold')
|
| 218 |
+
|
| 219 |
+
# Create center circle for donut effect
|
| 220 |
+
centre_circle = plt.Circle((0, 0), 0.25, fc='white')
|
| 221 |
+
ax2.add_patch(centre_circle)
|
| 222 |
+
|
| 223 |
+
ax2.set_title('Recall Comparison (Higher is Better)')
|
| 224 |
+
ax2.set_aspect('equal')
|
| 225 |
+
|
| 226 |
+
# 3. F1 Score and IoU - top right
|
| 227 |
+
ax3 = fig.add_subplot(gs[0, 2])
|
| 228 |
+
|
| 229 |
+
# Prepare data for grouped bar chart
|
| 230 |
+
metrics_to_plot = ['overall_f1', 'overall_iou']
|
| 231 |
+
x = np.arange(len(metrics_to_plot))
|
| 232 |
+
width = 0.8 / n_models
|
| 233 |
+
|
| 234 |
+
# Plot grouped bars for each model
|
| 235 |
+
for i, (model_name, metrics) in enumerate(all_models_metrics.items()):
|
| 236 |
+
values = [metrics[key] for key in metrics_to_plot]
|
| 237 |
+
bars = ax3.bar(x + i*width - width*(n_models-1)/2, values, width, color=colors[i])
|
| 238 |
+
|
| 239 |
+
# Add value labels above each bar
|
| 240 |
+
for bar, value in zip(bars, values):
|
| 241 |
+
height = bar.get_height()
|
| 242 |
+
ax3.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 243 |
+
f'{value:.3f}', ha='center', va='bottom', fontsize=9, rotation=0)
|
| 244 |
+
|
| 245 |
+
# Customize the plot
|
| 246 |
+
ax3.set_xticks(x)
|
| 247 |
+
ax3.set_xticklabels([key_metrics[key] for key in metrics_to_plot])
|
| 248 |
+
ax3.set_ylabel('Score')
|
| 249 |
+
ax3.set_title('F1 Score & IoU Comparison (Higher is Better)')
|
| 250 |
+
ax3.set_ylim(0, 1.0)
|
| 251 |
+
ax3.grid(axis='y', linestyle='--', alpha=0.7)
|
| 252 |
+
|
| 253 |
+
# 4. MAE comparison (separate bar chart) - middle left
|
| 254 |
+
ax4 = fig.add_subplot(gs[1, 0])
|
| 255 |
+
|
| 256 |
+
mae_values = [metrics['overall_mae'] for metrics in all_models_metrics.values()]
|
| 257 |
+
bars = ax4.bar(models, mae_values, color=colors)
|
| 258 |
+
|
| 259 |
+
# Add value labels above MAE bars
|
| 260 |
+
for bar, value in zip(bars, mae_values):
|
| 261 |
+
height = bar.get_height()
|
| 262 |
+
ax4.text(bar.get_x() + bar.get_width()/2., height + 0.01,
|
| 263 |
+
f'{value:.3f}', ha='center', va='bottom', fontsize=9)
|
| 264 |
+
|
| 265 |
+
ax4.set_ylabel('Error')
|
| 266 |
+
ax4.set_title('Mean Absolute Error (Lower is Better)')
|
| 267 |
+
ax4.grid(axis='y', linestyle='--', alpha=0.7)
|
| 268 |
+
|
| 269 |
+
# 5. Model metrics radar chart - middle center
|
| 270 |
+
ax5 = fig.add_subplot(gs[1, 1], polar=True)
|
| 271 |
+
|
| 272 |
+
# Setup for radar chart
|
| 273 |
+
metrics_for_radar = ['total_recall', 'total_precision', 'overall_f1', 'overall_iou']
|
| 274 |
+
num_vars = len(metrics_for_radar)
|
| 275 |
+
angles = np.linspace(0, 2*np.pi, num_vars, endpoint=False).tolist()
|
| 276 |
+
angles += angles[:1] # Close the loop
|
| 277 |
+
|
| 278 |
+
# Plot each model on the radar chart
|
| 279 |
+
for i, (model_name, metrics) in enumerate(all_models_metrics.items()):
|
| 280 |
+
values = [metrics[metric] for metric in metrics_for_radar]
|
| 281 |
+
values += values[:1] # Close the loop
|
| 282 |
+
|
| 283 |
+
ax5.plot(angles, values, linewidth=2, linestyle='solid', color=colors[i])
|
| 284 |
+
ax5.fill(angles, values, alpha=0.1, color=colors[i])
|
| 285 |
+
|
| 286 |
+
# Set radar chart labels
|
| 287 |
+
ax5.set_xticks(angles[:-1])
|
| 288 |
+
ax5.set_xticklabels([key_metrics[metric] for metric in metrics_for_radar])
|
| 289 |
+
ax5.set_ylim(0, 1)
|
| 290 |
+
ax5.set_title('Model Performance Radar Chart')
|
| 291 |
+
|
| 292 |
+
# 6. Model comparison bar - middle right
|
| 293 |
+
ax6 = fig.add_subplot(gs[1, 2])
|
| 294 |
+
|
| 295 |
+
# Calculate the average of the four main metrics for an overall score
|
| 296 |
+
# (excluding MAE which is inverse, lower is better)
|
| 297 |
+
overall_scores = []
|
| 298 |
+
for model_name, metrics in all_models_metrics.items():
|
| 299 |
+
score = (metrics['total_recall'] + metrics['total_precision'] +
|
| 300 |
+
metrics['overall_f1'] + metrics['overall_iou']) / 4
|
| 301 |
+
overall_scores.append(score)
|
| 302 |
+
|
| 303 |
+
# Create horizontal bar chart
|
| 304 |
+
y_pos = np.arange(len(models))
|
| 305 |
+
ax6.barh(y_pos, overall_scores, color=colors)
|
| 306 |
+
ax6.set_yticks(y_pos)
|
| 307 |
+
ax6.set_yticklabels(models)
|
| 308 |
+
ax6.invert_yaxis() # labels read top-to-bottom
|
| 309 |
+
ax6.set_xlabel('Overall Performance Score')
|
| 310 |
+
ax6.set_title('Overall Model Comparison (Higher is Better)')
|
| 311 |
+
|
| 312 |
+
# Add value labels
|
| 313 |
+
for i, v in enumerate(overall_scores):
|
| 314 |
+
ax6.text(v + 0.01, i, f'{v:.3f}', va='center')
|
| 315 |
+
|
| 316 |
+
# 7. Detailed per-model metrics table - bottom span all columns
|
| 317 |
+
ax7 = fig.add_subplot(gs[2, :])
|
| 318 |
+
ax7.axis('tight')
|
| 319 |
+
ax7.axis('off')
|
| 320 |
+
|
| 321 |
+
# Prepare table data
|
| 322 |
+
table_data = []
|
| 323 |
+
for model_name, metrics in all_models_metrics.items():
|
| 324 |
+
row = [model_name]
|
| 325 |
+
for key in key_metrics:
|
| 326 |
+
row.append(f"{metrics[key]:.4f}")
|
| 327 |
+
table_data.append(row)
|
| 328 |
+
|
| 329 |
+
# Create table
|
| 330 |
+
column_labels = ['Model'] + list(key_metrics.values())
|
| 331 |
+
table = ax7.table(
|
| 332 |
+
cellText=table_data,
|
| 333 |
+
colLabels=column_labels,
|
| 334 |
+
loc='center',
|
| 335 |
+
cellLoc='center'
|
| 336 |
+
)
|
| 337 |
+
table.auto_set_font_size(False)
|
| 338 |
+
table.set_fontsize(10)
|
| 339 |
+
table.scale(1, 1.5)
|
| 340 |
+
ax7.set_title('Model Metrics Summary Table')
|
| 341 |
+
|
| 342 |
+
plt.tight_layout(rect=[0, 0.03, 1, 0.88]) # Adjusted rect to account for title and legend
|
| 343 |
+
|
| 344 |
+
plt.show()
|
| 345 |
+
|
| 346 |
+
return fig
|
| 347 |
+
|
| 348 |
+
# The evaluate_model and evaluate_all_models functions remain unchanged
|
| 349 |
+
# The evaluate_model and evaluate_all_models functions remain unchanged
|
| 350 |
+
# The evaluate_model function remains unchanged from your second code snippet
|
| 351 |
+
def evaluate_model(model_name, model_eval_results_dict, pattern_row_count, test_patterns, located_patterns_and_other_info_updated_dict):
|
| 352 |
+
Evaluate a model and calculate metrics without redundant plots
|
| 353 |
+
print(f"\n{'='*20} Model: {model_name} {'='*20}")
|
| 354 |
+
|
| 355 |
+
# Extract model results
|
| 356 |
+
number_of_properly_located_patterns = model_eval_results_dict[model_name]['number_of_properly_located_patterns']
|
| 357 |
+
located_patterns_df = located_patterns_and_other_info_updated_dict[model_name]
|
| 358 |
+
mae_for_each_properly_detected_pattern = model_eval_results_dict[model_name]['mae_for_each_properly_detected_pattern']
|
| 359 |
+
iou_for_each_properly_detected_pattern = model_eval_results_dict[model_name]['iou_for_each_properly_detected_pattern']
|
| 360 |
+
|
| 361 |
+
# Calculate metrics without plotting
|
| 362 |
+
# Recall
|
| 363 |
+
total_number_of_all_patterns = sum(pattern_row_count.values())
|
| 364 |
+
total_number_of_properly_located_patterns = sum(number_of_properly_located_patterns.values())
|
| 365 |
+
total_recall = total_number_of_properly_located_patterns / total_number_of_all_patterns if total_number_of_all_patterns > 0 else 0
|
| 366 |
+
|
| 367 |
+
per_pattern_recall = {}
|
| 368 |
+
for pattern, count in number_of_properly_located_patterns.items():
|
| 369 |
+
pattern_count = test_patterns[test_patterns['Chart Pattern'] == pattern].shape[0]
|
| 370 |
+
if pattern_count > 0:
|
| 371 |
+
per_pattern_recall[pattern] = count / pattern_count
|
| 372 |
+
else:
|
| 373 |
+
per_pattern_recall[pattern] = 0
|
| 374 |
+
|
| 375 |
+
# Precision
|
| 376 |
+
total_number_of_all_located_patterns = len(located_patterns_df)
|
| 377 |
+
total_precision = total_number_of_properly_located_patterns / total_number_of_all_located_patterns if total_number_of_all_located_patterns > 0 else 0
|
| 378 |
+
|
| 379 |
+
per_pattern_precision = {}
|
| 380 |
+
for pattern, count in number_of_properly_located_patterns.items():
|
| 381 |
+
pattern_predictions = located_patterns_df[located_patterns_df['Chart Pattern'] == pattern].shape[0]
|
| 382 |
+
if pattern_predictions > 0:
|
| 383 |
+
per_pattern_precision[pattern] = count / pattern_predictions
|
| 384 |
+
else:
|
| 385 |
+
per_pattern_precision[pattern] = 0
|
| 386 |
+
|
| 387 |
+
# F1 Score
|
| 388 |
+
per_pattern_f1 = {}
|
| 389 |
+
for pattern in per_pattern_recall.keys():
|
| 390 |
+
precision = per_pattern_precision.get(pattern, 0)
|
| 391 |
+
recall = per_pattern_recall.get(pattern, 0)
|
| 392 |
+
if precision + recall > 0:
|
| 393 |
+
per_pattern_f1[pattern] = 2 * (precision * recall) / (precision + recall)
|
| 394 |
+
else:
|
| 395 |
+
per_pattern_f1[pattern] = 0
|
| 396 |
+
|
| 397 |
+
all_precisions = list(per_pattern_precision.values())
|
| 398 |
+
all_recalls = list(per_pattern_recall.values())
|
| 399 |
+
avg_precision = sum(all_precisions) / len(all_precisions) if all_precisions else 0
|
| 400 |
+
avg_recall = sum(all_recalls) / len(all_recalls) if all_recalls else 0
|
| 401 |
+
|
| 402 |
+
if avg_precision + avg_recall == 0:
|
| 403 |
+
overall_f1 = 0
|
| 404 |
+
else:
|
| 405 |
+
overall_f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall)
|
| 406 |
+
|
| 407 |
+
# MAE
|
| 408 |
+
per_pattern_mae = {}
|
| 409 |
+
for pattern, count in number_of_properly_located_patterns.items():
|
| 410 |
+
if count > 0:
|
| 411 |
+
per_pattern_mae[pattern] = mae_for_each_properly_detected_pattern.get(pattern, 0) / count
|
| 412 |
+
else:
|
| 413 |
+
per_pattern_mae[pattern] = 0
|
| 414 |
+
|
| 415 |
+
total_mae_sum = sum(mae_for_each_properly_detected_pattern.values())
|
| 416 |
+
total_proper_patterns = sum(number_of_properly_located_patterns.values())
|
| 417 |
+
overall_mae = total_mae_sum / total_proper_patterns if total_proper_patterns > 0 else 0
|
| 418 |
+
|
| 419 |
+
# IoU
|
| 420 |
+
per_pattern_iou = {}
|
| 421 |
+
for pattern, count in number_of_properly_located_patterns.items():
|
| 422 |
+
if count > 0:
|
| 423 |
+
per_pattern_iou[pattern] = iou_for_each_properly_detected_pattern.get(pattern, 0) / count
|
| 424 |
+
else:
|
| 425 |
+
per_pattern_iou[pattern] = 0
|
| 426 |
+
|
| 427 |
+
total_iou_sum = sum(iou_for_each_properly_detected_pattern.values())
|
| 428 |
+
overall_iou = total_iou_sum / total_proper_patterns if total_proper_patterns > 0 else 0
|
| 429 |
+
|
| 430 |
+
# Print summary of metrics
|
| 431 |
+
print(f"Overall Recall: {total_recall:.4f}")
|
| 432 |
+
print(f"Overall Precision: {total_precision:.4f}")
|
| 433 |
+
print(f"Overall F1 Score: {overall_f1:.4f}")
|
| 434 |
+
print(f"Overall Mean Absolute Error: {overall_mae:.4f}")
|
| 435 |
+
print(f"Overall Mean Intersection over Union: {overall_iou:.4f}")
|
| 436 |
+
|
| 437 |
+
# Store all metrics in one place for easy access
|
| 438 |
+
metrics_summary = {
|
| 439 |
+
'total_recall': total_recall,
|
| 440 |
+
'per_pattern_recall': per_pattern_recall,
|
| 441 |
+
'total_precision': total_precision,
|
| 442 |
+
'per_pattern_precision': per_pattern_precision,
|
| 443 |
+
'overall_f1': overall_f1,
|
| 444 |
+
'per_pattern_f1': per_pattern_f1,
|
| 445 |
+
'overall_mae': overall_mae,
|
| 446 |
+
'per_pattern_mae': per_pattern_mae,
|
| 447 |
+
'overall_iou': overall_iou,
|
| 448 |
+
'per_pattern_iou': per_pattern_iou
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
return metrics_summary
|
| 452 |
+
|
| 453 |
+
# Updated evaluate_all_models function that only creates the comprehensive plot
|
| 454 |
+
def evaluate_all_models(model_eval_results_dict, pattern_row_count, test_patterns, located_patterns_and_other_info_updated_dict):
|
| 455 |
+
Evaluate all models and return metrics summary with comprehensive plot only
|
| 456 |
+
all_models_metrics = {}
|
| 457 |
+
|
| 458 |
+
for model_name in model_eval_results_dict.keys():
|
| 459 |
+
all_models_metrics[model_name] = evaluate_model(
|
| 460 |
+
model_name,
|
| 461 |
+
model_eval_results_dict,
|
| 462 |
+
pattern_row_count,
|
| 463 |
+
test_patterns,
|
| 464 |
+
located_patterns_and_other_info_updated_dict
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# Only create the comprehensive visualization
|
| 468 |
+
if len(model_eval_results_dict) > 0:
|
| 469 |
+
print("\n--- Comprehensive Model Comparison ---")
|
| 470 |
+
# figure = create_comprehensive_model_comparison(all_models_metrics)
|
| 471 |
+
|
| 472 |
+
return all_models_metrics, None # Return None instead of figure
|
| 473 |
+
"""
|
| 474 |
+
###########################################################################################################
|
utils/formatAndPreprocessNewPatterns.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import the necessary libraries
|
| 2 |
+
from multiprocessing import Manager, Value
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from joblib import Parallel, delayed
|
| 7 |
+
import math
|
| 8 |
+
from scipy import interpolate
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from utils.drawPlots import plot_ohlc_segment
|
| 12 |
+
|
| 13 |
+
original_pattern_name_list = [
|
| 14 |
+
'Double Top, Adam and Adam',
|
| 15 |
+
'Double Top, Adam and Eve',
|
| 16 |
+
'Double Top, Eve and Eve',
|
| 17 |
+
'Double Top, Eve and Adam',
|
| 18 |
+
'Double Bottom, Adam and Adam',
|
| 19 |
+
'Double Bottom, Eve and Adam',
|
| 20 |
+
'Double Bottom, Eve and Eve',
|
| 21 |
+
'Double Bottom, Adam and Eve',
|
| 22 |
+
'Triangle, symmetrical',
|
| 23 |
+
'Head-and-shoulders top',
|
| 24 |
+
'Head-and-shoulders bottom',
|
| 25 |
+
'Flag, high and tight'
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
# Updated pattern encoding
|
| 29 |
+
pattern_encoding = {
|
| 30 |
+
'Double Top': 0,
|
| 31 |
+
'Double Bottom': 1,
|
| 32 |
+
'Triangle, symmetrical': 2,
|
| 33 |
+
'Head-and-shoulders top': 3,
|
| 34 |
+
'Head-and-shoulders bottom': 4,
|
| 35 |
+
'Flag, high and tight': 5,
|
| 36 |
+
'No Pattern': 6
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def get_pattern_encoding():
|
| 40 |
+
return pattern_encoding
|
| 41 |
+
|
| 42 |
+
def get_reverse_pattern_encoding():
|
| 43 |
+
return {v: k for k, v in pattern_encoding.items()}
|
| 44 |
+
|
| 45 |
+
def get_patetrn_name_by_encoding(encoding):
|
| 46 |
+
"""
|
| 47 |
+
Get the pattern name by encoding.
|
| 48 |
+
|
| 49 |
+
# Input:
|
| 50 |
+
- encoding (int): The encoding of the pattern.
|
| 51 |
+
|
| 52 |
+
# Returns:
|
| 53 |
+
- str: The name of the pattern.
|
| 54 |
+
"""
|
| 55 |
+
return get_reverse_pattern_encoding().get(encoding, 'Unknown Pattern')
|
| 56 |
+
|
| 57 |
+
def get_pattern_encoding_by_name(name):
|
| 58 |
+
"""
|
| 59 |
+
Get the pattern encoding by name.
|
| 60 |
+
|
| 61 |
+
# Input:
|
| 62 |
+
- name (str): The name of the pattern.
|
| 63 |
+
|
| 64 |
+
# Returns:
|
| 65 |
+
- int: The encoding of the pattern.
|
| 66 |
+
"""
|
| 67 |
+
return get_pattern_encoding().get(name, -1)
|
| 68 |
+
|
| 69 |
+
def get_pattern_list():
|
| 70 |
+
return list(pattern_encoding.keys())
|
| 71 |
+
|
| 72 |
+
def filter_to_get_selected_patterns(df):
|
| 73 |
+
# Filter dataframe to only include selected patterns
|
| 74 |
+
df = df[df['Chart Pattern'].isin(original_pattern_name_list)].copy() # Explicit copy to avoid warning
|
| 75 |
+
|
| 76 |
+
# Replace all variations of Double Top and Double Bottom with simplified names
|
| 77 |
+
double_top_variations = {
|
| 78 |
+
'Double Top, Adam and Adam': 'Double Top',
|
| 79 |
+
'Double Top, Adam and Eve': 'Double Top',
|
| 80 |
+
'Double Top, Eve and Eve': 'Double Top',
|
| 81 |
+
'Double Top, Eve and Adam': 'Double Top'
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
double_bottom_variations = {
|
| 85 |
+
'Double Bottom, Adam and Adam': 'Double Bottom',
|
| 86 |
+
'Double Bottom, Eve and Adam': 'Double Bottom',
|
| 87 |
+
'Double Bottom, Eve and Eve': 'Double Bottom',
|
| 88 |
+
'Double Bottom, Adam and Eve': 'Double Bottom'
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
# Combine all variations into a single mapping
|
| 92 |
+
pattern_mapping = {**double_top_variations, **double_bottom_variations}
|
| 93 |
+
|
| 94 |
+
# Use .loc[] to modify the dataframe safely
|
| 95 |
+
df.loc[:, 'Chart Pattern'] = df['Chart Pattern'].replace(pattern_mapping)
|
| 96 |
+
|
| 97 |
+
return df
|
| 98 |
+
|
| 99 |
+
def normalize_dataset(dataset):
|
| 100 |
+
# calculate the min values from Low column and max values from High column for each instance
|
| 101 |
+
min_low = dataset.groupby(level='Instance')['Low'].transform('min')
|
| 102 |
+
max_high = dataset.groupby(level='Instance')['High'].transform('max')
|
| 103 |
+
|
| 104 |
+
# OHLC columns to normalize
|
| 105 |
+
ohlc_columns = ['Open', 'High', 'Low', 'Close']
|
| 106 |
+
|
| 107 |
+
dataset_normalized = dataset.copy()
|
| 108 |
+
|
| 109 |
+
# Apply the normalization formula to all columns in one go
|
| 110 |
+
dataset_normalized[ohlc_columns] = (dataset_normalized[ohlc_columns] - min_low.values[:, None]) / (max_high.values[:, None] - min_low.values[:, None])
|
| 111 |
+
|
| 112 |
+
# if there is a Volume column normalize it
|
| 113 |
+
if 'Volume' in dataset.columns:
|
| 114 |
+
# calculate the min values from Volume column and max values from Volume column for each instance
|
| 115 |
+
min_volume = dataset.groupby(level='Instance')['Volume'].transform('min')
|
| 116 |
+
max_volume = dataset.groupby(level='Instance')['Volume'].transform('max')
|
| 117 |
+
|
| 118 |
+
# Normalize the Volume column
|
| 119 |
+
dataset_normalized['Volume'] = (dataset_normalized['Volume'] - min_volume.values) / (max_volume.values - min_volume)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
return dataset_normalized
|
| 123 |
+
|
| 124 |
+
def normalize_ohlc_segment(dataset):
|
| 125 |
+
# calculate the min values from Low column and max values from High column for each instance
|
| 126 |
+
min_low = dataset['Low'].min()
|
| 127 |
+
max_high = dataset['High'].max()
|
| 128 |
+
|
| 129 |
+
# OHLC columns to normalize
|
| 130 |
+
ohlc_columns = ['Open', 'High', 'Low', 'Close']
|
| 131 |
+
|
| 132 |
+
dataset_normalized = dataset.copy()
|
| 133 |
+
|
| 134 |
+
if (max_high - min_low) != 0:
|
| 135 |
+
# Apply the normalization formula to all columns in one go
|
| 136 |
+
dataset_normalized[ohlc_columns] = (dataset_normalized[ohlc_columns] - min_low) / (max_high - min_low)
|
| 137 |
+
else :
|
| 138 |
+
print("Error: Max high and min low are equal")
|
| 139 |
+
|
| 140 |
+
# if there is a Volume column normalize it
|
| 141 |
+
if 'Volume' in dataset.columns:
|
| 142 |
+
# calculate the min values from Volume column and max values from Volume column for each instance
|
| 143 |
+
min_volume = dataset['Volume'].min()
|
| 144 |
+
max_volume = dataset['Volume'].max()
|
| 145 |
+
|
| 146 |
+
if (max_volume - min_volume) != 0:
|
| 147 |
+
# Normalize the Volume column
|
| 148 |
+
dataset_normalized['Volume'] = (dataset_normalized['Volume'] - min_volume) / (max_volume - min_volume)
|
| 149 |
+
else:
|
| 150 |
+
print("Error: Max volume and min volume are equal")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
return dataset_normalized
|
| 154 |
+
|
| 155 |
+
def process_row_improved(idx, row, ohlc_df, instance_counter, lock, successful_instances, instance_index_mapping):
|
| 156 |
+
try:
|
| 157 |
+
# Extract info and filter data
|
| 158 |
+
start_date = pd.to_datetime(row['Start'])
|
| 159 |
+
end_date = pd.to_datetime(row['End'])
|
| 160 |
+
|
| 161 |
+
symbol_df_filtered = ohlc_df[(ohlc_df['Date'] >= start_date) &
|
| 162 |
+
(ohlc_df['Date'] <= end_date)]
|
| 163 |
+
|
| 164 |
+
if symbol_df_filtered.empty:
|
| 165 |
+
print(f"Empty result for {row['Symbol']} from {start_date} to {end_date}")
|
| 166 |
+
return None
|
| 167 |
+
|
| 168 |
+
# Get unique instance ID
|
| 169 |
+
with lock:
|
| 170 |
+
unique_instance = instance_counter.value
|
| 171 |
+
instance_counter.value += 1
|
| 172 |
+
|
| 173 |
+
# Explicitly add to instance_index_mapping using string key conversion
|
| 174 |
+
instance_index_mapping[unique_instance] = idx
|
| 175 |
+
|
| 176 |
+
# Track successful instances
|
| 177 |
+
successful_instances.append(unique_instance)
|
| 178 |
+
|
| 179 |
+
# Setup MultiIndex
|
| 180 |
+
symbol_df_filtered = symbol_df_filtered.reset_index(drop=True)
|
| 181 |
+
multi_index = pd.MultiIndex.from_arrays(
|
| 182 |
+
[[unique_instance] * len(symbol_df_filtered), range(len(symbol_df_filtered))],
|
| 183 |
+
names=["Instance", "Time"]
|
| 184 |
+
)
|
| 185 |
+
symbol_df_filtered.index = multi_index
|
| 186 |
+
|
| 187 |
+
# Set index levels to proper types
|
| 188 |
+
symbol_df_filtered.index = symbol_df_filtered.index.set_levels(
|
| 189 |
+
symbol_df_filtered.index.levels[0].astype('int'), level=0
|
| 190 |
+
)
|
| 191 |
+
symbol_df_filtered.index = symbol_df_filtered.index.set_levels(
|
| 192 |
+
symbol_df_filtered.index.levels[1].astype('int64'), level=1
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Add pattern and clean up
|
| 196 |
+
symbol_df_filtered['Pattern'] = pattern_encoding[row['Chart Pattern']]
|
| 197 |
+
symbol_df_filtered.drop('Date', axis=1, inplace=True)
|
| 198 |
+
if 'Adj Close' in symbol_df_filtered.columns:
|
| 199 |
+
symbol_df_filtered.drop('Adj Close', axis=1, inplace=True)
|
| 200 |
+
|
| 201 |
+
# Normalize
|
| 202 |
+
symbol_df_filtered = normalize_ohlc_segment(symbol_df_filtered)
|
| 203 |
+
|
| 204 |
+
return symbol_df_filtered
|
| 205 |
+
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print(f"Error processing {row['Symbol']}: {str(e)}")
|
| 208 |
+
return None
|
| 209 |
+
|
| 210 |
+
def dataset_format(filteredPatternDf, give_instance_index_mapping=False):
|
| 211 |
+
"""
|
| 212 |
+
Formats and preprocesses the dataset with better tracking of successful instances.
|
| 213 |
+
"""
|
| 214 |
+
# Get symbol list from files
|
| 215 |
+
folder_path = 'Datasets/OHLC data/'
|
| 216 |
+
file_list = os.listdir(folder_path)
|
| 217 |
+
symbol_list = [file[:-4] for file in file_list if file.endswith('.csv')]
|
| 218 |
+
|
| 219 |
+
# Check for missing symbols
|
| 220 |
+
symbols_in_df = filteredPatternDf['Symbol'].unique()
|
| 221 |
+
missing_symbols = set(symbols_in_df) - set(symbol_list)
|
| 222 |
+
if missing_symbols:
|
| 223 |
+
print("Missing symbols: ", missing_symbols)
|
| 224 |
+
|
| 225 |
+
# Create a list of tasks (symbol, row pairs)
|
| 226 |
+
tasks = []
|
| 227 |
+
for symbol in symbols_in_df:
|
| 228 |
+
if symbol in symbol_list: # Skip missing symbols
|
| 229 |
+
filteredPatternDf_for_symbol = filteredPatternDf[filteredPatternDf['Symbol'] == symbol]
|
| 230 |
+
file_path = os.path.join(folder_path, f"{symbol}.csv")
|
| 231 |
+
|
| 232 |
+
# Pre-load symbol data
|
| 233 |
+
try:
|
| 234 |
+
symbol_df = pd.read_csv(file_path)
|
| 235 |
+
symbol_df['Date'] = pd.to_datetime(symbol_df['Date'])
|
| 236 |
+
symbol_df['Date'] = symbol_df['Date'].dt.tz_localize(None)
|
| 237 |
+
|
| 238 |
+
for idx, row in filteredPatternDf_for_symbol.iterrows():
|
| 239 |
+
tasks.append((idx, row, symbol_df))
|
| 240 |
+
except Exception as e:
|
| 241 |
+
print(f"Error loading {symbol}: {str(e)}")
|
| 242 |
+
|
| 243 |
+
print(f"Processing {len(tasks)} tasks in parallel...")
|
| 244 |
+
|
| 245 |
+
# Process all tasks with instance tracking
|
| 246 |
+
with Manager() as manager:
|
| 247 |
+
instance_counter = manager.Value('i', 0)
|
| 248 |
+
lock = manager.Lock()
|
| 249 |
+
successful_instances = manager.list() # Track which instances succeed
|
| 250 |
+
instance_index_mapping = manager.dict() # Mapping from instance ID to index
|
| 251 |
+
|
| 252 |
+
results = Parallel(n_jobs=-1, verbose=1)(
|
| 253 |
+
delayed(process_row_improved)(task_idx, row, df, instance_counter, lock, successful_instances, instance_index_mapping)
|
| 254 |
+
for task_idx, row, df in tasks
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Filter out None results
|
| 258 |
+
results = [result for result in results if result is not None]
|
| 259 |
+
|
| 260 |
+
print(f"Total tasks: {len(tasks)}, Successful: {len(results)}")
|
| 261 |
+
print(f"Instance counter final value: {instance_counter.value}")
|
| 262 |
+
print(f"Number of successful instances: {len(successful_instances)}")
|
| 263 |
+
|
| 264 |
+
# # Debug print for mapping
|
| 265 |
+
# print("Debug - Instance Index Mapping:")
|
| 266 |
+
# for k, v in instance_index_mapping.items():
|
| 267 |
+
# print(f"Key: {k}, Value: {v}")
|
| 268 |
+
|
| 269 |
+
if len(successful_instances) < instance_counter.value:
|
| 270 |
+
print("Warning: Some instances were assigned but their tasks failed")
|
| 271 |
+
|
| 272 |
+
# Concatenate results and renumber instances if needed
|
| 273 |
+
if results:
|
| 274 |
+
dataset = pd.concat(results)
|
| 275 |
+
dataset = dataset.sort_index(level=0)
|
| 276 |
+
|
| 277 |
+
# Replace inf/nan values
|
| 278 |
+
dataset.replace([np.inf, -np.inf], np.nan, inplace=True)
|
| 279 |
+
dataset.fillna(method='ffill', inplace=True)
|
| 280 |
+
|
| 281 |
+
if give_instance_index_mapping:
|
| 282 |
+
# Convert manager.dict to a regular dictionary
|
| 283 |
+
instance_index_mapping_dict = dict(instance_index_mapping)
|
| 284 |
+
|
| 285 |
+
print("Converted Mapping:", instance_index_mapping_dict)
|
| 286 |
+
return dataset, instance_index_mapping_dict
|
| 287 |
+
else:
|
| 288 |
+
return dataset
|
| 289 |
+
else:
|
| 290 |
+
return pd.DataFrame()
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def width_augmentation (filteredPatternDf, min_aug_len , aug_len_fraction, make_duplicates = False , keep_original = False):
|
| 296 |
+
"""
|
| 297 |
+
Perform width augmentation on the filtered pattern DataFrame.
|
| 298 |
+
|
| 299 |
+
# Input:
|
| 300 |
+
- filteredPatternDf (pd.DataFrame): The filtered pattern DataFrame.
|
| 301 |
+
- min_aug_len (int): The minimum length of the augmented data.
|
| 302 |
+
- aug_len_fraction (float): The fraction of the original data size to determine the maximum length of the augmented data.
|
| 303 |
+
- make_duplicates (bool): Flag to indicate whether to make duplicates of patterns to reduce dataset imbalance.(make this false on test data)
|
| 304 |
+
- keep_original (bool): Flag to indicate whether to keep the original patterns in the augmented DataFrame.
|
| 305 |
+
|
| 306 |
+
# Returns:
|
| 307 |
+
- filteredPattern_width_aug_df (pd.DataFrame): The DataFrame with width-augmented patterns.
|
| 308 |
+
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
filteredPattern_width_aug_df = pd.DataFrame(columns=filteredPatternDf.columns)
|
| 312 |
+
|
| 313 |
+
print('Performing width augmentation...')
|
| 314 |
+
# print('Number of patterns:', len(filteredPatternDf))
|
| 315 |
+
|
| 316 |
+
# loop through the rows of filteredPatternDf
|
| 317 |
+
for index, row in tqdm(filteredPatternDf.iterrows(), total=len(filteredPatternDf), desc="Processing"):
|
| 318 |
+
|
| 319 |
+
symbol = row['Symbol']
|
| 320 |
+
start_date = row['Start']
|
| 321 |
+
end_date = row['End']
|
| 322 |
+
pattern = row['Chart Pattern']
|
| 323 |
+
|
| 324 |
+
ohlc_df = pd.read_csv(f'Datasets/OHLC data/{symbol}.csv')
|
| 325 |
+
# Ensure all datetime objects are timezone-naive
|
| 326 |
+
ohlc_df['Date'] = pd.to_datetime(ohlc_df['Date']).dt.tz_localize(None)
|
| 327 |
+
|
| 328 |
+
# Convert start_date and end_date to timezone-naive if they have a timezone
|
| 329 |
+
start_date = pd.to_datetime(start_date).tz_localize(None)
|
| 330 |
+
end_date = pd.to_datetime(end_date).tz_localize(None)
|
| 331 |
+
|
| 332 |
+
ohlc_of_interest = ohlc_df[(ohlc_df['Date'] >= start_date) & (ohlc_df['Date'] <= end_date)]
|
| 333 |
+
data_size = len(ohlc_of_interest)
|
| 334 |
+
|
| 335 |
+
if data_size <= 0:
|
| 336 |
+
print (f'No data for {symbol} between {start_date} and {end_date}')
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
# index of ohlc data on the start date and end date
|
| 340 |
+
start_index = ohlc_of_interest.index[0]
|
| 341 |
+
end_index = ohlc_of_interest.index[-1]
|
| 342 |
+
|
| 343 |
+
min_possible_index = 0
|
| 344 |
+
max_possible_index = len(ohlc_df) - 1
|
| 345 |
+
|
| 346 |
+
number_of_rows_for_pattern= filteredPatternDf['Chart Pattern'].value_counts()[pattern]
|
| 347 |
+
max_num_of_rows_for_pattern = filteredPatternDf['Chart Pattern'].value_counts().max()
|
| 348 |
+
|
| 349 |
+
# to make the number of rows for each pattern equal to reduce the imbalance in the dataset
|
| 350 |
+
if make_duplicates:
|
| 351 |
+
num_row_diff = (max_num_of_rows_for_pattern - number_of_rows_for_pattern)*2
|
| 352 |
+
|
| 353 |
+
multiplier = math.ceil(num_row_diff / number_of_rows_for_pattern) +2
|
| 354 |
+
# print ('Pattern :', pattern , 'Multiplier :' , multiplier , 'Number of rows for pattern :', number_of_rows_for_pattern)
|
| 355 |
+
# get a random mvalue between 1 to multiplier
|
| 356 |
+
m = np.random.randint(1, multiplier)
|
| 357 |
+
else:
|
| 358 |
+
m = 1
|
| 359 |
+
|
| 360 |
+
for i in range(m):
|
| 361 |
+
max_aug_len = math.ceil(data_size * aug_len_fraction)
|
| 362 |
+
if max_aug_len < min_aug_len:
|
| 363 |
+
max_aug_len = min_aug_len
|
| 364 |
+
aug_len_l = np.random.randint(1, max_aug_len)
|
| 365 |
+
aug_len_r = np.random.randint(1, max_aug_len)
|
| 366 |
+
|
| 367 |
+
# get the start and end index of the augmented data
|
| 368 |
+
start_index_aug = start_index - aug_len_l
|
| 369 |
+
end_index_aug = end_index + aug_len_r
|
| 370 |
+
|
| 371 |
+
if start_index_aug < min_possible_index:
|
| 372 |
+
start_index_aug = min_possible_index
|
| 373 |
+
if end_index_aug > max_possible_index:
|
| 374 |
+
end_index_aug = max_possible_index
|
| 375 |
+
|
| 376 |
+
# get the date of the start and end index of the augmented data
|
| 377 |
+
start_date_aug = ohlc_df.iloc[start_index_aug]['Date']
|
| 378 |
+
end_date_aug = ohlc_df.iloc[end_index_aug]['Date']
|
| 379 |
+
|
| 380 |
+
# create a new row for the augmented data
|
| 381 |
+
new_row = row.copy()
|
| 382 |
+
new_row['Start'] = start_date_aug
|
| 383 |
+
new_row['End'] = end_date_aug
|
| 384 |
+
filteredPattern_width_aug_df = pd.concat([filteredPattern_width_aug_df, pd.DataFrame([new_row])], ignore_index=True)
|
| 385 |
+
|
| 386 |
+
if keep_original:
|
| 387 |
+
# concat the original row too
|
| 388 |
+
filteredPattern_width_aug_df = pd.concat([filteredPattern_width_aug_df, pd.DataFrame([row])], ignore_index=True)
|
| 389 |
+
|
| 390 |
+
return filteredPattern_width_aug_df
|
| 391 |
+
|
| 392 |
+
def normalize_ohlc_len(df, target_len=30 , plot_count= 0):
|
| 393 |
+
|
| 394 |
+
instances_list = df.index.get_level_values(0).unique()
|
| 395 |
+
normalized_df_list = []
|
| 396 |
+
|
| 397 |
+
# pick 10 random instances from the list of instances to plot
|
| 398 |
+
random_indices = np.random.choice(len(instances_list), plot_count, replace=False)
|
| 399 |
+
|
| 400 |
+
for instance in instances_list:
|
| 401 |
+
|
| 402 |
+
sample = df.loc[instance]
|
| 403 |
+
|
| 404 |
+
pattern_df = sample.copy()
|
| 405 |
+
new_data = {}
|
| 406 |
+
orig_indices = pattern_df.index.values # Changed this line
|
| 407 |
+
new_indices = np.linspace(0, len(orig_indices) - 1, target_len)
|
| 408 |
+
|
| 409 |
+
# First interpolate all numerical columns
|
| 410 |
+
for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
|
| 411 |
+
# Determine the best interpolation method based on data length
|
| 412 |
+
if len(orig_indices) >= 4: # Enough points for cubic
|
| 413 |
+
kind = 'cubic'
|
| 414 |
+
elif len(orig_indices) >= 3: # Can use quadratic
|
| 415 |
+
kind = 'quadratic'
|
| 416 |
+
elif len(orig_indices) >= 2: # Can use linear
|
| 417 |
+
kind = 'linear'
|
| 418 |
+
else: # Not enough points, use nearest
|
| 419 |
+
kind = 'nearest'
|
| 420 |
+
|
| 421 |
+
f = interpolate.interp1d(np.arange(len(orig_indices)), pattern_df[col].values,
|
| 422 |
+
kind=kind, bounds_error=False, fill_value='extrapolate')
|
| 423 |
+
# Apply interpolation function to get new values
|
| 424 |
+
new_data[col] = f(new_indices)
|
| 425 |
+
|
| 426 |
+
# Ensure all OHLC values are positive
|
| 427 |
+
for col in ['Open', 'High', 'Low', 'Close']:
|
| 428 |
+
new_data[col] = np.maximum(new_data[col], 0.001) # Small positive value instead of zero
|
| 429 |
+
|
| 430 |
+
# Fix OHLC relationships
|
| 431 |
+
for i in range(len(new_indices)):
|
| 432 |
+
# Ensure High is the maximum
|
| 433 |
+
new_data['High'][i] = max(new_data['High'][i], new_data['Open'][i], new_data['Close'][i])
|
| 434 |
+
|
| 435 |
+
# Ensure Low is the minimum
|
| 436 |
+
new_data['Low'][i] = min(new_data['Low'][i], new_data['Open'][i], new_data['Close'][i])
|
| 437 |
+
|
| 438 |
+
# Handle categorical data separately
|
| 439 |
+
if 'Pattern' in pattern_df.columns:
|
| 440 |
+
f = interpolate.interp1d(np.arange(len(orig_indices)), pattern_df['Pattern'].values,
|
| 441 |
+
kind='nearest', bounds_error=False, fill_value=pattern_df['Pattern'].iloc[0])
|
| 442 |
+
new_data['Pattern'] = f(new_indices)
|
| 443 |
+
|
| 444 |
+
result_df = pd.DataFrame(new_data)
|
| 445 |
+
result_df.index = pd.MultiIndex.from_product([[instance], result_df.index])
|
| 446 |
+
normalized_df_list.append(result_df)
|
| 447 |
+
|
| 448 |
+
if instance in instances_list[random_indices]: # Fixed this line
|
| 449 |
+
# plot results
|
| 450 |
+
plot_ohlc_segment(pattern_df)
|
| 451 |
+
plot_ohlc_segment(result_df)
|
| 452 |
+
|
| 453 |
+
combined_result_df = pd.concat(normalized_df_list, axis=0) # Fixed this line
|
| 454 |
+
return combined_result_df
|
| 455 |
+
|
| 456 |
+
# Define features, target, and desired series length
|
| 457 |
+
features = ['Open', 'High', 'Low', 'Close', 'Volume']
|
| 458 |
+
target = 'Pattern'
|
| 459 |
+
series_length = 100
|
| 460 |
+
|
| 461 |
+
# This function pads or truncates every instance to length=100,
|
| 462 |
+
# then stacks into an array of shape (n_instances, n_features, series_length)
|
| 463 |
+
def prepare_rocket_data(dataset, features = features, target = target, series_length = series_length):
|
| 464 |
+
def adjust_series_length(group):
|
| 465 |
+
arr = group[features].values
|
| 466 |
+
if len(arr) > series_length:
|
| 467 |
+
return arr[:series_length]
|
| 468 |
+
padding = np.zeros((series_length - len(arr), arr.shape[1]))
|
| 469 |
+
return np.vstack([arr, padding])
|
| 470 |
+
|
| 471 |
+
# Apply per-instance adjustment
|
| 472 |
+
adjusted = dataset.groupby(level=0).apply(adjust_series_length)
|
| 473 |
+
X = np.stack(adjusted.values) # (n_instances, series_length, n_features)
|
| 474 |
+
X = np.transpose(X, (0, 2, 1)) # → (n_instances, n_features, series_length)
|
| 475 |
+
|
| 476 |
+
y = dataset.groupby(level=0)[target].first().values
|
| 477 |
+
return X, y
|