Commit
·
c9b7b21
1
Parent(s):
fbc408c
add files
Browse files- utils/FixedLengthTransformer.py +152 -0
- utils/drawPlots.py +24 -0
- utils/functionalPatternLocateAndPlot.py +653 -0
- utils/patternLocating.py +380 -0
- utils/patternLocatingGemni.py +452 -0
utils/FixedLengthTransformer.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Fixed length transformer, pad or truncate panel to fixed length."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
from sktime.transformations.base import BaseTransformer
|
| 7 |
+
from sktime.utils.pandas import df_map
|
| 8 |
+
|
| 9 |
+
__all__ = ["FixedLengthTransformer"]
|
| 10 |
+
__author__ = ["user"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FixedLengthTransformer(BaseTransformer):
|
| 14 |
+
"""Transform panel of variable length time series to fixed length.
|
| 15 |
+
|
| 16 |
+
Transforms input dataset to a fixed length by either:
|
| 17 |
+
- Padding shorter series with a fill value (default: 0)
|
| 18 |
+
- Truncating longer series to the specified length
|
| 19 |
+
|
| 20 |
+
Unlike PaddingTransformer, this transformer requires a fixed_length parameter
|
| 21 |
+
and will both pad and truncate as needed.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
fixed_length : int
|
| 26 |
+
The exact length that all series will be transformed to
|
| 27 |
+
fill_value : any, optional (default=0)
|
| 28 |
+
The value used to pad shorter series
|
| 29 |
+
|
| 30 |
+
Example
|
| 31 |
+
-------
|
| 32 |
+
>>> import pandas as pd
|
| 33 |
+
>>> from sktime.transformations.panel.fixed_length import FixedLengthTransformer
|
| 34 |
+
>>>
|
| 35 |
+
>>> # Create a sample nested DataFrame with unequal length time series
|
| 36 |
+
>>> data = {
|
| 37 |
+
... 'feature1': [
|
| 38 |
+
... pd.Series([1, 2, 3]), pd.Series([4, 5]), pd.Series([6, 7, 8, 9])
|
| 39 |
+
... ],
|
| 40 |
+
... 'feature2': [
|
| 41 |
+
... pd.Series([10, 11]), pd.Series([12, 13, 14]), pd.Series([15])
|
| 42 |
+
... ]
|
| 43 |
+
... }
|
| 44 |
+
>>> X = pd.DataFrame(data)
|
| 45 |
+
>>>
|
| 46 |
+
>>> # Initialize the FixedLengthTransformer with fixed_length=3
|
| 47 |
+
>>> transformer = FixedLengthTransformer(fixed_length=3)
|
| 48 |
+
>>>
|
| 49 |
+
>>> # Fit the transformer to the data
|
| 50 |
+
>>> transformer.fit(X)
|
| 51 |
+
>>>
|
| 52 |
+
>>> # Transform the data
|
| 53 |
+
>>> Xt = transformer.transform(X)
|
| 54 |
+
>>>
|
| 55 |
+
>>> # Display the transformed data
|
| 56 |
+
>>> print(Xt)
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
_tags = {
|
| 60 |
+
"authors": ["user"],
|
| 61 |
+
"maintainers": ["user"],
|
| 62 |
+
"scitype:transform-input": "Series",
|
| 63 |
+
"scitype:transform-output": "Series",
|
| 64 |
+
"scitype:instancewise": False,
|
| 65 |
+
"X_inner_mtype": "nested_univ",
|
| 66 |
+
"y_inner_mtype": "None",
|
| 67 |
+
"fit_is_empty": True, # No need to compute anything during fit
|
| 68 |
+
"capability:unequal_length:removes": True,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def __init__(self, fixed_length, fill_value=0):
|
| 72 |
+
if fixed_length is None or fixed_length <= 0:
|
| 73 |
+
raise ValueError("fixed_length must be a positive integer")
|
| 74 |
+
|
| 75 |
+
self.fixed_length = fixed_length
|
| 76 |
+
self.fill_value = fill_value
|
| 77 |
+
super().__init__()
|
| 78 |
+
|
| 79 |
+
def _fit(self, X, y=None):
|
| 80 |
+
"""Fit transformer to X and y.
|
| 81 |
+
|
| 82 |
+
This is a no-op since we only need the fixed_length parameter.
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
X : nested pandas DataFrame of shape [n_instances, n_features]
|
| 87 |
+
each cell of X must contain pandas.Series
|
| 88 |
+
y : ignored argument for interface compatibility
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
self : reference to self
|
| 93 |
+
"""
|
| 94 |
+
return self
|
| 95 |
+
|
| 96 |
+
def _transform_series(self, series):
|
| 97 |
+
"""Transform a single series to fixed length by padding or truncating.
|
| 98 |
+
|
| 99 |
+
Parameters
|
| 100 |
+
----------
|
| 101 |
+
series : pandas.Series
|
| 102 |
+
The input series to transform
|
| 103 |
+
|
| 104 |
+
Returns
|
| 105 |
+
-------
|
| 106 |
+
numpy.ndarray
|
| 107 |
+
Fixed length array
|
| 108 |
+
"""
|
| 109 |
+
series_length = len(series)
|
| 110 |
+
|
| 111 |
+
if series_length == self.fixed_length:
|
| 112 |
+
# Series is already the correct length
|
| 113 |
+
return series.values
|
| 114 |
+
elif series_length < self.fixed_length:
|
| 115 |
+
# Pad the series with fill_value
|
| 116 |
+
result = np.full(self.fixed_length, self.fill_value, dtype=float)
|
| 117 |
+
result[:series_length] = series.iloc[:series_length]
|
| 118 |
+
return result
|
| 119 |
+
else:
|
| 120 |
+
# Truncate the series
|
| 121 |
+
return series.iloc[:self.fixed_length].values
|
| 122 |
+
|
| 123 |
+
def _transform(self, X, y=None):
|
| 124 |
+
"""Transform X and return a transformed version.
|
| 125 |
+
|
| 126 |
+
Parameters
|
| 127 |
+
----------
|
| 128 |
+
X : nested pandas DataFrame of shape [n_instances, n_features]
|
| 129 |
+
each cell of X must contain pandas.Series
|
| 130 |
+
y : ignored argument for interface compatibility
|
| 131 |
+
|
| 132 |
+
Returns
|
| 133 |
+
-------
|
| 134 |
+
Xt : nested pandas DataFrame of shape [n_instances, n_features]
|
| 135 |
+
each cell of Xt contains pandas.Series with fixed length
|
| 136 |
+
"""
|
| 137 |
+
n_instances, _ = X.shape
|
| 138 |
+
|
| 139 |
+
# Process each row of instances
|
| 140 |
+
transformed_rows = []
|
| 141 |
+
for i in range(n_instances):
|
| 142 |
+
# Transform each series in the row
|
| 143 |
+
row_series = X.iloc[i, :].values
|
| 144 |
+
transformed_series = [pd.Series(self._transform_series(series))
|
| 145 |
+
for series in row_series]
|
| 146 |
+
transformed_rows.append(pd.Series(transformed_series))
|
| 147 |
+
|
| 148 |
+
# Convert back to DataFrame
|
| 149 |
+
Xt = df_map(pd.DataFrame(transformed_rows))(pd.Series)
|
| 150 |
+
Xt.columns = X.columns
|
| 151 |
+
|
| 152 |
+
return Xt
|
utils/drawPlots.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
# import mplfinance as mpf
|
| 3 |
+
|
| 4 |
+
def plot_ohlc_segment(data_segment):
|
| 5 |
+
"""
|
| 6 |
+
Plots a segment of OHLC data using mplfinance.
|
| 7 |
+
|
| 8 |
+
Parameters:
|
| 9 |
+
- data_segment (pd.DataFrame): A DataFrame containing columns ['Open', 'High', 'Low', 'Close', 'Volume']
|
| 10 |
+
"""
|
| 11 |
+
# Commenting out plotting functionality
|
| 12 |
+
pass
|
| 13 |
+
"""
|
| 14 |
+
# Ensure the DataFrame index is datetime for mplfinance
|
| 15 |
+
data_segment = data_segment.copy()
|
| 16 |
+
data_segment.index = pd.date_range(start='2024-01-01', periods=len(data_segment), freq='D')
|
| 17 |
+
|
| 18 |
+
# Plot the candlestick chart
|
| 19 |
+
mpf.plot(data_segment, type='candle', style='charles',
|
| 20 |
+
volume=True, ylabel='Price', ylabel_lower='Volume',
|
| 21 |
+
title="OHLC Segment", figsize=(10, 6))
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
|
utils/functionalPatternLocateAndPlot.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import matplotlib
|
| 2 |
+
# matplotlib.use('Agg')
|
| 3 |
+
from scipy.signal import find_peaks
|
| 4 |
+
|
| 5 |
+
from utils.formatAndPreprocessNewPatterns import get_pattern_encoding
|
| 6 |
+
|
| 7 |
+
path = 'Datasets/OHLC data'
|
| 8 |
+
pattern_encoding = get_pattern_encoding()
|
| 9 |
+
|
| 10 |
+
def calc_head_and_sholder_top(row,ohlc_data_pattern_segment):
|
| 11 |
+
high_prices = ohlc_data_pattern_segment['High'].values
|
| 12 |
+
low_prices = ohlc_data_pattern_segment['Low'].values
|
| 13 |
+
|
| 14 |
+
# Adjust this parameter to suit your data – lower values detect smaller features.
|
| 15 |
+
prominence_value = 0.1
|
| 16 |
+
|
| 17 |
+
# Find peaks (local maxima)
|
| 18 |
+
peak_indices, _ = find_peaks(high_prices, prominence=prominence_value)
|
| 19 |
+
# Find valleys (local minima) by inverting the low prices
|
| 20 |
+
valley_indices, _ = find_peaks(-low_prices, prominence=prominence_value)
|
| 21 |
+
|
| 22 |
+
# create a list of dates for peaks and valleys
|
| 23 |
+
peak_dates = ohlc_data_pattern_segment['Date'].iloc[peak_indices]
|
| 24 |
+
valley_dates = ohlc_data_pattern_segment['Date'].iloc[valley_indices]
|
| 25 |
+
|
| 26 |
+
if len(peak_indices) < 3 or len(valley_indices) < 2:
|
| 27 |
+
print("Not enough peaks and valleys to form a Head & Shoulders pattern.")
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
H_index = np.argmax(high_prices[peak_indices])
|
| 32 |
+
H = peak_indices[H_index]
|
| 33 |
+
LS_index = np.argmax(high_prices[peak_indices[0:H_index]])
|
| 34 |
+
LS = peak_indices[LS_index]
|
| 35 |
+
RS_index = np.argmax(high_prices[peak_indices[H_index+1:]]) + H_index + 1
|
| 36 |
+
RS = peak_indices[RS_index]
|
| 37 |
+
|
| 38 |
+
vally_left = valley_indices[(valley_indices > LS) & (valley_indices < H)]
|
| 39 |
+
vally_right = valley_indices[(valley_indices > H) & (valley_indices < RS)]
|
| 40 |
+
NL1 = vally_left[np.argmin(low_prices[vally_left])]
|
| 41 |
+
NL2 = vally_right[np.argmin(low_prices[vally_right])]
|
| 42 |
+
|
| 43 |
+
# Ensure the middle peak is the highest
|
| 44 |
+
if high_prices[H] <= max(high_prices[LS], high_prices[RS]):
|
| 45 |
+
print("Not a valid Head & Shoulders pattern.")
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
LS_date = ohlc_data_pattern_segment['Date'].iloc[LS]
|
| 49 |
+
H_date = ohlc_data_pattern_segment['Date'].iloc[H]
|
| 50 |
+
RS_date = ohlc_data_pattern_segment['Date'].iloc[RS]
|
| 51 |
+
NL1_date = ohlc_data_pattern_segment['Date'].iloc[NL1]
|
| 52 |
+
NL2_date = ohlc_data_pattern_segment['Date'].iloc[NL2]
|
| 53 |
+
|
| 54 |
+
# add the dates to the row
|
| 55 |
+
row['HS_Left_Shoulder'] = LS_date
|
| 56 |
+
row['HS_Head'] = H_date
|
| 57 |
+
row['HS_Right_Shoulder'] = RS_date
|
| 58 |
+
row['HS_Neckline_1'] = NL1_date
|
| 59 |
+
row['HS_Neckline_2'] = NL2_date
|
| 60 |
+
row['Peak_Dates'] = peak_dates
|
| 61 |
+
row['Valley_Dates'] = valley_dates
|
| 62 |
+
row['Calc_Start'] = LS_date
|
| 63 |
+
row['Calc_End'] = RS_date
|
| 64 |
+
|
| 65 |
+
return row
|
| 66 |
+
except:
|
| 67 |
+
print("Error in finding the peaks or valleys in the Head and Shoulders pattern")
|
| 68 |
+
return
|
| 69 |
+
|
| 70 |
+
def calc_head_and_shoulder_bottom(row, ohlc_data_pattern_segment):
|
| 71 |
+
high_prices = ohlc_data_pattern_segment['High'].values
|
| 72 |
+
low_prices = ohlc_data_pattern_segment['Low'].values
|
| 73 |
+
|
| 74 |
+
# Adjust this parameter to suit your data – lower values detect smaller features.
|
| 75 |
+
prominence_value = 0.1
|
| 76 |
+
|
| 77 |
+
# Find valleys (local minima)
|
| 78 |
+
valley_indices, _ = find_peaks(-low_prices, prominence=prominence_value)
|
| 79 |
+
# Find peaks (local maxima)
|
| 80 |
+
peak_indices, _ = find_peaks(high_prices, prominence=prominence_value)
|
| 81 |
+
|
| 82 |
+
# Create lists of dates for valleys and peaks
|
| 83 |
+
valley_dates = ohlc_data_pattern_segment['Date'].iloc[valley_indices]
|
| 84 |
+
peak_dates = ohlc_data_pattern_segment['Date'].iloc[peak_indices]
|
| 85 |
+
|
| 86 |
+
if len(valley_indices) < 3 or len(peak_indices) < 2:
|
| 87 |
+
print("Not enough valleys and peaks to form a Head & Shoulders Bottom pattern.")
|
| 88 |
+
return
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
H_index = np.argmin(low_prices[valley_indices]) # Find lowest valley (Head)
|
| 92 |
+
H = valley_indices[H_index]
|
| 93 |
+
LS_index = np.argmin(low_prices[valley_indices[0:H_index]])
|
| 94 |
+
LS = valley_indices[LS_index]
|
| 95 |
+
RS_index = np.argmin(low_prices[valley_indices[H_index+1:]]) + H_index + 1
|
| 96 |
+
RS = valley_indices[RS_index]
|
| 97 |
+
|
| 98 |
+
peak_left = peak_indices[(peak_indices > LS) & (peak_indices < H)]
|
| 99 |
+
peak_right = peak_indices[(peak_indices > H) & (peak_indices < RS)]
|
| 100 |
+
NL1 = peak_left[np.argmax(high_prices[peak_left])]
|
| 101 |
+
NL2 = peak_right[np.argmax(high_prices[peak_right])]
|
| 102 |
+
|
| 103 |
+
# Ensure the middle valley is the lowest
|
| 104 |
+
if low_prices[H] >= min(low_prices[LS], low_prices[RS]):
|
| 105 |
+
print("Not a valid Head & Shoulders Bottom pattern.")
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
LS_date = ohlc_data_pattern_segment['Date'].iloc[LS]
|
| 109 |
+
H_date = ohlc_data_pattern_segment['Date'].iloc[H]
|
| 110 |
+
RS_date = ohlc_data_pattern_segment['Date'].iloc[RS]
|
| 111 |
+
NL1_date = ohlc_data_pattern_segment['Date'].iloc[NL1]
|
| 112 |
+
NL2_date = ohlc_data_pattern_segment['Date'].iloc[NL2]
|
| 113 |
+
|
| 114 |
+
# Add the detected pattern data to the row
|
| 115 |
+
row['HS_Left_Shoulder'] = LS_date
|
| 116 |
+
row['HS_Head'] = H_date
|
| 117 |
+
row['HS_Right_Shoulder'] = RS_date
|
| 118 |
+
row['HS_Neckline_1'] = NL1_date
|
| 119 |
+
row['HS_Neckline_2'] = NL2_date
|
| 120 |
+
row['Valley_Dates'] = valley_dates
|
| 121 |
+
row['Peak_Dates'] = peak_dates
|
| 122 |
+
row['Calc_Start'] = LS_date
|
| 123 |
+
row['Calc_End'] = RS_date
|
| 124 |
+
|
| 125 |
+
return row
|
| 126 |
+
except:
|
| 127 |
+
print("Error in finding the valleys or peaks in the Head and Shoulders Bottom pattern")
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
def calc_double_top_aa(row,ohlc_data_pattern_segment):
|
| 131 |
+
high_prices = ohlc_data_pattern_segment['High'].values
|
| 132 |
+
low_prices = ohlc_data_pattern_segment['Low'].values
|
| 133 |
+
|
| 134 |
+
# Adjust this parameter to suit your data – lower values detect smaller features.
|
| 135 |
+
prominence_value = 0.1
|
| 136 |
+
|
| 137 |
+
# Find peaks (local maxima)
|
| 138 |
+
peak_indices, _ = find_peaks(high_prices, prominence=prominence_value)
|
| 139 |
+
# Find valleys (local minima) by inverting the low prices
|
| 140 |
+
valley_indices, _ = find_peaks(-low_prices, prominence=prominence_value)
|
| 141 |
+
|
| 142 |
+
# create a list of dates for peaks and valleys
|
| 143 |
+
peak_dates = ohlc_data_pattern_segment['Date'].iloc[peak_indices]
|
| 144 |
+
valley_dates = ohlc_data_pattern_segment['Date'].iloc[valley_indices]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if len(peak_indices) < 2 or len(valley_indices) < 1:
|
| 148 |
+
print("Not enough peaks and valleys to form a Double Top pattern.")
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
H1_index = np.argmax(high_prices[peak_indices])
|
| 153 |
+
H1 = peak_indices[H1_index]
|
| 154 |
+
H2_index = np.argmax(high_prices[peak_indices[H1_index+1:]]) + H1_index + 1
|
| 155 |
+
H2 = peak_indices[H2_index]
|
| 156 |
+
# get v index that is between H1 and H2
|
| 157 |
+
valley_indices_between_H1_H2 = valley_indices[(valley_indices > H1) & (valley_indices < H2)]
|
| 158 |
+
V = valley_indices_between_H1_H2[np.argmax(low_prices[ valley_indices_between_H1_H2])]
|
| 159 |
+
|
| 160 |
+
# # Ensure the middle peak is the highest
|
| 161 |
+
# if high_prices[H1] <= high_prices[H2]:
|
| 162 |
+
# print("Not a valid Double Top pattern.")
|
| 163 |
+
# return
|
| 164 |
+
|
| 165 |
+
H1_date = ohlc_data_pattern_segment['Date'].iloc[H1]
|
| 166 |
+
H2_date = ohlc_data_pattern_segment['Date'].iloc[H2]
|
| 167 |
+
V_date = ohlc_data_pattern_segment['Date'].iloc[V]
|
| 168 |
+
|
| 169 |
+
# add the dates to the row
|
| 170 |
+
row['DT_Peak_1'] = H1_date
|
| 171 |
+
row['DT_Peak_2'] = H2_date
|
| 172 |
+
row['DT_Valley'] = V_date
|
| 173 |
+
row['Peak_Dates'] = peak_dates
|
| 174 |
+
row['Valley_Dates'] = valley_dates
|
| 175 |
+
row['Calc_Start'] = H1_date
|
| 176 |
+
row['Calc_End'] = H2_date
|
| 177 |
+
|
| 178 |
+
return row
|
| 179 |
+
except:
|
| 180 |
+
print("Error in finding the peaks or valleys in the Double Top pattern")
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
def calc_double_bottom_aa(row,ohlc_data_pattern_segment):
|
| 184 |
+
high_prices = ohlc_data_pattern_segment['High'].values
|
| 185 |
+
low_prices = ohlc_data_pattern_segment['Low'].values
|
| 186 |
+
|
| 187 |
+
# Adjust this parameter to suit your data – lower values detect smaller features.
|
| 188 |
+
prominence_value = 0.05
|
| 189 |
+
|
| 190 |
+
# Find valleys (local minima)
|
| 191 |
+
valley_indices, _ = find_peaks(-low_prices, prominence=prominence_value)
|
| 192 |
+
# Find peaks (local maxima)
|
| 193 |
+
peak_indices, _ = find_peaks(high_prices, prominence=prominence_value)
|
| 194 |
+
|
| 195 |
+
# Create lists of dates for valleys and peaks
|
| 196 |
+
valley_dates = ohlc_data_pattern_segment['Date'].iloc[valley_indices]
|
| 197 |
+
peak_dates = ohlc_data_pattern_segment['Date'].iloc[peak_indices]
|
| 198 |
+
|
| 199 |
+
if len(valley_indices) < 2 or len(peak_indices) < 1:
|
| 200 |
+
print("Not enough valleys and peaks to form a Double Bottom pattern.")
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
H1_index = np.argmin(low_prices[valley_indices])
|
| 205 |
+
H1 = valley_indices[H1_index]
|
| 206 |
+
H2_index = np.argmin(low_prices[valley_indices[H1_index+1:]]) + H1_index + 1
|
| 207 |
+
H2 = valley_indices[H2_index]
|
| 208 |
+
# get v index that is between H1 and H2
|
| 209 |
+
peak_indices_between_H1_H2 = peak_indices[(peak_indices > H1) & (peak_indices < H2)]
|
| 210 |
+
P = peak_indices_between_H1_H2[np.argmax(high_prices[ peak_indices_between_H1_H2])]
|
| 211 |
+
|
| 212 |
+
# # Ensure the middle valley is the lowest
|
| 213 |
+
# if low_prices[H1] >= low_prices[H2]:
|
| 214 |
+
# print("Not a valid Double Bottom pattern.")
|
| 215 |
+
# return
|
| 216 |
+
|
| 217 |
+
H1_date = ohlc_data_pattern_segment['Date'].iloc[H1]
|
| 218 |
+
H2_date = ohlc_data_pattern_segment['Date'].iloc[H2]
|
| 219 |
+
P_date = ohlc_data_pattern_segment['Date'].iloc[P]
|
| 220 |
+
|
| 221 |
+
# Add the detected pattern data to the row
|
| 222 |
+
row['DB_Valley_1'] = H1_date
|
| 223 |
+
row['DB_Valley_2'] = H2_date
|
| 224 |
+
row['DB_Peak'] = P_date
|
| 225 |
+
row['Valley_Dates'] = valley_dates
|
| 226 |
+
row['Peak_Dates'] = peak_dates
|
| 227 |
+
row['Calc_Start'] = H1_date
|
| 228 |
+
row['Calc_End'] = H2_date
|
| 229 |
+
|
| 230 |
+
return row
|
| 231 |
+
except:
|
| 232 |
+
print("Error in finding the valleys or peaks in the Double Bottom pattern")
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
def calc_double_bottom_ea(row,ohlc_data_pattern_segment):
|
| 236 |
+
high_prices = ohlc_data_pattern_segment['High'].values
|
| 237 |
+
low_prices = ohlc_data_pattern_segment['Low'].values
|
| 238 |
+
|
| 239 |
+
# Adjust this parameter to suit your data – lower values detect smaller features.
|
| 240 |
+
prominence_value = 0.1
|
| 241 |
+
|
| 242 |
+
# Find valleys (local minima)
|
| 243 |
+
valley_indices, _ = find_peaks(-low_prices, prominence=prominence_value)
|
| 244 |
+
# Find peaks (local maxima)
|
| 245 |
+
peak_indices, _ = find_peaks(high_prices, prominence=prominence_value)
|
| 246 |
+
|
| 247 |
+
round_vallies,_ = find_peaks(-low_prices, prominence=0.01,width=3,threshold=0.01)
|
| 248 |
+
|
| 249 |
+
# Create lists of dates for valleys and peaks
|
| 250 |
+
valley_dates = ohlc_data_pattern_segment['Date'].iloc[valley_indices]
|
| 251 |
+
peak_dates = ohlc_data_pattern_segment['Date'].iloc[peak_indices]
|
| 252 |
+
|
| 253 |
+
if len(valley_indices) < 2 or len(peak_indices) < 1:
|
| 254 |
+
print("Not enough valleys and peaks to form a Double Bottom pattern.")
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
+
try:
|
| 258 |
+
H1_index = np.argmin(low_prices[round_vallies])
|
| 259 |
+
H1 = valley_indices[H1_index]
|
| 260 |
+
H2_index = np.argmin(low_prices[valley_indices[H1_index+1:]]) + H1_index + 1
|
| 261 |
+
H2 = valley_indices[H2_index]
|
| 262 |
+
# get v index that is between H1 and H2
|
| 263 |
+
peak_indices_between_H1_H2 = peak_indices[(peak_indices > H1) & (peak_indices < H2)]
|
| 264 |
+
P = peak_indices_between_H1_H2[np.argmax(high_prices[ peak_indices_between_H1_H2])]
|
| 265 |
+
|
| 266 |
+
# # Ensure the middle valley is the lowest
|
| 267 |
+
# if low_prices[H1] >= low_prices[H2]:
|
| 268 |
+
# print("Not a valid Double Bottom pattern.")
|
| 269 |
+
# return
|
| 270 |
+
|
| 271 |
+
H1_date = ohlc_data_pattern_segment['Date'].iloc[H1]
|
| 272 |
+
H2_date = ohlc_data_pattern_segment['Date'].iloc[H2]
|
| 273 |
+
P_date = ohlc_data_pattern_segment['Date'].iloc[P]
|
| 274 |
+
|
| 275 |
+
# Add the detected pattern data to the row
|
| 276 |
+
row['DB_Valley_1'] = H1_date
|
| 277 |
+
row['DB_Valley_2'] = H2_date
|
| 278 |
+
row['DB_Peak'] = P_date
|
| 279 |
+
row['Valley_Dates'] = valley_dates
|
| 280 |
+
row['Peak_Dates'] = peak_dates
|
| 281 |
+
row['Calc_Start'] = H1_date
|
| 282 |
+
row['Calc_End'] = H2_date
|
| 283 |
+
|
| 284 |
+
return row
|
| 285 |
+
except:
|
| 286 |
+
print("Error in finding the valleys or peaks in the Double Bottom pattern")
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Commenting out all plotting functions
|
| 292 |
+
"""
|
| 293 |
+
import matplotlib.pyplot as plt
|
| 294 |
+
import mplfinance as mpf
|
| 295 |
+
import pandas as pd
|
| 296 |
+
import numpy as np
|
| 297 |
+
import pandas as pd
|
| 298 |
+
import matplotlib.pyplot as plt
|
| 299 |
+
import mplfinance as mpf
|
| 300 |
+
from scipy.signal import argrelextrema
|
| 301 |
+
from scipy.signal import find_peaks
|
| 302 |
+
|
| 303 |
+
def draw_head_and_shoulders_top(ax, ohlc_data, pat_start_idx,row):
|
| 304 |
+
|
| 305 |
+
Draws a Head and Shoulders pattern on an existing mplfinance plot and visualizes detected peaks and valleys.
|
| 306 |
+
|
| 307 |
+
Parameters:
|
| 308 |
+
ax (matplotlib.axes.Axes): The candlestick chart's axis.
|
| 309 |
+
ohlc_data (pd.DataFrame): Data containing 'High' and 'Low' columns.
|
| 310 |
+
|
| 311 |
+
# reset the index of the ohlc_data
|
| 312 |
+
ohlc_data.reset_index(drop=True, inplace=True)
|
| 313 |
+
high_prices = ohlc_data['High'].values
|
| 314 |
+
low_prices = ohlc_data['Low'].values
|
| 315 |
+
|
| 316 |
+
# check if 'Peak_Dates' and 'Valley_Dates' columns are present in the row
|
| 317 |
+
if 'Peak_Dates' in row and 'Valley_Dates' in row:
|
| 318 |
+
|
| 319 |
+
peak_days = row['Peak_Dates']
|
| 320 |
+
valley_days = row['Valley_Dates']
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
peak_indices = ohlc_data[ohlc_data['Date'].isin(peak_days)].index
|
| 324 |
+
# add the pat_start_idx to the peak_indices
|
| 325 |
+
peak_indices = peak_indices
|
| 326 |
+
|
| 327 |
+
valley_indices = ohlc_data[ohlc_data['Date'].isin(valley_days)].index
|
| 328 |
+
# add the pat_start_idx to the valley_indices
|
| 329 |
+
valley_indices = valley_indices
|
| 330 |
+
|
| 331 |
+
# Debugging visualization: Plot detected peaks and valleys
|
| 332 |
+
ax.scatter(peak_indices , high_prices[peak_indices], color='green', marker='^', label='Peaks', zorder=3)
|
| 333 |
+
ax.scatter(valley_indices, low_prices[valley_indices], color='red', marker='v', label='Valleys', zorder=3)
|
| 334 |
+
|
| 335 |
+
calc_start_date = row['Calc_Start']
|
| 336 |
+
calc_end_date = row['Calc_End']
|
| 337 |
+
|
| 338 |
+
calc_start_idx = ohlc_data[ohlc_data['Date']== calc_start_date].index
|
| 339 |
+
calc_end_idx = ohlc_data[ohlc_data['Date']== calc_end_date].index
|
| 340 |
+
|
| 341 |
+
# drow a pink dotted vertical line at calc_start_idx and calc_end_idx
|
| 342 |
+
ax.axvline(x=calc_start_idx, color='blue', linestyle='dotted', linewidth=1)
|
| 343 |
+
ax.axvline(x=calc_end_idx, color='blue', linestyle='dotted', linewidth=1)
|
| 344 |
+
|
| 345 |
+
LS_idx = ohlc_data[ohlc_data['Date']== row['HS_Left_Shoulder']].index
|
| 346 |
+
H_idx = ohlc_data[ohlc_data['Date']== row['HS_Head']].index
|
| 347 |
+
RS_idx = ohlc_data[ohlc_data['Date']== row['HS_Right_Shoulder']].index
|
| 348 |
+
NL1_idx = ohlc_data[ohlc_data['Date']== row['HS_Neckline_1']].index
|
| 349 |
+
NL2_idx = ohlc_data[ohlc_data['Date']== row['HS_Neckline_2']].index
|
| 350 |
+
|
| 351 |
+
# Draw the head and shoulders
|
| 352 |
+
ax.plot([LS_idx, H_idx, RS_idx], [high_prices[LS_idx], high_prices[H_idx], high_prices[RS_idx]],
|
| 353 |
+
linestyle="solid", marker="o", color="blue", linewidth=1, label="H&S Pattern")
|
| 354 |
+
|
| 355 |
+
# Use NL1_idx and NL2_idx as the x-range to keep the line within bounds
|
| 356 |
+
x_min, x_max = min(NL1_idx, NL2_idx), max(NL1_idx, NL2_idx)
|
| 357 |
+
|
| 358 |
+
# Compute the y-values using the line equation (y = mx + c)
|
| 359 |
+
slope = (low_prices[NL2_idx] - low_prices[NL1_idx]) / (NL2_idx - NL1_idx)
|
| 360 |
+
y_min = low_prices[NL1_idx] + slope * (x_min - NL1_idx)
|
| 361 |
+
y_max = low_prices[NL1_idx] + slope * (x_max - NL1_idx)
|
| 362 |
+
|
| 363 |
+
# Plot the line within the original graph size
|
| 364 |
+
ax.plot([x_min, x_max], [y_min, y_max],
|
| 365 |
+
linestyle="dashed", color="red", linewidth=1, label="Neckline")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def draw_head_and_shoulders_bottom(ax, ohlc_data, pat_start_idx,row):
|
| 374 |
+
|
| 375 |
+
Draws a Head and Shoulders pattern on an existing mplfinance plot and visualizes detected peaks and valleys.
|
| 376 |
+
|
| 377 |
+
Parameters:
|
| 378 |
+
ax (matplotlib.axes.Axes): The candlestick chart's axis.
|
| 379 |
+
ohlc_data (pd.DataFrame): Data containing 'High' and 'Low' columns.
|
| 380 |
+
|
| 381 |
+
# reset the index of the ohlc_data
|
| 382 |
+
ohlc_data.reset_index(drop=True, inplace=True)
|
| 383 |
+
high_prices = ohlc_data['High'].values
|
| 384 |
+
low_prices = ohlc_data['Low'].values
|
| 385 |
+
|
| 386 |
+
# check if 'Peak_Dates' and 'Valley_Dates' columns are present in the row
|
| 387 |
+
if 'Peak_Dates' in row and 'Valley_Dates' in row:
|
| 388 |
+
peak_days = row['Peak_Dates']
|
| 389 |
+
valley_days = row['Valley_Dates']
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
peak_indices = ohlc_data[ohlc_data['Date'].isin(peak_days)].index
|
| 393 |
+
# add the pat_start_idx to the peak_indices
|
| 394 |
+
peak_indices = peak_indices
|
| 395 |
+
|
| 396 |
+
valley_indices = ohlc_data[ohlc_data['Date'].isin(valley_days)].index
|
| 397 |
+
# add the pat_start_idx to the valley_indices
|
| 398 |
+
valley_indices = valley_indices
|
| 399 |
+
|
| 400 |
+
# Debugging visualization: Plot detected peaks and valleys
|
| 401 |
+
ax.scatter(peak_indices , high_prices[peak_indices], color='green', marker='^', label='Peaks', zorder=3)
|
| 402 |
+
ax.scatter(valley_indices, low_prices[valley_indices], color='red', marker='v', label='Valleys', zorder=3)
|
| 403 |
+
|
| 404 |
+
calc_start_date = row['Calc_Start']
|
| 405 |
+
calc_end_date = row['Calc_End']
|
| 406 |
+
|
| 407 |
+
calc_start_idx = ohlc_data[ohlc_data['Date']== calc_start_date].index
|
| 408 |
+
calc_end_idx = ohlc_data[ohlc_data['Date']== calc_end_date].index
|
| 409 |
+
|
| 410 |
+
# drow a pink dotted vertical line at calc_start_idx and calc_end_idx
|
| 411 |
+
ax.axvline(x=calc_start_idx, color='blue', linestyle='dotted', linewidth=1)
|
| 412 |
+
ax.axvline(x=calc_end_idx, color='blue', linestyle='dotted', linewidth=1)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
LS_idx = ohlc_data[ohlc_data['Date']== row['HS_Left_Shoulder']].index
|
| 416 |
+
H_idx = ohlc_data[ohlc_data['Date']== row['HS_Head']].index
|
| 417 |
+
RS_idx = ohlc_data[ohlc_data['Date']== row['HS_Right_Shoulder']].index
|
| 418 |
+
NL1_idx = ohlc_data[ohlc_data['Date']== row['HS_Neckline_1']].index
|
| 419 |
+
NL2_idx = ohlc_data[ohlc_data['Date']== row['HS_Neckline_2']].index
|
| 420 |
+
|
| 421 |
+
# Draw the head and shoulders
|
| 422 |
+
ax.plot([LS_idx, H_idx, RS_idx], [low_prices[LS_idx], low_prices[H_idx], low_prices[RS_idx]],
|
| 423 |
+
linestyle="solid", marker="o", color="blue", linewidth=1, label="H&S Pattern")
|
| 424 |
+
|
| 425 |
+
# Use NL1_idx and NL2_idx as the x-range to keep the line within bounds
|
| 426 |
+
x_min, x_max = min(NL1_idx, NL2_idx), max(NL1_idx, NL2_idx)
|
| 427 |
+
|
| 428 |
+
# Compute the y-values using the line equation (y = mx + c)
|
| 429 |
+
slope = (high_prices[NL2_idx] - high_prices[NL1_idx]) / (NL2_idx - NL1_idx)
|
| 430 |
+
y_min = high_prices[NL1_idx] + slope * (x_min - NL1_idx)
|
| 431 |
+
y_max = high_prices[NL1_idx] + slope * (x_max - NL1_idx)
|
| 432 |
+
|
| 433 |
+
# Plot the line within the original graph size
|
| 434 |
+
ax.plot([x_min, x_max], [y_min, y_max],
|
| 435 |
+
linestyle="dashed", color="red", linewidth=1, label="Neckline")
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def draw_double_top_aa(ax, ohlc_data, pat_start_idx,row):
|
| 440 |
+
|
| 441 |
+
Draws a Double Top pattern on an existing mplfinance plot and visualizes detected peaks and valleys.
|
| 442 |
+
|
| 443 |
+
Parameters:
|
| 444 |
+
ax (matplotlib.axes.Axes): The candlestick chart's axis.
|
| 445 |
+
ohlc_data (pd.DataFrame): Data containing 'High' and 'Low' columns.
|
| 446 |
+
|
| 447 |
+
# reset the index of the ohlc_data
|
| 448 |
+
ohlc_data.reset_index(drop=True, inplace=True)
|
| 449 |
+
high_prices = ohlc_data['High'].values
|
| 450 |
+
low_prices = ohlc_data['Low'].values
|
| 451 |
+
|
| 452 |
+
# check if 'Peak_Dates' and 'Valley_Dates' columns are present in the row
|
| 453 |
+
if 'Peak_Dates' in row and 'Valley_Dates' in row:
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
peak_days = row['Peak_Dates']
|
| 457 |
+
valley_days = row['Valley_Dates']
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
peak_indices = ohlc_data[ohlc_data['Date'].isin(peak_days)].index
|
| 461 |
+
# add the pat_start_idx to the peak_indices
|
| 462 |
+
peak_indices = peak_indices
|
| 463 |
+
|
| 464 |
+
valley_indices = ohlc_data[ohlc_data['Date'].isin(valley_days)].index
|
| 465 |
+
# add the pat_start_idx to the valley_indices
|
| 466 |
+
valley_indices = valley_indices
|
| 467 |
+
|
| 468 |
+
# Debugging visualization: Plot detected peaks and valleys
|
| 469 |
+
ax.scatter(peak_indices , high_prices[peak_indices], color='green', marker='^', label='Peaks', zorder=3)
|
| 470 |
+
ax.scatter(valley_indices, low_prices[valley_indices], color='red', marker='v', label='Valleys', zorder=3)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
DT_Peak_1_idx = ohlc_data[ohlc_data['Date']== row['DT_Peak_1']].index
|
| 475 |
+
DT_Peak_2_idx = ohlc_data[ohlc_data['Date']== row['DT_Peak_2']].index
|
| 476 |
+
DT_Valley_idx = ohlc_data[ohlc_data['Date']== row['DT_Valley']].index
|
| 477 |
+
|
| 478 |
+
# draw the double peaks
|
| 479 |
+
ax.plot([DT_Peak_1_idx,DT_Valley_idx, DT_Peak_2_idx], [high_prices[DT_Peak_1_idx],high_prices[DT_Valley_idx], high_prices[DT_Peak_2_idx]],
|
| 480 |
+
linestyle="solid", marker="o", color="blue", linewidth=1, label="Double Top Pattern")
|
| 481 |
+
# Draw the neckline
|
| 482 |
+
ax.hlines(y=low_prices[DT_Valley_idx], xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], color='red', linestyle='dotted', linewidth=1)
|
| 483 |
+
|
| 484 |
+
def draw_double_bottom_aa(ax, ohlc_data, pat_start_idx,row):
|
| 485 |
+
|
| 486 |
+
Draws a Double Bottom pattern on an existing mplfinance plot and visualizes detected peaks and valleys.
|
| 487 |
+
|
| 488 |
+
Parameters:
|
| 489 |
+
ax (matplotlib.axes.Axes): The candlestick chart's axis.
|
| 490 |
+
ohlc_data (pd.DataFrame): Data containing 'High' and 'Low' columns.
|
| 491 |
+
|
| 492 |
+
# reset the index of the ohlc_data
|
| 493 |
+
ohlc_data.reset_index(drop=True, inplace=True)
|
| 494 |
+
high_prices = ohlc_data['High'].values
|
| 495 |
+
low_prices = ohlc_data['Low'].values
|
| 496 |
+
|
| 497 |
+
# check if 'Peak_Dates' and 'Valley_Dates' columns are present in the row
|
| 498 |
+
if 'Peak_Dates' in row and 'Valley_Dates' in row:
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
peak_days = row['Peak_Dates']
|
| 502 |
+
valley_days = row['Valley_Dates']
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
peak_indices = ohlc_data[ohlc_data['Date'].isin(peak_days)].index
|
| 506 |
+
# add the pat_start_idx to the peak_indices
|
| 507 |
+
peak_indices = peak_indices
|
| 508 |
+
|
| 509 |
+
valley_indices = ohlc_data[ohlc_data['Date'].isin(valley_days)].index
|
| 510 |
+
# add the pat_start_idx to the valley_indices
|
| 511 |
+
valley_indices = valley_indices
|
| 512 |
+
|
| 513 |
+
# Debugging visualization: Plot detected peaks and valleys
|
| 514 |
+
ax.scatter(peak_indices , high_prices[peak_indices], color='green', marker='^', label='Peaks', zorder=3)
|
| 515 |
+
ax.scatter(valley_indices, low_prices[valley_indices], color='red', marker='v', label='Valleys', zorder=3)
|
| 516 |
+
|
| 517 |
+
DB_Valley_1_idx = ohlc_data[ohlc_data['Date']== row['DB_Valley_1']].index
|
| 518 |
+
DB_Valley_2_idx = ohlc_data[ohlc_data['Date']== row['DB_Valley_2']].index
|
| 519 |
+
DB_Peak_idx = ohlc_data[ohlc_data['Date']== row['DB_Peak']].index
|
| 520 |
+
|
| 521 |
+
# draw the double peaks
|
| 522 |
+
ax.plot([DB_Valley_1_idx,DB_Peak_idx, DB_Valley_2_idx], [low_prices[DB_Valley_1_idx],low_prices[DB_Peak_idx], low_prices[DB_Valley_2_idx]],
|
| 523 |
+
linestyle="solid", marker="o", color="blue", linewidth=1, label="Double Bottom Pattern")
|
| 524 |
+
# Draw the neckline
|
| 525 |
+
ax.hlines(y=high_prices[DB_Peak_idx], xmin=ax.get_xlim()[0], xmax=ax.get_xlim()[1], color='red', linestyle='dotted', linewidth=1)
|
| 526 |
+
|
| 527 |
+
def plot_pattern_clusters( test_pattern_segment_wise, ohcl_data_given=None, padding_days=0,draw_lines = False):
|
| 528 |
+
colors = ["blue", "green", "red", "cyan", "magenta", "yellow", "purple", "orange", "brown", "pink", "lime", "teal"]
|
| 529 |
+
|
| 530 |
+
group = test_pattern_segment_wise
|
| 531 |
+
|
| 532 |
+
if ohcl_data_given is None:
|
| 533 |
+
symbol = group['Symbol'].iloc[0]
|
| 534 |
+
ohcl_data = pd.read_csv(path + '/' + symbol + '.csv')
|
| 535 |
+
else:
|
| 536 |
+
ohcl_data = ohcl_data_given
|
| 537 |
+
|
| 538 |
+
ohcl_data['Date'] = pd.to_datetime(ohcl_data['Date'])
|
| 539 |
+
ohcl_data['Date'] = ohcl_data['Date'].dt.tz_localize(None)
|
| 540 |
+
|
| 541 |
+
seg_start = group['Seg_Start'].iloc[0] - pd.to_timedelta(padding_days, unit='D')
|
| 542 |
+
seg_end = group['Seg_End'].iloc[0] + pd.to_timedelta(padding_days, unit='D')
|
| 543 |
+
|
| 544 |
+
ohcl_data = ohcl_data[(ohcl_data['Date'] >= seg_start) & (ohcl_data['Date'] <= seg_end)]
|
| 545 |
+
if ohcl_data.empty:
|
| 546 |
+
print("OHLC Data set is empty")
|
| 547 |
+
return
|
| 548 |
+
|
| 549 |
+
ohlc_for_mpf = ohcl_data[['Open', 'High', 'Low', 'Close']].copy()
|
| 550 |
+
ohlc_for_mpf.index = pd.to_datetime(ohcl_data['Date'])
|
| 551 |
+
|
| 552 |
+
fig, axes = mpf.plot(ohlc_for_mpf, type='candle', style='charles', datetime_format='%Y-%m-%d', returnfig=True)
|
| 553 |
+
ax = axes[0]
|
| 554 |
+
|
| 555 |
+
for _, row in group.iterrows():
|
| 556 |
+
pattern_name = row['Chart Pattern']
|
| 557 |
+
cluster = row['Cluster']
|
| 558 |
+
color = "gray" if cluster == -1 else colors[cluster % len(colors)]
|
| 559 |
+
|
| 560 |
+
pattern_start_date = pd.to_datetime(row['Start']).tz_localize(None)
|
| 561 |
+
pattern_end_date = pd.to_datetime(row['End']).tz_localize(None)
|
| 562 |
+
|
| 563 |
+
num_start = len(ohcl_data[ohcl_data['Date'] < pattern_start_date])
|
| 564 |
+
num_end = num_start + len(ohcl_data[(ohcl_data['Date'] >= pattern_start_date) & (ohcl_data['Date'] <= pattern_end_date)])
|
| 565 |
+
|
| 566 |
+
ax.axvspan(num_start, num_end, color=color, alpha=0.1, label=pattern_name)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
if draw_lines:
|
| 570 |
+
# error = row['Error'] check only if the column is present
|
| 571 |
+
error = False
|
| 572 |
+
if 'Error' in row and row['Error'] != np.nan:
|
| 573 |
+
error = row['Error']
|
| 574 |
+
if error != True:
|
| 575 |
+
calc_start_date = row['Calc_Start']
|
| 576 |
+
calc_end_date = row['Calc_End']
|
| 577 |
+
|
| 578 |
+
# reset the index of the ohlc_data
|
| 579 |
+
ohcl_data.reset_index(drop=True, inplace=True)
|
| 580 |
+
|
| 581 |
+
calc_start_idx = ohcl_data[ohcl_data['Date']== calc_start_date].index
|
| 582 |
+
calc_end_idx = ohcl_data[ohcl_data['Date']== calc_end_date].index
|
| 583 |
+
|
| 584 |
+
# drow a pink dotted vertical line at calc_start_idx and calc_end_idx
|
| 585 |
+
ax.axvline(x=calc_start_idx, color='blue', linestyle='dotted', linewidth=1)
|
| 586 |
+
ax.axvline(x=calc_end_idx, color='blue', linestyle='dotted', linewidth=1)
|
| 587 |
+
|
| 588 |
+
# # If detected pattern is Head and Shoulders, plot indicator lines
|
| 589 |
+
# if pattern_name == "Head-and-shoulders top":
|
| 590 |
+
# # get the ohlc segment of where the date is between the pattern start and end from ohlc_for_mpf data set where the index is the date
|
| 591 |
+
# ohlc_segment_head_and_sholder = ohlc_for_mpf.loc[pattern_start_date:pattern_end_date]
|
| 592 |
+
# draw_head_and_shoulders_top(ax, ohcl_data, num_start,row)
|
| 593 |
+
# elif pattern_name == "Head-and-shoulders bottom":
|
| 594 |
+
# # get the ohlc segment of where the date is between the pattern start and end from ohlc_for_mpf data set where the index is the date
|
| 595 |
+
# ohlc_segment_head_and_sholder = ohlc_for_mpf.loc[pattern_start_date:pattern_end_date]
|
| 596 |
+
# draw_head_and_shoulders_bottom(ax, ohcl_data, num_start,row)
|
| 597 |
+
# elif pattern_name == "Double Top, Adam and Adam":
|
| 598 |
+
# # get the ohlc segment of where the date is between the pattern start and end from ohlc_for_mpf data set where the index is the date
|
| 599 |
+
# ohlc_segment_double_top = ohlc_for_mpf.loc[pattern_start_date:pattern_end_date]
|
| 600 |
+
# draw_double_top_aa(ax, ohcl_data, num_start,row)
|
| 601 |
+
# elif pattern_name == "Double Bottom, Adam and Adam":
|
| 602 |
+
# ohlc_segment_double_top = ohlc_for_mpf.loc[pattern_start_date:pattern_end_date]
|
| 603 |
+
# draw_double_bottom_aa(ax, ohcl_data, num_start,row)
|
| 604 |
+
# elif pattern_name == "Double Bottom, Eve and Adam":
|
| 605 |
+
# ohlc_segment_double_top = ohlc_for_mpf.loc[pattern_start_date:pattern_end_date]
|
| 606 |
+
# draw_double_bottom_aa(ax, ohcl_data, num_start,row)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
if draw_lines:
|
| 610 |
+
# Get unique legend handles and labels
|
| 611 |
+
handles, labels = ax.get_legend_handles_labels()
|
| 612 |
+
unique_labels = {}
|
| 613 |
+
unique_handles = []
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
# Initialize storage for unique handles/labels
|
| 618 |
+
unique_labels = {}
|
| 619 |
+
unique_handles = []
|
| 620 |
+
i= 1
|
| 621 |
+
|
| 622 |
+
for handle, label in zip(handles, labels):
|
| 623 |
+
# print(label)
|
| 624 |
+
|
| 625 |
+
# Allow duplication if the label is in pattern_encoding
|
| 626 |
+
if label in pattern_encoding or label not in unique_labels:
|
| 627 |
+
if label not in unique_labels:
|
| 628 |
+
unique_labels[label] = handle
|
| 629 |
+
unique_handles.append(handle)
|
| 630 |
+
else:
|
| 631 |
+
unique_labels[label + f"_{i}"] = handle
|
| 632 |
+
unique_handles.append(handle)
|
| 633 |
+
i += 1
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
ax.legend(unique_handles, unique_labels.keys())
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
ax.grid(True)
|
| 641 |
+
plt.show()
|
| 642 |
+
|
| 643 |
+
def plot_pattern_groups_and_finalized_sections(located_patterns_and_other_info, cluster_labled_windows_df ,ohcl_data_given=None):
|
| 644 |
+
# for each unique Chart Pattern in located_patterns_and_other_info plot the patterns
|
| 645 |
+
for pattern, group in located_patterns_and_other_info.groupby('Chart Pattern'):
|
| 646 |
+
# pattern = 'Head-and-shoulders top'
|
| 647 |
+
print (pattern ," :")
|
| 648 |
+
print(" Clustered Windows :")
|
| 649 |
+
plot_pattern_clusters( cluster_labled_windows_df[cluster_labled_windows_df['Chart Pattern'] == pattern],ohcl_data_given=ohcl_data_given)
|
| 650 |
+
print(" Finalized Section :")
|
| 651 |
+
plot_pattern_clusters( located_patterns_and_other_info[located_patterns_and_other_info['Chart Pattern'] == pattern],draw_lines=True,ohcl_data_given=ohcl_data_given)
|
| 652 |
+
"""
|
| 653 |
+
|
utils/patternLocating.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
from utils.eval import intersection_over_union
|
| 4 |
+
from utils.formatAndPreprocessNewPatterns import get_patetrn_name_by_encoding, get_pattern_encoding_by_name, get_reverse_pattern_encoding
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
from joblib import Parallel, delayed
|
| 8 |
+
import math
|
| 9 |
+
from sklearn.cluster import DBSCAN
|
| 10 |
+
|
| 11 |
+
path = 'Datasets/OHLC data'
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def process_window(i, ohlc_data_segment, rocket_model, probability_threshold, pattern_encoding_reversed,seg_start, seg_end, window_size, padding_proportion,prob_threshold_of_no_pattern_to_mark_as_no_pattern=1):
|
| 16 |
+
start_index = i - math.ceil(window_size * padding_proportion)
|
| 17 |
+
end_index = start_index + window_size
|
| 18 |
+
|
| 19 |
+
start_index = max(start_index, 0)
|
| 20 |
+
end_index = min(end_index, len(ohlc_data_segment))
|
| 21 |
+
|
| 22 |
+
ohlc_segment = ohlc_data_segment[start_index:end_index]
|
| 23 |
+
if len(ohlc_segment) == 0:
|
| 24 |
+
return None # Skip empty segments
|
| 25 |
+
win_start_date = ohlc_segment['Date'].iloc[0]
|
| 26 |
+
win_end_date = ohlc_segment['Date'].iloc[-1]
|
| 27 |
+
|
| 28 |
+
# print("ohlc befor :" , ohlc_segment)
|
| 29 |
+
ohlc_array_for_rocket = ohlc_segment[['Open', 'High', 'Low', 'Close','Volume']].to_numpy().reshape(1, len(ohlc_segment), 5)
|
| 30 |
+
ohlc_array_for_rocket = np.transpose(ohlc_array_for_rocket, (0, 2, 1))
|
| 31 |
+
# print( "ohlc for rocket :" , ohlc_array_for_rocket)
|
| 32 |
+
try:
|
| 33 |
+
pattern_probabilities = rocket_model.predict_proba(ohlc_array_for_rocket)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"Error in prediction: {e}")
|
| 36 |
+
return None
|
| 37 |
+
max_probability = np.max(pattern_probabilities)
|
| 38 |
+
# print(pattern_probabilities)
|
| 39 |
+
# print(f"Predicted Pattern: {pattern_encoding_reversed[np.argmax(pattern_probabilities)]} with probability: {max_probability} in num {i} window")
|
| 40 |
+
# if max_probability > probability_threshold:
|
| 41 |
+
no_pattern_proba = pattern_probabilities[0][get_pattern_encoding_by_name ('No Pattern')]
|
| 42 |
+
pattern_index = np.argmax(pattern_probabilities)
|
| 43 |
+
|
| 44 |
+
pred_proba = max_probability
|
| 45 |
+
pred_pattern = get_patetrn_name_by_encoding(pattern_index)
|
| 46 |
+
if no_pattern_proba > prob_threshold_of_no_pattern_to_mark_as_no_pattern:
|
| 47 |
+
pred_proba = no_pattern_proba
|
| 48 |
+
pred_pattern = 'No Pattern'
|
| 49 |
+
|
| 50 |
+
new_row = {
|
| 51 |
+
'Start': win_start_date, 'End': win_end_date, 'Chart Pattern': pred_pattern, 'Seg_Start': seg_start, 'Seg_End': seg_end ,
|
| 52 |
+
'Probability': pred_proba
|
| 53 |
+
}
|
| 54 |
+
# plot_patterns_for_segment(test_seg_id, pd.DataFrame([new_row]), ohlc_data_segment)
|
| 55 |
+
return new_row
|
| 56 |
+
# return None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def parallel_process_sliding_window(ohlc_data_segment, rocket_model, probability_threshold, stride, pattern_encoding_reversed, window_size, padding_proportion,prob_threshold_of_no_pattern_to_mark_as_no_pattern=1,parallel=True,num_cores=-1):
|
| 61 |
+
# get the start and end dates of the ohlc data
|
| 62 |
+
seg_start = ohlc_data_segment['Date'].iloc[0]
|
| 63 |
+
seg_end = ohlc_data_segment['Date'].iloc[-1]
|
| 64 |
+
|
| 65 |
+
if parallel:
|
| 66 |
+
# Use Parallel as a context manager to ensure cleanup
|
| 67 |
+
with Parallel(n_jobs=num_cores,verbose = 1) as parallel:
|
| 68 |
+
results = parallel(
|
| 69 |
+
delayed(process_window)(
|
| 70 |
+
i=i,
|
| 71 |
+
ohlc_data_segment=ohlc_data_segment,
|
| 72 |
+
rocket_model=rocket_model,
|
| 73 |
+
probability_threshold=probability_threshold,
|
| 74 |
+
pattern_encoding_reversed=pattern_encoding_reversed,
|
| 75 |
+
window_size=window_size,
|
| 76 |
+
seg_start=seg_start,
|
| 77 |
+
seg_end=seg_end,
|
| 78 |
+
padding_proportion=padding_proportion,
|
| 79 |
+
prob_threshold_of_no_pattern_to_mark_as_no_pattern=prob_threshold_of_no_pattern_to_mark_as_no_pattern
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
for i in range(0, len(ohlc_data_segment), stride)
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# print(f"Finished processing segment {seg_id} for symbol {symbol}")
|
| 86 |
+
# print(results)
|
| 87 |
+
# Filter out None values and create DataFrame
|
| 88 |
+
return pd.DataFrame([res for res in results if res is not None])
|
| 89 |
+
else:
|
| 90 |
+
|
| 91 |
+
# do the sam e thing without parrellel processing
|
| 92 |
+
results = []
|
| 93 |
+
total_iterations = len(range(0, len(ohlc_data_segment), stride))
|
| 94 |
+
for i_idx, i in enumerate(range(0, len(ohlc_data_segment), stride)):
|
| 95 |
+
res = process_window(i, ohlc_data_segment, rocket_model, probability_threshold, pattern_encoding_reversed, seg_start, seg_end, window_size, padding_proportion)
|
| 96 |
+
if res is not None:
|
| 97 |
+
results.append(res)
|
| 98 |
+
# Progress print statement
|
| 99 |
+
print(f"Processing window {i_idx + 1} of {total_iterations}...")
|
| 100 |
+
return pd.DataFrame(results)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def prepare_dataset_for_cluster(ohlc_data_segment, win_results_df):
|
| 104 |
+
|
| 105 |
+
predicted_patterns = win_results_df.copy()
|
| 106 |
+
origin_date = ohlc_data_segment['Date'].min()
|
| 107 |
+
for index, row in predicted_patterns.iterrows():
|
| 108 |
+
pattern_start = row['Start']
|
| 109 |
+
pattern_end = row['End']
|
| 110 |
+
|
| 111 |
+
# get the number of OHLC data points from the origin date to the pattern start date
|
| 112 |
+
start_point_index = len(ohlc_data_segment[ohlc_data_segment['Date'] < pattern_start])
|
| 113 |
+
pattern_len = len(ohlc_data_segment[(ohlc_data_segment['Date'] >= pattern_start) & (ohlc_data_segment['Date'] <= pattern_end)])
|
| 114 |
+
|
| 115 |
+
pattern_mid_index = start_point_index + (pattern_len / 2)
|
| 116 |
+
|
| 117 |
+
# add the center index to a new column Center in the predicted_patterns current row
|
| 118 |
+
predicted_patterns.at[index, 'Center'] = pattern_mid_index
|
| 119 |
+
predicted_patterns.at[index, 'Pattern_Start_pos'] = start_point_index
|
| 120 |
+
predicted_patterns.at[index, 'Pattern_End_pos'] = start_point_index + pattern_len
|
| 121 |
+
|
| 122 |
+
return predicted_patterns
|
| 123 |
+
|
| 124 |
+
def cluster_windows(predicted_patterns , probability_threshold, window_size,eps = 0.05 , min_samples = 2):
|
| 125 |
+
df = predicted_patterns.copy()
|
| 126 |
+
|
| 127 |
+
# check if the probability_threshold is a list or a float
|
| 128 |
+
if isinstance(probability_threshold, list):
|
| 129 |
+
# the list contain the probability thresholds for each chart pattern
|
| 130 |
+
# filter the dataframe for each probability threshold
|
| 131 |
+
for i in range(len(probability_threshold)):
|
| 132 |
+
pattern_name = get_patetrn_name_by_encoding(i)
|
| 133 |
+
df.drop(df[(df['Chart Pattern'] == pattern_name) & (df['Probability'] < probability_threshold[i])].index, inplace=True)
|
| 134 |
+
# print(f"Filtered {pattern_name} with probability < {probability_threshold[i]}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
else:
|
| 138 |
+
# only get the rows that has a probability greater than the probability threshold
|
| 139 |
+
df = df[df['Probability'] > probability_threshold]
|
| 140 |
+
|
| 141 |
+
# Initialize a list to store merged clusters from all groups
|
| 142 |
+
cluster_labled_windows = []
|
| 143 |
+
interseced_clusters = []
|
| 144 |
+
|
| 145 |
+
min_center = df['Center'].min()
|
| 146 |
+
max_center = df['Center'].max()
|
| 147 |
+
|
| 148 |
+
# Group by 'Chart Pattern' and apply clustering to each group
|
| 149 |
+
for pattern, group in df.groupby('Chart Pattern'):
|
| 150 |
+
# print (pattern)
|
| 151 |
+
# print(group)
|
| 152 |
+
# Clustering
|
| 153 |
+
centers = group['Center'].values.reshape(-1, 1)
|
| 154 |
+
|
| 155 |
+
# centers normalization
|
| 156 |
+
if min_center < max_center: # Avoid division by zero
|
| 157 |
+
norm_centers = (centers - min_center) / (max_center - min_center)
|
| 158 |
+
else:
|
| 159 |
+
# If all values are the same, set to constant (e.g., 0 or 1)
|
| 160 |
+
norm_centers = np.ones_like(centers)
|
| 161 |
+
|
| 162 |
+
# eps =window_size/2 + 4
|
| 163 |
+
db = DBSCAN(eps=eps, min_samples=min_samples).fit(norm_centers)
|
| 164 |
+
group['Cluster'] = db.labels_
|
| 165 |
+
|
| 166 |
+
cluster_labled_windows.append(group)
|
| 167 |
+
|
| 168 |
+
# Filter out noise (-1) and group by Cluster
|
| 169 |
+
for cluster_id, cluster_group in group[group['Cluster'] != -1].groupby('Cluster'):
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
expanded_dates = []
|
| 173 |
+
for _, row in cluster_group.iterrows():
|
| 174 |
+
# Print the start and end dates for debugging
|
| 175 |
+
dates = pd.date_range(row["Start"], row["End"])
|
| 176 |
+
expanded_dates.extend(dates)
|
| 177 |
+
|
| 178 |
+
# print("Total expanded dates:", len(expanded_dates))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# Step 2: Count occurrences of each date
|
| 182 |
+
date_counts = pd.Series(expanded_dates).value_counts().sort_index()
|
| 183 |
+
|
| 184 |
+
# Step 3: Identify cluster start and end (where at least 2 windows overlap)
|
| 185 |
+
cluster_start = date_counts[date_counts >= 2].index.min()
|
| 186 |
+
cluster_end = date_counts[date_counts >= 2].index.max()
|
| 187 |
+
|
| 188 |
+
interseced_clusters.append({
|
| 189 |
+
# 'Seg_ID' : cluster_group['Seg_ID'].iloc[0],
|
| 190 |
+
# 'Symbol' : cluster_group['Symbol'].iloc[0],
|
| 191 |
+
'Chart Pattern': pattern,
|
| 192 |
+
'Cluster': cluster_id,
|
| 193 |
+
'Start': cluster_start,
|
| 194 |
+
'End': cluster_end,
|
| 195 |
+
'Seg_Start': cluster_group['Seg_Start'].iloc[0],
|
| 196 |
+
'Seg_End': cluster_group['Seg_End'].iloc[0],
|
| 197 |
+
'Avg_Probability': cluster_group['Probability'].mean(),
|
| 198 |
+
})
|
| 199 |
+
|
| 200 |
+
if len(cluster_labled_windows) == 0 or len(interseced_clusters) == 0:
|
| 201 |
+
return None,None
|
| 202 |
+
# # Combine all merged clusters into a final DataFrame
|
| 203 |
+
cluster_labled_windows_df = pd.concat(cluster_labled_windows)
|
| 204 |
+
interseced_clusters_df = pd.DataFrame(interseced_clusters)
|
| 205 |
+
|
| 206 |
+
# sort by the index
|
| 207 |
+
cluster_labled_windows_df = cluster_labled_windows_df.sort_index()
|
| 208 |
+
# print(cluster_labled_windows_df)
|
| 209 |
+
# Display the result
|
| 210 |
+
# print(merged_df)
|
| 211 |
+
return cluster_labled_windows_df,interseced_clusters_df
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# =========================Advance Locator ==========================
|
| 215 |
+
|
| 216 |
+
pattern_encoding_reversed = get_reverse_pattern_encoding()
|
| 217 |
+
# load the joblib model at Models\Width Aug OHLC_mini_rocket_xgb.joblib to use
|
| 218 |
+
model = joblib.load('Models/Width Aug OHLC_mini_rocket_xgb.joblib')
|
| 219 |
+
plot_count = 0
|
| 220 |
+
|
| 221 |
+
win_size_proportions = np.round(np.logspace(0, np.log10(20), num=10), 2).tolist()
|
| 222 |
+
padding_proportion = 0.6
|
| 223 |
+
stride = 1
|
| 224 |
+
probab_threshold_list = 0.5
|
| 225 |
+
prob_threshold_of_no_pattern_to_mark_as_no_pattern = 0.5
|
| 226 |
+
target_len = 30
|
| 227 |
+
|
| 228 |
+
eps=0.04 # in the dbscan clustering
|
| 229 |
+
min_samples=3 # in the dbscan clustering
|
| 230 |
+
win_width_proportion=10 # in the dbscan clustering from what amount to divide the width related feature
|
| 231 |
+
|
| 232 |
+
def locate_patterns(ohlc_data, patterns_to_return= None,model = model , pattern_encoding_reversed= pattern_encoding_reversed,plot_count = 10):
|
| 233 |
+
ohlc_data_segment = ohlc_data.copy()
|
| 234 |
+
# convert date to datetime
|
| 235 |
+
ohlc_data_segment['Date'] = pd.to_datetime(ohlc_data_segment['Date'])
|
| 236 |
+
seg_len = len(ohlc_data_segment)
|
| 237 |
+
|
| 238 |
+
if ohlc_data_segment is None or len(ohlc_data_segment) == 0:
|
| 239 |
+
print("OHLC Data segment is empty")
|
| 240 |
+
raise Exception("OHLC Data segment is empty")
|
| 241 |
+
|
| 242 |
+
win_results_for_each_size = []
|
| 243 |
+
located_patterns_and_other_info_for_each_size = []
|
| 244 |
+
cluster_labled_windows_list = []
|
| 245 |
+
|
| 246 |
+
used_win_sizes = []
|
| 247 |
+
win_iteration = 0
|
| 248 |
+
|
| 249 |
+
for win_size_proportion in win_size_proportions:
|
| 250 |
+
window_size = seg_len // win_size_proportion
|
| 251 |
+
# print(f"Win size : {window_size}")
|
| 252 |
+
if window_size < 10:
|
| 253 |
+
window_size = 10
|
| 254 |
+
# elif window_size > 30:
|
| 255 |
+
# window_size = 30
|
| 256 |
+
|
| 257 |
+
# convert to int
|
| 258 |
+
window_size = int(window_size)
|
| 259 |
+
if window_size in used_win_sizes:
|
| 260 |
+
continue
|
| 261 |
+
used_win_sizes.append(window_size)
|
| 262 |
+
|
| 263 |
+
# win_results_df = parallel_process_sliding_window(ohlc_data_segment, model, probability_threshold,stride, pattern_encoding_reversed,group,test_seg_id,window_size, padding_proportion, len_norm, target_len)
|
| 264 |
+
win_results_df = parallel_process_sliding_window(ohlc_data_segment, model, probab_threshold_list,stride, pattern_encoding_reversed,window_size, padding_proportion,prob_threshold_of_no_pattern_to_mark_as_no_pattern,parallel=True)
|
| 265 |
+
|
| 266 |
+
if win_results_df is None or len(win_results_df) == 0:
|
| 267 |
+
print("Window results dataframe is empty")
|
| 268 |
+
continue
|
| 269 |
+
win_results_df['Window_Size'] = window_size
|
| 270 |
+
win_results_for_each_size.append(win_results_df)
|
| 271 |
+
# plot_sliding_steps(win_results_df ,ohlc_data_segment,probability_threshold ,test_seg_id)
|
| 272 |
+
predicted_patterns = prepare_dataset_for_cluster(ohlc_data_segment, win_results_df)
|
| 273 |
+
if predicted_patterns is None or len(predicted_patterns) == 0:
|
| 274 |
+
print("Predicted patterns dataframe is empty")
|
| 275 |
+
# print("Predicted Patterns :",predicted_patterns)
|
| 276 |
+
# cluster_labled_windows_df , interseced_clusters_df = cluster_windows(predicted_patterns, probability_threshold, window_size)
|
| 277 |
+
cluster_labled_windows_df , interseced_clusters_df = cluster_windows(predicted_patterns, probab_threshold_list, window_size)
|
| 278 |
+
if cluster_labled_windows_df is None or interseced_clusters_df is None or len(cluster_labled_windows_df) == 0 or len(interseced_clusters_df) == 0:
|
| 279 |
+
print("Clustered windows dataframe is empty")
|
| 280 |
+
continue
|
| 281 |
+
mask = cluster_labled_windows_df['Cluster'] != -1
|
| 282 |
+
cluster_labled_windows_df.loc[mask, 'Cluster'] = cluster_labled_windows_df.loc[mask, 'Cluster'].astype(int) + win_iteration
|
| 283 |
+
# mask2 = interseced_clusters_df['Cluster'] != -1
|
| 284 |
+
interseced_clusters_df['Cluster'] = interseced_clusters_df['Cluster'].astype(int) + win_iteration
|
| 285 |
+
num_of_unique_clusters = interseced_clusters_df[interseced_clusters_df['Cluster']!=-1]['Cluster'].nunique()
|
| 286 |
+
win_iteration += num_of_unique_clusters
|
| 287 |
+
cluster_labled_windows_list.append(cluster_labled_windows_df)
|
| 288 |
+
# located_patterns_and_other_info = functional_pattern_filter_and_point_recognition(interseced_clusters_df)
|
| 289 |
+
interseced_clusters_df['Calc_Start'] = interseced_clusters_df['Start']
|
| 290 |
+
interseced_clusters_df['Calc_End'] = interseced_clusters_df['End']
|
| 291 |
+
located_patterns_and_other_info = interseced_clusters_df.copy()
|
| 292 |
+
|
| 293 |
+
if located_patterns_and_other_info is None or len(located_patterns_and_other_info) == 0:
|
| 294 |
+
print("]Located patterns and other info dataframe is empty")
|
| 295 |
+
continue
|
| 296 |
+
# Remove plotting call
|
| 297 |
+
# plot_pattern_groups_and_finalized_sections(located_patterns_and_other_info, cluster_labled_windows_df, test_seg_id)
|
| 298 |
+
located_patterns_and_other_info['Window_Size'] = window_size
|
| 299 |
+
|
| 300 |
+
located_patterns_and_other_info_for_each_size.append(located_patterns_and_other_info)
|
| 301 |
+
|
| 302 |
+
if located_patterns_and_other_info_for_each_size is None or len(located_patterns_and_other_info_for_each_size) == 0 or win_results_for_each_size is None or len(win_results_for_each_size) == 0:
|
| 303 |
+
print("Located patterns and other info for each size is empty")
|
| 304 |
+
return None
|
| 305 |
+
located_patterns_and_other_info_for_each_size_df = pd.concat(located_patterns_and_other_info_for_each_size)
|
| 306 |
+
win_results_for_each_size_df = pd.concat(win_results_for_each_size, ignore_index=True)
|
| 307 |
+
# window_results_list.append(win_results_for_each_size_df)
|
| 308 |
+
|
| 309 |
+
# get the set of unique window sizes from located_patterns_and_other_info_for_each_size_df
|
| 310 |
+
unique_window_sizes = located_patterns_and_other_info_for_each_size_df['Window_Size'].unique()
|
| 311 |
+
unique_patterns = located_patterns_and_other_info_for_each_size_df['Chart Pattern'].unique()
|
| 312 |
+
|
| 313 |
+
# sort the unique_window_sizes descending order
|
| 314 |
+
unique_window_sizes = np.sort(unique_window_sizes)[::-1]
|
| 315 |
+
|
| 316 |
+
filtered_loc_pat_and_info_rows_list = []
|
| 317 |
+
|
| 318 |
+
for chart_pattern in unique_patterns:
|
| 319 |
+
located_patterns_and_other_info_for_each_size_df_chart_pattern = located_patterns_and_other_info_for_each_size_df[located_patterns_and_other_info_for_each_size_df['Chart Pattern'] == chart_pattern]
|
| 320 |
+
for win_size in unique_window_sizes:
|
| 321 |
+
located_patterns_and_other_info_for_each_size_df_win_size_chart_pattern = located_patterns_and_other_info_for_each_size_df_chart_pattern[located_patterns_and_other_info_for_each_size_df_chart_pattern['Window_Size'] == win_size]
|
| 322 |
+
for idx , row in located_patterns_and_other_info_for_each_size_df_win_size_chart_pattern.iterrows():
|
| 323 |
+
start_date = row['Calc_Start']
|
| 324 |
+
end_date = row['Calc_End']
|
| 325 |
+
is_already_included = False
|
| 326 |
+
# check if there are any other rows that intersect with the start and end dates with the same chart pattern
|
| 327 |
+
intersecting_rows = located_patterns_and_other_info_for_each_size_df_chart_pattern[
|
| 328 |
+
(located_patterns_and_other_info_for_each_size_df_chart_pattern['Calc_Start'] <= end_date) &
|
| 329 |
+
(located_patterns_and_other_info_for_each_size_df_chart_pattern['Calc_End'] >= start_date)
|
| 330 |
+
]
|
| 331 |
+
is_already_included = False
|
| 332 |
+
for idx2, row2 in intersecting_rows.iterrows():
|
| 333 |
+
iou = intersection_over_union(start_date, end_date, row2['Calc_Start'], row2['Calc_End'])
|
| 334 |
+
|
| 335 |
+
if iou > 0.6:
|
| 336 |
+
# Case 1: Larger window already exists
|
| 337 |
+
if row2['Window_Size'] > row['Window_Size']:
|
| 338 |
+
# Case 1A: But smaller one has significantly higher probability, keep it instead
|
| 339 |
+
if (row['Avg_Probability'] - row2['Avg_Probability']) > 0.1:
|
| 340 |
+
is_already_included = False
|
| 341 |
+
else:
|
| 342 |
+
is_already_included = True
|
| 343 |
+
break # Keep large, skip current(small)
|
| 344 |
+
|
| 345 |
+
# Case 2: Equal or smaller window exists, possibly overlapping
|
| 346 |
+
elif row['Window_Size'] >= row2['Window_Size']:
|
| 347 |
+
# If current row has significantly better probability, replace existing
|
| 348 |
+
if (row2['Avg_Probability'] - row['Avg_Probability']) > 0.1:
|
| 349 |
+
is_already_included = True
|
| 350 |
+
break # remove current (large) , keep small
|
| 351 |
+
else:
|
| 352 |
+
is_already_included = False
|
| 353 |
+
# break
|
| 354 |
+
|
| 355 |
+
if not is_already_included:
|
| 356 |
+
filtered_loc_pat_and_info_rows_list.append(row)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# convert the filtered_loc_pat_and_info_rows_list to a dataframe
|
| 360 |
+
filtered_loc_pat_and_info_df = pd.DataFrame(filtered_loc_pat_and_info_rows_list)
|
| 361 |
+
# located_patterns_and_other_info_list.append(filtered_loc_pat_and_info_df)
|
| 362 |
+
|
| 363 |
+
if cluster_labled_windows_list is None or len(cluster_labled_windows_list) == 0:
|
| 364 |
+
print("Clustered windows list is empty")
|
| 365 |
+
cluster_labled_windows_df_conc = pd.concat(cluster_labled_windows_list)
|
| 366 |
+
# Remove plotting code
|
| 367 |
+
"""
|
| 368 |
+
if plot_count > 0:
|
| 369 |
+
plot_pattern_groups_and_finalized_sections(filtered_loc_pat_and_info_df, cluster_labled_windows_df_conc,ohcl_data_given=ohlc_data_segment)
|
| 370 |
+
plot_count -= 1
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
if patterns_to_return is None or len(patterns_to_return) == 0:
|
| 374 |
+
return filtered_loc_pat_and_info_df
|
| 375 |
+
else:
|
| 376 |
+
# filter the filtered_loc_pat_and_info_df based on the patterns_to_return
|
| 377 |
+
filtered_loc_pat_and_info_df = filtered_loc_pat_and_info_df[filtered_loc_pat_and_info_df['Chart Pattern'].isin(patterns_to_return)]
|
| 378 |
+
return filtered_loc_pat_and_info_df
|
| 379 |
+
|
| 380 |
+
|
utils/patternLocatingGemni.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import joblib
|
| 2 |
+
from utils.eval import intersection_over_union
|
| 3 |
+
from utils.formatAndPreprocessNewPatterns import get_patetrn_name_by_encoding, get_pattern_encoding_by_name, get_reverse_pattern_encoding
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import math
|
| 7 |
+
from sklearn.cluster import DBSCAN
|
| 8 |
+
from joblib import Parallel, delayed
|
| 9 |
+
# Remove matplotlib imports and plotting function import
|
| 10 |
+
# import matplotlib.pyplot as plt
|
| 11 |
+
# from utils.functionalPatternLocateAndPlot import plot_pattern_groups_and_finalized_sections
|
| 12 |
+
|
| 13 |
+
# --- Global Configuration & Model Loading ---
|
| 14 |
+
# Load the pre-trained model and pattern encodings
|
| 15 |
+
# It's assumed 'Models/Width Aug OHLC_mini_rocket_xgb.joblib' is in the correct path
|
| 16 |
+
MODEL_PATH = 'Models/Width Aug OHLC_mini_rocket_xgb.joblib'
|
| 17 |
+
try:
|
| 18 |
+
rocket_model_global = joblib.load(MODEL_PATH)
|
| 19 |
+
except FileNotFoundError:
|
| 20 |
+
print(f"Error: Model file not found at {MODEL_PATH}. Please ensure the path is correct.")
|
| 21 |
+
# You might want to exit or raise an exception here if the model is critical
|
| 22 |
+
rocket_model_global = None
|
| 23 |
+
|
| 24 |
+
pattern_encoding_reversed_global = get_reverse_pattern_encoding()
|
| 25 |
+
|
| 26 |
+
# Default parameters for the pattern location logic
|
| 27 |
+
WIN_SIZE_PROPORTIONS = np.round(np.logspace(0, np.log10(20), num=10), 2).tolist()
|
| 28 |
+
PADDING_PROPORTION = 0.6
|
| 29 |
+
STRIDE = 1
|
| 30 |
+
# Default probability thresholds for pattern identification.
|
| 31 |
+
PROBABILITY_THRESHOLD_LIST = [0.8884, 0.8676, 0.5620, 0.5596, 0.5132, 0.8367, 0.7635]
|
| 32 |
+
PROB_THRESHOLD_NO_PATTERN = 0.5 # Threshold to mark as 'No Pattern'
|
| 33 |
+
|
| 34 |
+
# DBSCAN Clustering parameters
|
| 35 |
+
DBSCAN_EPS = 0.04
|
| 36 |
+
DBSCAN_MIN_SAMPLES = 3
|
| 37 |
+
|
| 38 |
+
# --- Private Helper Functions ---
|
| 39 |
+
|
| 40 |
+
def _process_window(i, ohlc_data_segment, rocket_model, probability_threshold, pattern_encoding_reversed, seg_start, seg_end, window_size, padding_proportion, prob_threshold_of_no_pattern_to_mark_as_no_pattern=1):
|
| 41 |
+
"""Processes a single window of OHLC data to predict patterns."""
|
| 42 |
+
start_index = i - math.ceil(window_size * padding_proportion)
|
| 43 |
+
end_index = start_index + window_size
|
| 44 |
+
|
| 45 |
+
start_index = max(start_index, 0)
|
| 46 |
+
end_index = min(end_index, len(ohlc_data_segment))
|
| 47 |
+
|
| 48 |
+
ohlc_segment = ohlc_data_segment[start_index:end_index]
|
| 49 |
+
if len(ohlc_segment) == 0:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
win_start_date = ohlc_segment['Date'].iloc[0]
|
| 53 |
+
win_end_date = ohlc_segment['Date'].iloc[-1]
|
| 54 |
+
|
| 55 |
+
# Prepare data for Rocket model (reshape and transpose)
|
| 56 |
+
ohlc_array_for_rocket = ohlc_segment[['Open', 'High', 'Low', 'Close', 'Volume']].to_numpy().reshape(1, len(ohlc_segment), 5)
|
| 57 |
+
ohlc_array_for_rocket = np.transpose(ohlc_array_for_rocket, (0, 2, 1))
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
pattern_probabilities = rocket_model.predict_proba(ohlc_array_for_rocket)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
# print(f"Error in prediction for window {i}: {e}") # Optional: for debugging
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
max_probability = np.max(pattern_probabilities)
|
| 66 |
+
# Assuming get_pattern_encoding_by_name returns a valid index or handles errors
|
| 67 |
+
no_pattern_encoding = get_pattern_encoding_by_name('No Pattern')
|
| 68 |
+
if no_pattern_encoding is None: # Handle case where 'No Pattern' is not in encoding
|
| 69 |
+
# print("Warning: 'No Pattern' encoding not found.") # Optional warning
|
| 70 |
+
no_pattern_proba = 0
|
| 71 |
+
else:
|
| 72 |
+
no_pattern_proba = pattern_probabilities[0][no_pattern_encoding]
|
| 73 |
+
|
| 74 |
+
pattern_index = np.argmax(pattern_probabilities)
|
| 75 |
+
|
| 76 |
+
pred_proba = max_probability
|
| 77 |
+
pred_pattern = get_patetrn_name_by_encoding(pattern_index)
|
| 78 |
+
|
| 79 |
+
if no_pattern_proba >= prob_threshold_of_no_pattern_to_mark_as_no_pattern: # Use >= for consistency
|
| 80 |
+
pred_proba = no_pattern_proba
|
| 81 |
+
pred_pattern = 'No Pattern'
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
'Start': win_start_date, 'End': win_end_date, 'Chart Pattern': pred_pattern,
|
| 85 |
+
'Seg_Start': seg_start, 'Seg_End': seg_end, 'Probability': pred_proba
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
def _parallel_process_sliding_window(ohlc_data_segment, rocket_model, probability_threshold, stride, pattern_encoding_reversed, window_size, padding_proportion, prob_threshold_of_no_pattern_to_mark_as_no_pattern=1, parallel=True, num_cores=16, verbose_level=1):
|
| 89 |
+
"""Applies sliding window pattern detection in parallel or sequentially."""
|
| 90 |
+
seg_start = ohlc_data_segment['Date'].iloc[0]
|
| 91 |
+
seg_end = ohlc_data_segment['Date'].iloc[-1]
|
| 92 |
+
|
| 93 |
+
common_args = {
|
| 94 |
+
'ohlc_data_segment': ohlc_data_segment,
|
| 95 |
+
'rocket_model': rocket_model,
|
| 96 |
+
'probability_threshold': probability_threshold,
|
| 97 |
+
'pattern_encoding_reversed': pattern_encoding_reversed,
|
| 98 |
+
'window_size': window_size,
|
| 99 |
+
'seg_start': seg_start,
|
| 100 |
+
'seg_end': seg_end,
|
| 101 |
+
'padding_proportion': padding_proportion,
|
| 102 |
+
'prob_threshold_of_no_pattern_to_mark_as_no_pattern': prob_threshold_of_no_pattern_to_mark_as_no_pattern
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
if parallel:
|
| 106 |
+
with Parallel(n_jobs=num_cores, verbose=verbose_level) as parallel_executor: # User requested verbose
|
| 107 |
+
results = parallel_executor(
|
| 108 |
+
delayed(_process_window)(i=i, **common_args)
|
| 109 |
+
for i in range(0, len(ohlc_data_segment), stride)
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
results = []
|
| 113 |
+
total_iterations = len(range(0, len(ohlc_data_segment), stride)) # Optional: for progress
|
| 114 |
+
for i_idx, i in enumerate(range(0, len(ohlc_data_segment), stride)):
|
| 115 |
+
res = _process_window(i=i, **common_args)
|
| 116 |
+
if res is not None:
|
| 117 |
+
results.append(res)
|
| 118 |
+
if verbose_level > 0: # Basic progress for sequential
|
| 119 |
+
print(f"Processing window {i_idx + 1} of {total_iterations}...")
|
| 120 |
+
|
| 121 |
+
return pd.DataFrame([res for res in results if res is not None])
|
| 122 |
+
|
| 123 |
+
def _prepare_dataset_for_cluster(ohlc_data_segment, win_results_df):
|
| 124 |
+
"""Adds position-based features to window results for clustering."""
|
| 125 |
+
predicted_patterns = win_results_df.copy()
|
| 126 |
+
|
| 127 |
+
for index, row in predicted_patterns.iterrows():
|
| 128 |
+
pattern_start_date = row['Start']
|
| 129 |
+
pattern_end_date = row['End']
|
| 130 |
+
|
| 131 |
+
start_point_index = len(ohlc_data_segment[ohlc_data_segment['Date'] < pattern_start_date])
|
| 132 |
+
pattern_len = len(ohlc_data_segment[(ohlc_data_segment['Date'] >= pattern_start_date) & (ohlc_data_segment['Date'] <= pattern_end_date)])
|
| 133 |
+
|
| 134 |
+
pattern_mid_index = start_point_index + (pattern_len / 2.0) # Use float division
|
| 135 |
+
|
| 136 |
+
predicted_patterns.at[index, 'Center'] = pattern_mid_index
|
| 137 |
+
predicted_patterns.at[index, 'Pattern_Start_pos'] = start_point_index
|
| 138 |
+
predicted_patterns.at[index, 'Pattern_End_pos'] = start_point_index + pattern_len
|
| 139 |
+
return predicted_patterns
|
| 140 |
+
|
| 141 |
+
def _cluster_windows(predicted_patterns, probability_threshold, eps=0.05, min_samples_dbscan=2):
|
| 142 |
+
"""Clusters detected pattern windows using DBSCAN.
|
| 143 |
+
min_samples_dbscan is the min_samples for DBSCAN algorithm itself.
|
| 144 |
+
The overlap check for intersected_clusters will also use this value.
|
| 145 |
+
"""
|
| 146 |
+
df = predicted_patterns.copy()
|
| 147 |
+
|
| 148 |
+
if isinstance(probability_threshold, list):
|
| 149 |
+
temp_dfs = []
|
| 150 |
+
# Ensure probability_threshold list length matches number of encodable patterns if used directly with get_patetrn_name_by_encoding(i)
|
| 151 |
+
# Or, better, iterate through unique patterns present in df if threshold list is a dict or structured differently.
|
| 152 |
+
# Assuming probability_threshold list is indexed corresponding to pattern encodings from 0 to N-1
|
| 153 |
+
for i, p_thresh in enumerate(probability_threshold):
|
| 154 |
+
pattern_name = get_patetrn_name_by_encoding(i)
|
| 155 |
+
if pattern_name:
|
| 156 |
+
temp_dfs.append(df[(df['Chart Pattern'] == pattern_name) & (df['Probability'] >= p_thresh)])
|
| 157 |
+
if temp_dfs:
|
| 158 |
+
df = pd.concat(temp_dfs) if temp_dfs else pd.DataFrame(columns=df.columns)
|
| 159 |
+
else:
|
| 160 |
+
df = pd.DataFrame(columns=df.columns)
|
| 161 |
+
else: # single float threshold
|
| 162 |
+
df = df[df['Probability'] >= probability_threshold] # Changed > to >=
|
| 163 |
+
|
| 164 |
+
if df.empty:
|
| 165 |
+
return pd.DataFrame(), pd.DataFrame()
|
| 166 |
+
|
| 167 |
+
cluster_labled_windows_list = []
|
| 168 |
+
interseced_clusters_list = []
|
| 169 |
+
|
| 170 |
+
# Normalize 'Center' for DBSCAN if there's variance
|
| 171 |
+
min_center_val = df['Center'].min()
|
| 172 |
+
max_center_val = df['Center'].max()
|
| 173 |
+
|
| 174 |
+
for pattern, group in df.groupby('Chart Pattern'):
|
| 175 |
+
if group.empty:
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
centers = group['Center'].values.reshape(-1, 1)
|
| 179 |
+
|
| 180 |
+
if min_center_val < max_center_val: # Avoid division by zero if all centers are same
|
| 181 |
+
norm_centers = (centers - min_center_val) / (max_center_val - min_center_val)
|
| 182 |
+
elif len(centers) > 0 : # All centers are the same, no real distance variance
|
| 183 |
+
norm_centers = np.zeros_like(centers) # Treat as single point for clustering
|
| 184 |
+
else: # Empty group after filtering, should not happen if group.empty() check passed
|
| 185 |
+
norm_centers = np.array([])
|
| 186 |
+
|
| 187 |
+
if len(norm_centers) == 0:
|
| 188 |
+
group['Cluster'] = -1
|
| 189 |
+
cluster_labled_windows_list.append(group)
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
current_min_samples_for_dbscan = min(min_samples_dbscan, len(norm_centers))
|
| 193 |
+
if current_min_samples_for_dbscan < 1 and len(norm_centers) > 0 :
|
| 194 |
+
current_min_samples_for_dbscan = 1
|
| 195 |
+
elif len(norm_centers) == 0:
|
| 196 |
+
group['Cluster'] = -1
|
| 197 |
+
cluster_labled_windows_list.append(group)
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
db = DBSCAN(eps=eps, min_samples=current_min_samples_for_dbscan).fit(norm_centers)
|
| 201 |
+
group['Cluster'] = db.labels_
|
| 202 |
+
cluster_labled_windows_list.append(group)
|
| 203 |
+
|
| 204 |
+
for cluster_id, cluster_group in group[group['Cluster'] != -1].groupby('Cluster'):
|
| 205 |
+
expanded_dates = []
|
| 206 |
+
for _, row_cg in cluster_group.iterrows(): # Renamed 'row' to 'row_cg' to avoid conflict
|
| 207 |
+
# Ensure Start and End are valid datetime objects
|
| 208 |
+
try:
|
| 209 |
+
dates = pd.date_range(start=pd.to_datetime(row_cg["Start"]), end=pd.to_datetime(row_cg["End"]))
|
| 210 |
+
expanded_dates.extend(dates)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
# print(f"Warning: Could not create date range for row: {row_cg}. Error: {e}") # Optional
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if not expanded_dates:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
date_counts = pd.Series(expanded_dates).value_counts().sort_index()
|
| 220 |
+
|
| 221 |
+
# Use min_samples_dbscan for defining a significant overlap
|
| 222 |
+
overlapping_dates = date_counts[date_counts >= min_samples_dbscan]
|
| 223 |
+
if overlapping_dates.empty:
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
cluster_start = overlapping_dates.index.min()
|
| 227 |
+
cluster_end = overlapping_dates.index.max()
|
| 228 |
+
|
| 229 |
+
interseced_clusters_list.append({
|
| 230 |
+
'Chart Pattern': pattern,
|
| 231 |
+
'Cluster': cluster_id, # This ID is local to the (pattern, window_size) batch
|
| 232 |
+
'Start': cluster_start,
|
| 233 |
+
'End': cluster_end,
|
| 234 |
+
'Seg_Start': cluster_group['Seg_Start'].iloc[0],
|
| 235 |
+
'Seg_End': cluster_group['Seg_End'].iloc[0],
|
| 236 |
+
'Avg_Probability': cluster_group['Probability'].mean(),
|
| 237 |
+
})
|
| 238 |
+
|
| 239 |
+
final_cluster_labled_df = pd.concat(cluster_labled_windows_list) if cluster_labled_windows_list else pd.DataFrame(columns=df.columns if not df.empty else [])
|
| 240 |
+
if 'Cluster' not in final_cluster_labled_df.columns and not final_cluster_labled_df.empty:
|
| 241 |
+
final_cluster_labled_df['Cluster'] = -1 # Default if no clusters formed but df had data
|
| 242 |
+
|
| 243 |
+
final_interseced_df = pd.DataFrame(interseced_clusters_list)
|
| 244 |
+
|
| 245 |
+
return final_cluster_labled_df, final_interseced_df
|
| 246 |
+
|
| 247 |
+
# --- Public API Function ---
|
| 248 |
+
|
| 249 |
+
def locate_patterns(ohlc_data: pd.DataFrame,
|
| 250 |
+
patterns_to_return: list = None,
|
| 251 |
+
model=None,
|
| 252 |
+
pattern_encoding_reversed=None,
|
| 253 |
+
win_size_proportions: list = None,
|
| 254 |
+
padding_proportion: float = PADDING_PROPORTION,
|
| 255 |
+
stride: int = STRIDE,
|
| 256 |
+
probability_threshold = None,
|
| 257 |
+
prob_threshold_of_no_pattern_to_mark_as_no_pattern: float = PROB_THRESHOLD_NO_PATTERN,
|
| 258 |
+
dbscan_eps: float = DBSCAN_EPS,
|
| 259 |
+
dbscan_min_samples: int = DBSCAN_MIN_SAMPLES,
|
| 260 |
+
enable_plotting: bool = False, # Keep parameter but ignore it
|
| 261 |
+
parallel_processing: bool = True,
|
| 262 |
+
num_cores_parallel: int = 16,
|
| 263 |
+
parallel_verbose_level: int = 1
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Locates financial chart patterns in OHLC data using a sliding window approach and clustering.
|
| 267 |
+
"""
|
| 268 |
+
active_model = model if model is not None else rocket_model_global
|
| 269 |
+
active_pattern_encoding_rev = pattern_encoding_reversed if pattern_encoding_reversed is not None else pattern_encoding_reversed_global
|
| 270 |
+
active_win_size_proportions = win_size_proportions if win_size_proportions is not None else WIN_SIZE_PROPORTIONS
|
| 271 |
+
active_probability_threshold = probability_threshold if probability_threshold is not None else PROBABILITY_THRESHOLD_LIST
|
| 272 |
+
|
| 273 |
+
if active_model is None:
|
| 274 |
+
print("Error: Pattern detection model is not loaded. Cannot proceed.")
|
| 275 |
+
return pd.DataFrame()
|
| 276 |
+
|
| 277 |
+
ohlc_data_segment = ohlc_data.copy()
|
| 278 |
+
ohlc_data_segment['Date'] = pd.to_datetime(ohlc_data_segment['Date'])
|
| 279 |
+
seg_len = len(ohlc_data_segment)
|
| 280 |
+
|
| 281 |
+
if ohlc_data_segment.empty:
|
| 282 |
+
return pd.DataFrame()
|
| 283 |
+
|
| 284 |
+
win_results_for_each_size = []
|
| 285 |
+
located_patterns_and_other_info_for_each_size = []
|
| 286 |
+
cluster_labled_windows_list = [] # Stores all clustered windows from all iterations
|
| 287 |
+
used_win_sizes = []
|
| 288 |
+
global_cluster_id_offset = 0 # To ensure cluster IDs are unique across all window sizes and patterns
|
| 289 |
+
|
| 290 |
+
for win_prop in active_win_size_proportions:
|
| 291 |
+
window_size = seg_len // win_prop if win_prop > 0 else seg_len # Avoid division by zero
|
| 292 |
+
window_size = int(max(10, window_size))
|
| 293 |
+
|
| 294 |
+
if window_size in used_win_sizes:
|
| 295 |
+
continue
|
| 296 |
+
used_win_sizes.append(window_size)
|
| 297 |
+
|
| 298 |
+
win_results_df = _parallel_process_sliding_window(
|
| 299 |
+
ohlc_data_segment, active_model, active_probability_threshold, stride,
|
| 300 |
+
active_pattern_encoding_rev, window_size, padding_proportion,
|
| 301 |
+
prob_threshold_of_no_pattern_to_mark_as_no_pattern,
|
| 302 |
+
parallel=parallel_processing, num_cores=num_cores_parallel,
|
| 303 |
+
verbose_level=parallel_verbose_level # Pass verbosity
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
if win_results_df.empty:
|
| 307 |
+
continue
|
| 308 |
+
win_results_df['Window_Size'] = window_size
|
| 309 |
+
# win_results_for_each_size.append(win_results_df) # Not directly used later, can be omitted if not needed for debugging
|
| 310 |
+
|
| 311 |
+
predicted_patterns_for_cluster = _prepare_dataset_for_cluster(ohlc_data_segment, win_results_df)
|
| 312 |
+
if predicted_patterns_for_cluster.empty:
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
# Pass dbscan_min_samples to _cluster_windows
|
| 316 |
+
temp_cluster_labled_windows_df, temp_interseced_clusters_df = _cluster_windows(
|
| 317 |
+
predicted_patterns_for_cluster, active_probability_threshold,
|
| 318 |
+
eps=dbscan_eps, min_samples_dbscan=dbscan_min_samples # Pass the parameter
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
if temp_cluster_labled_windows_df.empty or temp_interseced_clusters_df.empty:
|
| 322 |
+
continue
|
| 323 |
+
|
| 324 |
+
# Adjust cluster IDs to be globally unique before appending
|
| 325 |
+
# For temp_cluster_labled_windows_df
|
| 326 |
+
non_noise_clusters_mask_labeled = temp_cluster_labled_windows_df['Cluster'] != -1
|
| 327 |
+
if non_noise_clusters_mask_labeled.any():
|
| 328 |
+
temp_cluster_labled_windows_df.loc[non_noise_clusters_mask_labeled, 'Cluster'] = \
|
| 329 |
+
temp_cluster_labled_windows_df.loc[non_noise_clusters_mask_labeled, 'Cluster'].astype(int) + global_cluster_id_offset
|
| 330 |
+
|
| 331 |
+
# For temp_interseced_clusters_df
|
| 332 |
+
# Note: 'Cluster' in temp_interseced_clusters_df is already filtered for non-noise by its creation logic
|
| 333 |
+
if not temp_interseced_clusters_df.empty:
|
| 334 |
+
temp_interseced_clusters_df['Cluster'] = temp_interseced_clusters_df['Cluster'].astype(int) + global_cluster_id_offset
|
| 335 |
+
|
| 336 |
+
current_max_cluster_id_in_batch = -1
|
| 337 |
+
if not temp_interseced_clusters_df.empty and 'Cluster' in temp_interseced_clusters_df.columns:
|
| 338 |
+
valid_clusters = temp_interseced_clusters_df[temp_interseced_clusters_df['Cluster'] != -1]['Cluster']
|
| 339 |
+
if not valid_clusters.empty:
|
| 340 |
+
current_max_cluster_id_in_batch = valid_clusters.max()
|
| 341 |
+
|
| 342 |
+
cluster_labled_windows_list.append(temp_cluster_labled_windows_df)
|
| 343 |
+
|
| 344 |
+
temp_interseced_clusters_df['Calc_Start'] = temp_interseced_clusters_df['Start']
|
| 345 |
+
temp_interseced_clusters_df['Calc_End'] = temp_interseced_clusters_df['End']
|
| 346 |
+
located_patterns_info = temp_interseced_clusters_df.copy()
|
| 347 |
+
located_patterns_info['Window_Size'] = window_size
|
| 348 |
+
located_patterns_and_other_info_for_each_size.append(located_patterns_info)
|
| 349 |
+
|
| 350 |
+
if current_max_cluster_id_in_batch > -1 :
|
| 351 |
+
global_cluster_id_offset = current_max_cluster_id_in_batch + 1
|
| 352 |
+
elif non_noise_clusters_mask_labeled.any(): # If intersected was empty but labeled had clusters
|
| 353 |
+
max_labeled_cluster = temp_cluster_labled_windows_df.loc[non_noise_clusters_mask_labeled, 'Cluster'].max()
|
| 354 |
+
global_cluster_id_offset = max_labeled_cluster + 1
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if not located_patterns_and_other_info_for_each_size:
|
| 358 |
+
return pd.DataFrame()
|
| 359 |
+
|
| 360 |
+
all_located_patterns_df = pd.concat(located_patterns_and_other_info_for_each_size, ignore_index=True)
|
| 361 |
+
if all_located_patterns_df.empty:
|
| 362 |
+
return pd.DataFrame()
|
| 363 |
+
|
| 364 |
+
# Filter overlapping patterns (logic remains similar to previous version)
|
| 365 |
+
unique_chart_patterns = all_located_patterns_df['Chart Pattern'].unique()
|
| 366 |
+
# Sort window sizes descending to prioritize larger windows
|
| 367 |
+
sorted_unique_window_sizes = np.sort(all_located_patterns_df['Window_Size'].unique())[::-1]
|
| 368 |
+
|
| 369 |
+
final_filtered_patterns_list = []
|
| 370 |
+
# Use a copy and mark 'taken' to handle overlaps systematically
|
| 371 |
+
candidate_patterns_df = all_located_patterns_df.copy()
|
| 372 |
+
# Ensure 'taken' column exists, default to False
|
| 373 |
+
if 'taken' not in candidate_patterns_df.columns:
|
| 374 |
+
candidate_patterns_df['taken'] = False
|
| 375 |
+
else: # if it somehow exists from a previous run (unlikely with .copy()), reset it
|
| 376 |
+
candidate_patterns_df['taken'] = False
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
for cp_val in unique_chart_patterns:
|
| 380 |
+
for ws_val in sorted_unique_window_sizes:
|
| 381 |
+
# Select current batch of patterns to consider
|
| 382 |
+
current_batch_indices = candidate_patterns_df[
|
| 383 |
+
(candidate_patterns_df['Chart Pattern'] == cp_val) &
|
| 384 |
+
(candidate_patterns_df['Window_Size'] == ws_val) &
|
| 385 |
+
(~candidate_patterns_df['taken'])
|
| 386 |
+
].index
|
| 387 |
+
|
| 388 |
+
for current_idx in current_batch_indices:
|
| 389 |
+
if candidate_patterns_df.loc[current_idx, 'taken']: # Already claimed by a higher priority pattern
|
| 390 |
+
continue
|
| 391 |
+
|
| 392 |
+
current_row_data = candidate_patterns_df.loc[current_idx]
|
| 393 |
+
final_filtered_patterns_list.append(current_row_data.drop('taken')) # Add to final list
|
| 394 |
+
candidate_patterns_df.loc[current_idx, 'taken'] = True # Mark as taken
|
| 395 |
+
|
| 396 |
+
# Now, check for overlaps with other non-taken patterns and invalidate lower-priority ones
|
| 397 |
+
# Lower priority: smaller window, or same window but processed later (which this loop structure handles),
|
| 398 |
+
# or significantly lower probability.
|
| 399 |
+
overlapping_candidates_indices = candidate_patterns_df[
|
| 400 |
+
(candidate_patterns_df.index != current_idx) & # Don't compare with itself
|
| 401 |
+
(candidate_patterns_df['Chart Pattern'] == cp_val) &
|
| 402 |
+
(~candidate_patterns_df['taken']) &
|
| 403 |
+
(candidate_patterns_df['Calc_Start'] <= current_row_data['Calc_End']) &
|
| 404 |
+
(candidate_patterns_df['Calc_End'] >= current_row_data['Calc_Start'])
|
| 405 |
+
].index
|
| 406 |
+
|
| 407 |
+
for ov_idx in overlapping_candidates_indices:
|
| 408 |
+
ov_row_data = candidate_patterns_df.loc[ov_idx]
|
| 409 |
+
iou = intersection_over_union(current_row_data['Calc_Start'], current_row_data['Calc_End'],
|
| 410 |
+
ov_row_data['Calc_Start'], ov_row_data['Calc_End'])
|
| 411 |
+
if iou > 0.6: # Significant overlap
|
| 412 |
+
# current_row_data (from larger/earlier window) is preferred by default.
|
| 413 |
+
# ov_row_data (overlapping candidate) is discarded UNLESS:
|
| 414 |
+
# it's from a smaller window AND has significantly higher probability.
|
| 415 |
+
is_ov_preferred = (ov_row_data['Window_Size'] < current_row_data['Window_Size']) and \
|
| 416 |
+
((ov_row_data['Avg_Probability'] - current_row_data['Avg_Probability']) > 0.1)
|
| 417 |
+
|
| 418 |
+
if not is_ov_preferred:
|
| 419 |
+
candidate_patterns_df.loc[ov_idx, 'taken'] = True
|
| 420 |
+
# If ov_preferred, current_row_data was already added. The ov_row will be considered
|
| 421 |
+
# when its (smaller) window size turn comes, if not already taken.
|
| 422 |
+
# This implies a potential issue: if current_row is added, and a smaller, much better ov_row exists,
|
| 423 |
+
# current_row should ideally be removed. The current logic adds current_row first.
|
| 424 |
+
# For a more robust selection, decisions might need to be deferred or involve pairwise ranking.
|
| 425 |
+
# However, given the descending window size iteration, this greedy choice is often sufficient.
|
| 426 |
+
# Re-evaluating this complex interaction:
|
| 427 |
+
# If current_row (larger window) is chosen, and an ov_row (smaller window, much higher prob) exists,
|
| 428 |
+
# the current logic keeps current_row and marks ov_row as NOT taken, so ov_row can be picked later.
|
| 429 |
+
# This might lead to both being in the list if their IoU with *other* patterns doesn't disqualify them.
|
| 430 |
+
# The final drop_duplicates will handle exact overlaps.
|
| 431 |
+
|
| 432 |
+
filtered_loc_pat_and_info_df = pd.DataFrame(final_filtered_patterns_list)
|
| 433 |
+
if not filtered_loc_pat_and_info_df.empty:
|
| 434 |
+
# Drop duplicates based on the defining characteristics of a pattern instance
|
| 435 |
+
filtered_loc_pat_and_info_df = filtered_loc_pat_and_info_df.sort_values(
|
| 436 |
+
by=['Chart Pattern', 'Calc_Start', 'Window_Size', 'Avg_Probability'],
|
| 437 |
+
ascending=[True, True, False, False] # Prioritize larger window, then higher prob for duplicates
|
| 438 |
+
).drop_duplicates(
|
| 439 |
+
subset=['Chart Pattern', 'Calc_Start', 'Calc_End'],
|
| 440 |
+
keep='first' # Keep the one that came first after sorting (best according to sort)
|
| 441 |
+
).sort_values(by='Calc_Start').reset_index(drop=True)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
if enable_plotting and not filtered_loc_pat_and_info_df.empty and cluster_labled_windows_list:
|
| 445 |
+
# Remove plotting code
|
| 446 |
+
pass
|
| 447 |
+
|
| 448 |
+
if patterns_to_return and not filtered_loc_pat_and_info_df.empty:
|
| 449 |
+
return filtered_loc_pat_and_info_df[filtered_loc_pat_and_info_df['Chart Pattern'].isin(patterns_to_return)]
|
| 450 |
+
|
| 451 |
+
return filtered_loc_pat_and_info_df
|
| 452 |
+
|