GoshawkVortexAI commited on
Commit
e365f22
·
verified ·
1 Parent(s): 0a82a80

Create threshold_optimizer.py

Browse files
Files changed (1) hide show
  1. threshold_optimizer.py +196 -0
threshold_optimizer.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ threshold_optimizer.py — Post-training threshold calibration tool.
3
+
4
+ Run this standalone to re-optimize the probability threshold on new data
5
+ WITHOUT retraining the model. Useful for:
6
+ - Adapting to regime changes without full retraining
7
+ - Testing different optimization objectives
8
+ - Out-of-sample threshold validation
9
+
10
+ The threshold search maximizes expectancy or Sharpe over a held-out dataset.
11
+
12
+ Usage:
13
+ python threshold_optimizer.py --symbols BTC-USDT ETH-USDT --bars 200
14
+ python threshold_optimizer.py --objective sharpe
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import logging
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import pandas as pd
25
+ import matplotlib
26
+ matplotlib.use("Agg") # non-interactive backend
27
+ import matplotlib.pyplot as plt
28
+
29
+ sys.path.insert(0, str(Path(__file__).parent))
30
+
31
+ from ml_config import (
32
+ THRESHOLD_PATH,
33
+ THRESHOLD_MIN,
34
+ THRESHOLD_MAX,
35
+ THRESHOLD_STEPS,
36
+ THRESHOLD_OBJECTIVE,
37
+ TARGET_RR,
38
+ ROUND_TRIP_COST,
39
+ FEATURE_COLUMNS,
40
+ ML_DIR,
41
+ )
42
+ from ml_filter import TradeFilter
43
+ from feature_builder import build_feature_dict, validate_features
44
+ from train import build_dataset
45
+
46
+ logger = logging.getLogger(__name__)
47
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
48
+
49
+
50
+ def compute_threshold_curve(
51
+ probs: np.ndarray,
52
+ y_true: np.ndarray,
53
+ rr: float = TARGET_RR,
54
+ cost: float = ROUND_TRIP_COST,
55
+ ) -> pd.DataFrame:
56
+ """
57
+ Sweep threshold grid and compute metrics at each threshold.
58
+ Returns DataFrame for analysis and plotting.
59
+ """
60
+ thresholds = np.linspace(THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS)
61
+ records = []
62
+
63
+ for t in thresholds:
64
+ mask = probs >= t
65
+ n = int(mask.sum())
66
+ if n < 5:
67
+ records.append({
68
+ "threshold": t, "n_trades": n,
69
+ "win_rate": np.nan, "expectancy": np.nan,
70
+ "sharpe": np.nan, "precision": np.nan,
71
+ "coverage": 0.0,
72
+ })
73
+ continue
74
+
75
+ y_f = y_true[mask]
76
+ wr = float(y_f.mean())
77
+ exp = wr * rr - (1 - wr) * 1.0 - cost
78
+ pnl = np.where(y_f == 1, rr, -1.0) - cost
79
+ sh = (pnl.mean() / pnl.std() * np.sqrt(252)) if pnl.std() > 1e-9 else 0.0
80
+ cov = n / len(y_true)
81
+
82
+ records.append({
83
+ "threshold": round(t, 4),
84
+ "n_trades": n,
85
+ "win_rate": round(wr, 4),
86
+ "expectancy": round(exp, 4),
87
+ "sharpe": round(sh, 4),
88
+ "precision": round(wr, 4),
89
+ "coverage": round(cov, 4),
90
+ })
91
+
92
+ return pd.DataFrame(records)
93
+
94
+
95
+ def find_optimal_threshold(
96
+ curve: pd.DataFrame,
97
+ objective: str = THRESHOLD_OBJECTIVE,
98
+ min_trades: int = 20,
99
+ ) -> float:
100
+ valid = curve[curve["n_trades"] >= min_trades].dropna(subset=[objective])
101
+ if valid.empty:
102
+ logger.warning("No valid threshold found — using default 0.55")
103
+ return 0.55
104
+ best_row = valid.loc[valid[objective].idxmax()]
105
+ return float(best_row["threshold"])
106
+
107
+
108
+ def plot_threshold_curves(curve: pd.DataFrame, optimal: float, save_path: Path):
109
+ fig, axes = plt.subplots(2, 2, figsize=(12, 8))
110
+ fig.suptitle("Threshold Optimization", fontsize=14, fontweight="bold")
111
+
112
+ metrics = ["expectancy", "sharpe", "win_rate", "n_trades"]
113
+ titles = ["Expectancy per Trade", "Annualized Sharpe", "Win Rate", "# Trades"]
114
+
115
+ for ax, metric, title in zip(axes.flatten(), metrics, titles):
116
+ valid = curve.dropna(subset=[metric])
117
+ ax.plot(valid["threshold"], valid[metric], lw=2, color="#1a6bff")
118
+ ax.axvline(optimal, color="orange", linestyle="--", lw=1.5, label=f"Optimal={optimal:.3f}")
119
+ ax.axhline(0, color="gray", linestyle=":", lw=0.8)
120
+ ax.set_title(title, fontsize=11)
121
+ ax.set_xlabel("Threshold")
122
+ ax.legend(fontsize=9)
123
+ ax.grid(True, alpha=0.3)
124
+
125
+ plt.tight_layout()
126
+ plt.savefig(save_path, dpi=120, bbox_inches="tight")
127
+ plt.close()
128
+ logger.info(f"Threshold curve plot saved → {save_path}")
129
+
130
+
131
+ def main(args):
132
+ trade_filter = TradeFilter.load_or_none()
133
+ if trade_filter is None:
134
+ logger.error("No trained model found. Run train.py first.")
135
+ sys.exit(1)
136
+
137
+ symbols = args.symbols or ["BTC-USDT", "ETH-USDT", "SOL-USDT", "BNB-USDT"]
138
+ dataset = build_dataset(symbols, bars=args.bars)
139
+
140
+ X = dataset[FEATURE_COLUMNS].values.astype(np.float64)
141
+ y = dataset["label"].values.astype(np.int32)
142
+
143
+ feature_dicts = [
144
+ {k: float(row[k]) for k in FEATURE_COLUMNS}
145
+ for _, row in dataset[FEATURE_COLUMNS].iterrows()
146
+ ]
147
+ probs = trade_filter.predict_batch(feature_dicts)
148
+
149
+ logger.info(f"Generated {len(probs)} predictions | mean_prob={probs.mean():.4f}")
150
+
151
+ curve = compute_threshold_curve(probs, y)
152
+ optimal = find_optimal_threshold(curve, objective=args.objective)
153
+ best_row = curve[curve["threshold"].round(4) == round(optimal, 4)].iloc[0]
154
+
155
+ logger.info(f"\n=== THRESHOLD OPTIMIZATION RESULT ===")
156
+ logger.info(f" Objective: {args.objective}")
157
+ logger.info(f" Optimal threshold: {optimal:.4f}")
158
+ logger.info(f" Win rate: {best_row['win_rate']:.4f}")
159
+ logger.info(f" Expectancy: {best_row['expectancy']:.4f}")
160
+ logger.info(f" Sharpe: {best_row['sharpe']:.4f}")
161
+ logger.info(f" # Trades: {int(best_row['n_trades'])}")
162
+ logger.info(f" Coverage: {best_row['coverage']:.2%}")
163
+
164
+ # Update threshold file
165
+ ML_DIR.mkdir(parents=True, exist_ok=True)
166
+ thresh_data = {
167
+ "threshold": optimal,
168
+ "objective": args.objective,
169
+ "win_rate_at_threshold": float(best_row["win_rate"]),
170
+ "expectancy_at_threshold": float(best_row["expectancy"]),
171
+ "sharpe_at_threshold": float(best_row["sharpe"]),
172
+ "n_trades_at_threshold": int(best_row["n_trades"]),
173
+ }
174
+ with open(THRESHOLD_PATH, "w") as f:
175
+ json.dump(thresh_data, f, indent=2)
176
+ logger.info(f"Threshold updated → {THRESHOLD_PATH}")
177
+
178
+ # Save curve CSV
179
+ curve_path = ML_DIR / "threshold_curve.csv"
180
+ curve.to_csv(curve_path, index=False)
181
+
182
+ # Plot
183
+ plot_path = ML_DIR / "threshold_curve.png"
184
+ try:
185
+ plot_threshold_curves(curve, optimal, plot_path)
186
+ except Exception as e:
187
+ logger.warning(f"Plot failed: {e}")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ parser = argparse.ArgumentParser(description="Optimize probability threshold")
192
+ parser.add_argument("--symbols", nargs="+", default=None)
193
+ parser.add_argument("--bars", type=int, default=200)
194
+ parser.add_argument("--objective", choices=["expectancy", "sharpe", "win_rate"], default=THRESHOLD_OBJECTIVE)
195
+ args = parser.parse_args()
196
+ main(args)