GoshawkVortexAI commited on
Commit
1cdd6ba
·
verified ·
1 Parent(s): 9a1172d

Update regime.py

Browse files
Files changed (1) hide show
  1. regime.py +237 -34
regime.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, Any
2
 
3
  import numpy as np
@@ -9,32 +21,74 @@ from config import (
9
  STRUCTURE_CONFIRM_BARS,
10
  VOLATILITY_EXPANSION_MULT,
11
  VOLATILITY_CONTRACTION_MULT,
 
 
 
 
 
 
 
 
 
12
  )
13
 
14
 
15
  def compute_atr(df: pd.DataFrame, period: int = ATR_PERIOD) -> pd.Series:
16
  high, low, prev_close = df["high"], df["low"], df["close"].shift(1)
17
  tr = pd.concat(
18
- [
19
- high - low,
20
- (high - prev_close).abs(),
21
- (low - prev_close).abs(),
22
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  axis=1,
24
  ).max(axis=1)
25
- return tr.ewm(span=period, adjust=False).mean()
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def compute_structure(df: pd.DataFrame, lookback: int = STRUCTURE_LOOKBACK) -> pd.Series:
29
  roll_high = df["high"].rolling(lookback).max()
30
  roll_low = df["low"].rolling(lookback).min()
31
- prev_roll_high = roll_high.shift(lookback // 2)
32
- prev_roll_low = roll_low.shift(lookback // 2)
 
33
 
34
- hh = roll_high > prev_roll_high
35
- hl = roll_low > prev_roll_low
36
- lh = roll_high < prev_roll_high
37
- ll = roll_low < prev_roll_low
38
 
39
  structure = pd.Series(0, index=df.index)
40
  structure[hh & hl] = 1
@@ -42,44 +96,176 @@ def compute_structure(df: pd.DataFrame, lookback: int = STRUCTURE_LOOKBACK) -> p
42
  return structure
43
 
44
 
45
- def compute_vol_ratio(df: pd.DataFrame, period: int = ATR_PERIOD) -> pd.Series:
46
- atr = compute_atr(df, period)
47
- atr_ma = atr.rolling(period * 2).mean().replace(0, np.nan)
48
- return atr / atr_ma
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
50
 
51
- def classify_trend(structure_series: pd.Series, lookback: int = STRUCTURE_CONFIRM_BARS) -> str:
52
- recent = structure_series.iloc[-lookback:]
53
- bullish = (recent == 1).sum()
54
- bearish = (recent == -1).sum()
55
- if bullish > bearish and bullish >= max(1, lookback // 2):
56
  return "bullish"
57
- if bearish > bullish and bearish >= max(1, lookback // 2):
58
  return "bearish"
59
  return "ranging"
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def detect_regime(df: pd.DataFrame) -> Dict[str, Any]:
63
  atr_series = compute_atr(df, ATR_PERIOD)
 
64
  structure_series = compute_structure(df, STRUCTURE_LOOKBACK)
65
- vol_ratio_series = compute_vol_ratio(df, ATR_PERIOD)
 
 
 
 
66
 
67
  last_atr = float(atr_series.iloc[-1])
68
- last_structure = int(structure_series.iloc[-1])
69
- last_vol_ratio = float(vol_ratio_series.iloc[-1]) if not np.isnan(vol_ratio_series.iloc[-1]) else 1.0
70
  last_close = float(df["close"].iloc[-1])
 
 
 
 
 
 
 
71
 
72
- trend = classify_trend(structure_series, STRUCTURE_CONFIRM_BARS)
73
- vol_expanding = last_vol_ratio > VOLATILITY_EXPANSION_MULT
74
- vol_contracting = last_vol_ratio < VOLATILITY_CONTRACTION_MULT
 
 
75
  atr_pct = last_atr / last_close if last_close > 0 else 0.0
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  if trend == "bullish" and not vol_expanding:
78
  regime_score = 1.0
79
  elif trend == "bullish" and vol_expanding:
80
  regime_score = 0.55
81
- elif trend == "ranging" and not vol_expanding and not vol_contracting:
82
- regime_score = 0.35
83
  elif trend == "ranging":
84
  regime_score = 0.25
85
  elif trend == "bearish" and not vol_expanding:
@@ -87,6 +273,11 @@ def detect_regime(df: pd.DataFrame) -> Dict[str, Any]:
87
  else:
88
  regime_score = 0.05
89
 
 
 
 
 
 
90
  if last_structure == 1:
91
  regime_score = min(1.0, regime_score + 0.1)
92
  elif last_structure == -1:
@@ -94,19 +285,31 @@ def detect_regime(df: pd.DataFrame) -> Dict[str, Any]:
94
 
95
  atr_ma_20 = atr_series.rolling(20).mean().iloc[-1]
96
  atr_ma_50 = atr_series.rolling(50).mean().iloc[-1] if len(df) >= 50 else atr_ma_20
97
- atr_trend = "rising" if atr_ma_20 > atr_ma_50 else "falling"
98
 
99
  return {
100
  "atr": last_atr,
101
  "atr_pct": atr_pct,
 
102
  "structure": last_structure,
103
  "trend": trend,
104
- "vol_ratio": last_vol_ratio,
105
  "vol_expanding": vol_expanding,
106
  "vol_contracting": vol_contracting,
107
- "atr_trend": atr_trend,
 
 
 
 
 
 
 
 
108
  "regime_score": round(float(np.clip(regime_score, 0.0, 1.0)), 4),
 
109
  "atr_series": atr_series,
110
  "structure_series": structure_series,
111
- "vol_ratio_series": vol_ratio_series,
 
 
112
  }
 
1
+ """
2
+ regime.py — Market regime detection with ADX, volatility compression,
3
+ distance-from-mean filter, and regime confidence scoring.
4
+
5
+ Key fixes vs prior version:
6
+ - STRUCTURE_LOOKBACK halved (10) to reduce entry lag
7
+ - True ATR (not EWM-only) with percentile-based compression detection
8
+ - ADX for objective trend strength (replaces pure HH/HL heuristic)
9
+ - Regime confidence: composite of trend + structure + vol alignment
10
+ - Distance-from-mean filter to avoid entering extended moves
11
+ """
12
+
13
  from typing import Dict, Any
14
 
15
  import numpy as np
 
21
  STRUCTURE_CONFIRM_BARS,
22
  VOLATILITY_EXPANSION_MULT,
23
  VOLATILITY_CONTRACTION_MULT,
24
+ VOL_COMPRESSION_LOOKBACK,
25
+ VOL_COMPRESSION_PERCENTILE,
26
+ VOL_EXPANSION_CONFIRM_MULT,
27
+ ADX_PERIOD,
28
+ ADX_TREND_THRESHOLD,
29
+ ADX_STRONG_THRESHOLD,
30
+ DIST_FROM_MEAN_MA,
31
+ DIST_FROM_MEAN_ATR_MAX,
32
+ REGIME_CONFIDENCE_MIN,
33
  )
34
 
35
 
36
  def compute_atr(df: pd.DataFrame, period: int = ATR_PERIOD) -> pd.Series:
37
  high, low, prev_close = df["high"], df["low"], df["close"].shift(1)
38
  tr = pd.concat(
39
+ [high - low, (high - prev_close).abs(), (low - prev_close).abs()],
40
+ axis=1,
41
+ ).max(axis=1)
42
+ # Use Wilder's smoothing (RMA) — matches TradingView / industry standard
43
+ return tr.ewm(alpha=1.0 / period, adjust=False).mean()
44
+
45
+
46
+ def compute_adx(df: pd.DataFrame, period: int = ADX_PERIOD) -> pd.DataFrame:
47
+ """
48
+ Returns DataFrame with columns: adx, di_plus, di_minus.
49
+ Uses Wilder smoothing throughout to match standard ADX definition.
50
+ """
51
+ high, low, close = df["high"], df["low"], df["close"]
52
+ prev_high = high.shift(1)
53
+ prev_low = low.shift(1)
54
+ prev_close = close.shift(1)
55
+
56
+ dm_plus = (high - prev_high).clip(lower=0)
57
+ dm_minus = (prev_low - low).clip(lower=0)
58
+ # Zero out when the other direction is larger
59
+ mask = dm_plus >= dm_minus
60
+ dm_plus = dm_plus.where(mask, 0.0)
61
+ dm_minus = dm_minus.where(~mask, 0.0)
62
+
63
+ tr = pd.concat(
64
+ [high - low, (high - prev_close).abs(), (low - prev_close).abs()],
65
  axis=1,
66
  ).max(axis=1)
67
+
68
+ alpha = 1.0 / period
69
+ atr_w = tr.ewm(alpha=alpha, adjust=False).mean()
70
+ sdm_plus = dm_plus.ewm(alpha=alpha, adjust=False).mean()
71
+ sdm_minus = dm_minus.ewm(alpha=alpha, adjust=False).mean()
72
+
73
+ di_plus = 100 * sdm_plus / atr_w.replace(0, np.nan)
74
+ di_minus = 100 * sdm_minus / atr_w.replace(0, np.nan)
75
+ dx = 100 * (di_plus - di_minus).abs() / (di_plus + di_minus).replace(0, np.nan)
76
+ adx = dx.ewm(alpha=alpha, adjust=False).mean()
77
+
78
+ return pd.DataFrame({"adx": adx, "di_plus": di_plus, "di_minus": di_minus})
79
 
80
 
81
  def compute_structure(df: pd.DataFrame, lookback: int = STRUCTURE_LOOKBACK) -> pd.Series:
82
  roll_high = df["high"].rolling(lookback).max()
83
  roll_low = df["low"].rolling(lookback).min()
84
+ half = max(1, lookback // 2)
85
+ prev_high = roll_high.shift(half)
86
+ prev_low = roll_low.shift(half)
87
 
88
+ hh = roll_high > prev_high
89
+ hl = roll_low > prev_low
90
+ lh = roll_high < prev_high
91
+ ll = roll_low < prev_low
92
 
93
  structure = pd.Series(0, index=df.index)
94
  structure[hh & hl] = 1
 
96
  return structure
97
 
98
 
99
+ def compute_volatility_compression(
100
+ atr_series: pd.Series,
101
+ lookback: int = VOL_COMPRESSION_LOOKBACK,
102
+ percentile: float = VOL_COMPRESSION_PERCENTILE,
103
+ ) -> pd.Series:
104
+ """
105
+ Returns True where current ATR is below the Nth percentile of its
106
+ recent history — i.e., volatility is compressed (coiled).
107
+ """
108
+ rolling_pct = atr_series.rolling(lookback).quantile(percentile / 100.0)
109
+ return atr_series < rolling_pct
110
+
111
+
112
+ def compute_volatility_expanding_from_compression(
113
+ atr_series: pd.Series,
114
+ compressed_series: pd.Series,
115
+ mult: float = VOL_EXPANSION_CONFIRM_MULT,
116
+ lookback: int = 5,
117
+ ) -> pd.Series:
118
+ """
119
+ Returns True where ATR is now expanding (current > recent_min * mult)
120
+ AND was compressed within the last `lookback` bars.
121
+ Catches the precise moment of volatility breakout from a base.
122
+ """
123
+ recent_min_atr = atr_series.rolling(lookback).min().shift(1)
124
+ expanding = atr_series > recent_min_atr * mult
125
+ was_compressed = compressed_series.shift(1).rolling(lookback).max().fillna(0) > 0
126
+ return expanding & was_compressed
127
+
128
+
129
+ def compute_distance_from_mean(
130
+ df: pd.DataFrame,
131
+ atr_series: pd.Series,
132
+ ma_period: int = DIST_FROM_MEAN_MA,
133
+ atr_max: float = DIST_FROM_MEAN_ATR_MAX,
134
+ ) -> pd.Series:
135
+ """
136
+ Returns ATR-normalised distance of close from its SMA.
137
+ Values > atr_max mean price is too extended for a fresh long entry.
138
+ """
139
+ sma = df["close"].rolling(ma_period).mean()
140
+ distance_atr = (df["close"] - sma) / atr_series.replace(0, np.nan)
141
+ return distance_atr
142
+
143
+
144
+ def classify_trend(
145
+ structure_series: pd.Series,
146
+ adx_df: pd.DataFrame,
147
+ lookback: int = STRUCTURE_CONFIRM_BARS,
148
+ ) -> str:
149
+ recent_struct = structure_series.iloc[-lookback:]
150
+ bullish = (recent_struct == 1).sum()
151
+ bearish = (recent_struct == -1).sum()
152
+
153
+ adx_val = float(adx_df["adx"].iloc[-1]) if not np.isnan(adx_df["adx"].iloc[-1]) else 0.0
154
+ di_plus = float(adx_df["di_plus"].iloc[-1]) if not np.isnan(adx_df["di_plus"].iloc[-1]) else 0.0
155
+ di_minus = float(adx_df["di_minus"].iloc[-1]) if not np.isnan(adx_df["di_minus"].iloc[-1]) else 0.0
156
 
157
+ adx_trending = adx_val >= ADX_TREND_THRESHOLD
158
 
159
+ if adx_trending and di_plus > di_minus and bullish >= max(1, lookback // 2):
 
 
 
 
160
  return "bullish"
161
+ if adx_trending and di_minus > di_plus and bearish >= max(1, lookback // 2):
162
  return "bearish"
163
  return "ranging"
164
 
165
 
166
+ def compute_regime_confidence(
167
+ trend: str,
168
+ adx_val: float,
169
+ structure: int,
170
+ vol_expanding_from_base: bool,
171
+ vol_ratio: float,
172
+ dist_atr: float,
173
+ ) -> float:
174
+ """
175
+ Composite confidence [0, 1] requiring alignment across:
176
+ - ADX trend strength
177
+ - Price structure
178
+ - Volatility expanding from compression
179
+ - Price not extended
180
+
181
+ Low confidence = system holds off even if other scores look good.
182
+ """
183
+ score = 0.0
184
+
185
+ # ADX contribution (0 to 0.35)
186
+ if adx_val >= ADX_STRONG_THRESHOLD:
187
+ score += 0.35
188
+ elif adx_val >= ADX_TREND_THRESHOLD:
189
+ score += 0.20
190
+ else:
191
+ score += 0.05
192
+
193
+ # Structure alignment (0 to 0.25)
194
+ if trend == "bullish" and structure == 1:
195
+ score += 0.25
196
+ elif trend == "bearish" and structure == -1:
197
+ score += 0.25
198
+ elif structure == 0:
199
+ score += 0.10
200
+ else:
201
+ score += 0.0
202
+
203
+ # Volatility expanding from base (0 to 0.25)
204
+ if vol_expanding_from_base:
205
+ score += 0.25
206
+ elif 1.0 < vol_ratio < VOLATILITY_EXPANSION_MULT:
207
+ score += 0.10
208
+ else:
209
+ score += 0.0
210
+
211
+ # Price not extended (0 to 0.15)
212
+ abs_dist = abs(dist_atr) if not np.isnan(dist_atr) else 0.0
213
+ if abs_dist < 1.0:
214
+ score += 0.15
215
+ elif abs_dist < DIST_FROM_MEAN_ATR_MAX:
216
+ score += 0.07
217
+ else:
218
+ score += 0.0
219
+
220
+ return float(np.clip(score, 0.0, 1.0))
221
+
222
+
223
  def detect_regime(df: pd.DataFrame) -> Dict[str, Any]:
224
  atr_series = compute_atr(df, ATR_PERIOD)
225
+ adx_df = compute_adx(df, ADX_PERIOD)
226
  structure_series = compute_structure(df, STRUCTURE_LOOKBACK)
227
+ compressed_series = compute_volatility_compression(atr_series)
228
+ expanding_from_base = compute_volatility_expanding_from_compression(
229
+ atr_series, compressed_series
230
+ )
231
+ dist_atr_series = compute_distance_from_mean(df, atr_series)
232
 
233
  last_atr = float(atr_series.iloc[-1])
 
 
234
  last_close = float(df["close"].iloc[-1])
235
+ last_structure = int(structure_series.iloc[-1])
236
+ last_adx = float(adx_df["adx"].iloc[-1]) if not np.isnan(adx_df["adx"].iloc[-1]) else 0.0
237
+ last_di_plus = float(adx_df["di_plus"].iloc[-1]) if not np.isnan(adx_df["di_plus"].iloc[-1]) else 0.0
238
+ last_di_minus = float(adx_df["di_minus"].iloc[-1]) if not np.isnan(adx_df["di_minus"].iloc[-1]) else 0.0
239
+ last_compressed = bool(compressed_series.iloc[-1])
240
+ last_expanding_from_base = bool(expanding_from_base.iloc[-1])
241
+ last_dist_atr = float(dist_atr_series.iloc[-1]) if not np.isnan(dist_atr_series.iloc[-1]) else 0.0
242
 
243
+ atr_ma = atr_series.rolling(ATR_PERIOD * 2).mean()
244
+ last_atr_ma = float(atr_ma.iloc[-1]) if not np.isnan(atr_ma.iloc[-1]) else last_atr
245
+ vol_ratio = last_atr / last_atr_ma if last_atr_ma > 0 else 1.0
246
+ vol_expanding = vol_ratio > VOLATILITY_EXPANSION_MULT
247
+ vol_contracting = vol_ratio < VOLATILITY_CONTRACTION_MULT
248
  atr_pct = last_atr / last_close if last_close > 0 else 0.0
249
 
250
+ trend = classify_trend(structure_series, adx_df, STRUCTURE_CONFIRM_BARS)
251
+
252
+ price_too_extended_long = last_dist_atr > DIST_FROM_MEAN_ATR_MAX
253
+ price_too_extended_short = last_dist_atr < -DIST_FROM_MEAN_ATR_MAX
254
+
255
+ regime_confidence = compute_regime_confidence(
256
+ trend=trend,
257
+ adx_val=last_adx,
258
+ structure=last_structure,
259
+ vol_expanding_from_base=last_expanding_from_base,
260
+ vol_ratio=vol_ratio,
261
+ dist_atr=last_dist_atr,
262
+ )
263
+
264
+ # Regime score: raw directional quality
265
  if trend == "bullish" and not vol_expanding:
266
  regime_score = 1.0
267
  elif trend == "bullish" and vol_expanding:
268
  regime_score = 0.55
 
 
269
  elif trend == "ranging":
270
  regime_score = 0.25
271
  elif trend == "bearish" and not vol_expanding:
 
273
  else:
274
  regime_score = 0.05
275
 
276
+ if last_adx >= ADX_STRONG_THRESHOLD:
277
+ regime_score = min(1.0, regime_score + 0.1)
278
+ elif last_adx < ADX_TREND_THRESHOLD:
279
+ regime_score = max(0.0, regime_score - 0.15)
280
+
281
  if last_structure == 1:
282
  regime_score = min(1.0, regime_score + 0.1)
283
  elif last_structure == -1:
 
285
 
286
  atr_ma_20 = atr_series.rolling(20).mean().iloc[-1]
287
  atr_ma_50 = atr_series.rolling(50).mean().iloc[-1] if len(df) >= 50 else atr_ma_20
288
+ atr_trend_dir = "rising" if atr_ma_20 > atr_ma_50 else "falling"
289
 
290
  return {
291
  "atr": last_atr,
292
  "atr_pct": atr_pct,
293
+ "atr_pct_pct": round(atr_pct * 100, 3),
294
  "structure": last_structure,
295
  "trend": trend,
296
+ "vol_ratio": round(vol_ratio, 3),
297
  "vol_expanding": vol_expanding,
298
  "vol_contracting": vol_contracting,
299
+ "vol_compressed": last_compressed,
300
+ "vol_expanding_from_base": last_expanding_from_base,
301
+ "adx": round(last_adx, 2),
302
+ "di_plus": round(last_di_plus, 2),
303
+ "di_minus": round(last_di_minus, 2),
304
+ "dist_atr": round(last_dist_atr, 3),
305
+ "price_extended_long": price_too_extended_long,
306
+ "price_extended_short": price_too_extended_short,
307
+ "regime_confidence": round(regime_confidence, 4),
308
  "regime_score": round(float(np.clip(regime_score, 0.0, 1.0)), 4),
309
+ "atr_trend": atr_trend_dir,
310
  "atr_series": atr_series,
311
  "structure_series": structure_series,
312
+ "adx_series": adx_df,
313
+ "compressed_series": compressed_series,
314
+ "dist_atr_series": dist_atr_series,
315
  }