GoshawkVortexAI commited on
Commit
cfb37af
·
verified ·
1 Parent(s): 39bdaba

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +345 -0
train.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train.py — Full training pipeline. Run this script to train the model.
3
+
4
+ Usage:
5
+ python train.py --symbols BTC-USDT ETH-USDT SOL-USDT ... --bars 500
6
+ python train.py --use-defaults --bars 300
7
+ python train.py --data-dir ./historical_csv # load pre-saved CSVs
8
+
9
+ Pipeline:
10
+ 1. Fetch OHLCV for all symbols
11
+ 2. Run rule engine to extract features (no lookahead)
12
+ 3. Label each signal bar with forward-looking outcome
13
+ 4. Concatenate all symbols (adds cross-asset diversity)
14
+ 5. Walk-forward validation → choose threshold
15
+ 6. Final model fit on full dataset
16
+ 7. Save model + threshold + feature importances
17
+ """
18
+
19
+ import argparse
20
+ import json
21
+ import logging
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+
28
+ sys.path.insert(0, str(Path(__file__).parent))
29
+
30
+ from config import DEFAULT_SYMBOLS, TIMEFRAME, CANDLE_LIMIT
31
+ from data_fetcher import fetch_multiple
32
+ from regime import detect_regime
33
+ from volume_analysis import analyze_volume
34
+ from scorer import compute_structure_score, score_token
35
+ from veto import apply_veto
36
+ from feature_builder import build_feature_dict, validate_features
37
+ from labeler import label_dataframe, compute_label_stats
38
+ from walk_forward import run_walk_forward, summarize_walk_forward
39
+ from model_backend import ModelBackend
40
+ from ml_config import (
41
+ ML_DIR,
42
+ MODEL_PATH,
43
+ THRESHOLD_PATH,
44
+ FEATURE_IMP_PATH,
45
+ LABEL_PATH,
46
+ LGBM_PARAMS,
47
+ FEATURE_COLUMNS,
48
+ LABEL_FORWARD_BARS,
49
+ THRESHOLD_MIN,
50
+ THRESHOLD_MAX,
51
+ THRESHOLD_STEPS,
52
+ THRESHOLD_OBJECTIVE,
53
+ STOP_MULT,
54
+ )
55
+
56
+ logging.basicConfig(
57
+ level=logging.INFO,
58
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
59
+ stream=sys.stdout,
60
+ )
61
+ logger = logging.getLogger("train")
62
+
63
+
64
+ def infer_direction(trend: str, breakout: int) -> int:
65
+ if trend == "bullish" or breakout == 1:
66
+ return 1
67
+ if trend == "bearish" or breakout == -1:
68
+ return -1
69
+ return 0
70
+
71
+
72
+ def extract_features_and_labels(
73
+ symbol: str,
74
+ df: pd.DataFrame,
75
+ ) -> pd.DataFrame:
76
+ """
77
+ Run the full rule engine over a DataFrame, bar by bar (forward-only).
78
+ Returns a DataFrame with feature columns + 'label' + 'direction' + 'timestamp'.
79
+
80
+ Implementation note: we compute regime/volume/scores using the full
81
+ historical series up to each bar — no information from future bars
82
+ is ever used. The label is computed separately using FORWARD bars only.
83
+ """
84
+ if len(df) < 60:
85
+ logger.warning(f"{symbol}: too short ({len(df)} bars), skipping")
86
+ return pd.DataFrame()
87
+
88
+ # Compute full-series regime and volume (these use only past data internally)
89
+ try:
90
+ regime_data = detect_regime(df)
91
+ volume_data = analyze_volume(df, atr_series=regime_data["atr_series"])
92
+ except Exception as e:
93
+ logger.error(f"{symbol}: rule engine error: {e}")
94
+ return pd.DataFrame()
95
+
96
+ atr_series = regime_data["atr_series"]
97
+
98
+ # Build per-bar feature rows for all bars with valid ATR (skip first ATR_PERIOD)
99
+ rows = []
100
+ n = len(df)
101
+
102
+ for i in range(30, n):
103
+ # Slice up to bar i (inclusive) — simulate running bar by bar
104
+ df_i = df.iloc[: i + 1]
105
+
106
+ try:
107
+ r_i = detect_regime(df_i)
108
+ v_i = analyze_volume(df_i, atr_series=r_i["atr_series"])
109
+ except Exception:
110
+ continue
111
+
112
+ sc_i = compute_structure_score(r_i)
113
+ direction = infer_direction(r_i["trend"], v_i["breakout"])
114
+ vetoed, _ = apply_veto(r_i, v_i, sc_i, direction=direction)
115
+
116
+ # Only label bars that the rule engine would have flagged as signals
117
+ is_signal = not vetoed and r_i["regime_confidence"] > 0.3
118
+
119
+ scores = score_token(r_i, v_i, vetoed=False) # compute scores even if vetoed
120
+
121
+ try:
122
+ feat = build_feature_dict(r_i, v_i, scores)
123
+ except (KeyError, ValueError):
124
+ continue
125
+
126
+ if not validate_features(feat):
127
+ continue
128
+
129
+ feat["_symbol"] = symbol
130
+ feat["_bar_idx"] = i
131
+ feat["_timestamp"] = df.index[i]
132
+ feat["_is_signal"] = int(is_signal)
133
+ feat["_direction"] = direction
134
+ feat["_atr"] = float(r_i["atr"])
135
+ rows.append(feat)
136
+
137
+ if not rows:
138
+ return pd.DataFrame()
139
+
140
+ result = pd.DataFrame(rows)
141
+
142
+ # Label: compute forward outcomes for signal bars
143
+ signal_mask_full = pd.Series(False, index=df.index)
144
+ direction_full = pd.Series(0, index=df.index)
145
+ atr_full = atr_series
146
+
147
+ for row in rows:
148
+ if row["_is_signal"]:
149
+ idx = df.index[row["_bar_idx"]]
150
+ signal_mask_full[idx] = True
151
+ direction_full[idx] = row["_direction"]
152
+
153
+ labels = label_dataframe(
154
+ df=df,
155
+ signal_mask=signal_mask_full,
156
+ atr_series=atr_full,
157
+ direction_series=direction_full,
158
+ forward_bars=LABEL_FORWARD_BARS,
159
+ )
160
+
161
+ # Merge labels back into result
162
+ result = result.set_index("_timestamp")
163
+ result["label"] = labels.reindex(result.index)
164
+ result = result.reset_index().rename(columns={"index": "_timestamp"})
165
+
166
+ # Keep only signal bars with valid labels
167
+ result = result[result["_is_signal"] == 1].copy()
168
+ result = result.dropna(subset=["label"])
169
+ result["label"] = result["label"].astype(int)
170
+
171
+ logger.info(
172
+ f"{symbol}: {len(result)} labeled signals — "
173
+ f"wr={result['label'].mean():.3f}"
174
+ )
175
+ return result
176
+
177
+
178
+ def build_dataset(
179
+ symbols: list,
180
+ bars: int = CANDLE_LIMIT,
181
+ data_dir: Path = None,
182
+ ) -> pd.DataFrame:
183
+ """Fetch data and build full labeled feature dataset."""
184
+ all_frames = []
185
+
186
+ if data_dir and data_dir.exists():
187
+ logger.info(f"Loading CSVs from {data_dir}")
188
+ for csv_path in sorted(data_dir.glob("*.csv")):
189
+ sym = csv_path.stem
190
+ df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
191
+ df.index = pd.to_datetime(df.index, utc=True)
192
+ df.sort_index(inplace=True)
193
+ frame = extract_features_and_labels(sym, df)
194
+ if not frame.empty:
195
+ all_frames.append(frame)
196
+ else:
197
+ logger.info(f"Fetching OHLCV for {len(symbols)} symbols ({bars} bars each)")
198
+ ohlcv_map = fetch_multiple(symbols, limit=bars, min_bars=60)
199
+ for sym, df in ohlcv_map.items():
200
+ frame = extract_features_and_labels(sym, df)
201
+ if not frame.empty:
202
+ all_frames.append(frame)
203
+
204
+ if not all_frames:
205
+ raise ValueError("No labeled data produced. Check symbols and API connectivity.")
206
+
207
+ combined = pd.concat(all_frames, ignore_index=True)
208
+ combined.sort_values("_timestamp", inplace=True)
209
+ combined.reset_index(drop=True, inplace=True)
210
+ logger.info(
211
+ f"Dataset: {len(combined)} samples across {len(all_frames)} symbols | "
212
+ f"overall wr={combined['label'].mean():.3f}"
213
+ )
214
+ return combined
215
+
216
+
217
+ def fit_final_model(
218
+ X: np.ndarray,
219
+ y: np.ndarray,
220
+ params: dict,
221
+ val_frac: float = 0.15,
222
+ ) -> ModelBackend:
223
+ """Fit final model on full dataset with internal validation split."""
224
+ split = int(len(X) * (1 - val_frac))
225
+ X_tr, y_tr = X[:split], y[:split]
226
+ X_va, y_va = X[split:], y[split:]
227
+
228
+ pos_frac = y_tr.mean()
229
+ sample_weight = None
230
+ if 0.05 < pos_frac < 0.95:
231
+ sample_weight = np.where(y_tr == 1, 1.0 / pos_frac, 1.0 / (1 - pos_frac))
232
+
233
+ backend = ModelBackend(params=params, calibrate=True)
234
+ backend.fit(X_tr, y_tr, X_va, y_va, sample_weight=sample_weight)
235
+ logger.info(f"Final model: {backend.n_iter_} boosting rounds, backend={backend.backend_name}")
236
+ return backend
237
+
238
+
239
+ def save_artifacts(
240
+ backend: ModelBackend,
241
+ threshold: float,
242
+ summary: dict,
243
+ dataset: pd.DataFrame,
244
+ ):
245
+ import joblib
246
+
247
+ ML_DIR.mkdir(parents=True, exist_ok=True)
248
+
249
+ # Save model
250
+ joblib.dump(backend, MODEL_PATH)
251
+ logger.info(f"Model saved → {MODEL_PATH}")
252
+
253
+ # Save threshold
254
+ thresh_data = {
255
+ "threshold": threshold,
256
+ "objective": THRESHOLD_OBJECTIVE,
257
+ "n_folds_used": summary.get("n_folds", 0),
258
+ "mean_test_expectancy": summary.get("mean_expectancy"),
259
+ "mean_test_sharpe": summary.get("mean_sharpe"),
260
+ "mean_test_precision": summary.get("mean_precision"),
261
+ }
262
+ with open(THRESHOLD_PATH, "w") as f:
263
+ json.dump(thresh_data, f, indent=2)
264
+ logger.info(f"Threshold saved → {THRESHOLD_PATH} (value={threshold:.4f})")
265
+
266
+ # Save feature importances
267
+ imp_df = pd.DataFrame({
268
+ "feature": FEATURE_COLUMNS,
269
+ "importance": backend.feature_importances_,
270
+ }).sort_values("importance", ascending=False)
271
+ imp_df.to_csv(FEATURE_IMP_PATH, index=False)
272
+ logger.info(f"Feature importances saved → {FEATURE_IMP_PATH}")
273
+
274
+ # Save label stats
275
+ label_stats = compute_label_stats(pd.Series(dataset["label"].values))
276
+ with open(LABEL_PATH, "w") as f:
277
+ json.dump(label_stats, f, indent=2)
278
+ logger.info(f"Label stats: {label_stats}")
279
+
280
+
281
+ def main(args):
282
+ logger.info("=" * 60)
283
+ logger.info("OKX TRADE FILTER — TRAINING PIPELINE")
284
+ logger.info("=" * 60)
285
+
286
+ if args.use_defaults:
287
+ symbols = DEFAULT_SYMBOLS
288
+ elif args.symbols:
289
+ symbols = args.symbols
290
+ else:
291
+ symbols = DEFAULT_SYMBOLS[:20] # safe default for quick runs
292
+
293
+ data_dir = Path(args.data_dir) if args.data_dir else None
294
+ dataset = build_dataset(symbols, bars=args.bars, data_dir=data_dir)
295
+
296
+ X = dataset[FEATURE_COLUMNS].values.astype(np.float64)
297
+ y = dataset["label"].values.astype(np.int32)
298
+ timestamps = dataset["_timestamp"].values
299
+
300
+ logger.info(f"Feature matrix: {X.shape} | Positive rate: {y.mean():.4f}")
301
+
302
+ # Walk-forward validation
303
+ logger.info("Running walk-forward validation...")
304
+ wf_results = run_walk_forward(X, y, timestamps=timestamps, params=LGBM_PARAMS)
305
+ summary = summarize_walk_forward(wf_results)
306
+
307
+ logger.info("\n=== WALK-FORWARD SUMMARY ===")
308
+ logger.info(f" Folds: {summary['n_folds']}")
309
+ logger.info(f" Mean threshold: {summary['mean_threshold']:.4f} ± {summary['std_threshold']:.4f}")
310
+ logger.info(f" Mean expectancy: {summary['mean_expectancy']}")
311
+ logger.info(f" Mean sharpe: {summary['mean_sharpe']}")
312
+ logger.info(f" Mean precision: {summary['mean_precision']}")
313
+
314
+ if summary.get("mean_expectancy") is not None and summary["mean_expectancy"] < 0:
315
+ logger.warning("Negative mean expectancy! Model may not generalize. Check data quality.")
316
+
317
+ # Choose final threshold: mean of walk-forward optimal thresholds
318
+ final_threshold = summary["mean_threshold"]
319
+ logger.info(f"\nFinal threshold: {final_threshold:.4f}")
320
+
321
+ # Feature importance report
322
+ imp_arr = np.array(summary["avg_feature_importance"])
323
+ imp_pairs = sorted(zip(FEATURE_COLUMNS, imp_arr), key=lambda x: x[1], reverse=True)
324
+ logger.info("\n=== TOP 15 FEATURES BY IMPORTANCE ===")
325
+ for feat, imp in imp_pairs[:15]:
326
+ bar = "█" * int(imp / imp_arr.max() * 30) if imp_arr.max() > 0 else ""
327
+ logger.info(f" {feat:<28} {imp:>8.2f} {bar}")
328
+
329
+ # Fit final model on all data
330
+ logger.info("\nFitting final model on full dataset...")
331
+ final_backend = fit_final_model(X, y, LGBM_PARAMS, val_frac=0.15)
332
+
333
+ # Save everything
334
+ save_artifacts(final_backend, final_threshold, summary, dataset)
335
+ logger.info("\n✓ Training complete.")
336
+
337
+
338
+ if __name__ == "__main__":
339
+ parser = argparse.ArgumentParser(description="Train OKX trade probability filter")
340
+ parser.add_argument("--symbols", nargs="+", default=None, help="Symbol list, e.g. BTC-USDT ETH-USDT")
341
+ parser.add_argument("--use-defaults", action="store_true", help="Use all DEFAULT_SYMBOLS from config")
342
+ parser.add_argument("--bars", type=int, default=300, help="OHLCV bars to fetch per symbol")
343
+ parser.add_argument("--data-dir", type=str, default=None, help="Directory of pre-saved CSV files")
344
+ args = parser.parse_args()
345
+ main(args)