| | """ |
| | train.py β Full training pipeline. Run this script to train the model. |
| | |
| | Usage: |
| | python train.py --symbols BTC-USDT ETH-USDT SOL-USDT ... --bars 500 |
| | python train.py --use-defaults --bars 300 |
| | python train.py --data-dir ./historical_csv # load pre-saved CSVs |
| | |
| | Pipeline: |
| | 1. Fetch OHLCV for all symbols |
| | 2. Run rule engine to extract features (no lookahead) |
| | 3. Label each signal bar with forward-looking outcome |
| | 4. Concatenate all symbols (adds cross-asset diversity) |
| | 5. Walk-forward validation β choose threshold |
| | 6. Final model fit on full dataset |
| | 7. Save model + threshold + feature importances |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import logging |
| | import sys |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import pandas as pd |
| |
|
| | sys.path.insert(0, str(Path(__file__).parent)) |
| |
|
| | from config import DEFAULT_SYMBOLS, TIMEFRAME, CANDLE_LIMIT |
| | from data_fetcher import fetch_multiple |
| | from regime import detect_regime |
| | from volume_analysis import analyze_volume |
| | from scorer import compute_structure_score, score_token |
| | from veto import apply_veto |
| | from feature_builder import build_feature_dict, validate_features |
| | from labeler import label_dataframe, compute_label_stats |
| | from walk_forward import run_walk_forward, summarize_walk_forward |
| | from model_backend import ModelBackend |
| | from ml_config import ( |
| | ML_DIR, |
| | MODEL_PATH, |
| | THRESHOLD_PATH, |
| | FEATURE_IMP_PATH, |
| | LABEL_PATH, |
| | LGBM_PARAMS, |
| | FEATURE_COLUMNS, |
| | LABEL_FORWARD_BARS, |
| | THRESHOLD_MIN, |
| | THRESHOLD_MAX, |
| | THRESHOLD_STEPS, |
| | THRESHOLD_OBJECTIVE, |
| | STOP_MULT, |
| | ) |
| |
|
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| | stream=sys.stdout, |
| | ) |
| | logger = logging.getLogger("train") |
| |
|
| |
|
| | def infer_direction(trend: str, breakout: int) -> int: |
| | if trend == "bullish" or breakout == 1: |
| | return 1 |
| | if trend == "bearish" or breakout == -1: |
| | return -1 |
| | return 0 |
| |
|
| |
|
| | def extract_features_and_labels( |
| | symbol: str, |
| | df: pd.DataFrame, |
| | ) -> pd.DataFrame: |
| | """ |
| | Run the full rule engine over a DataFrame, bar by bar (forward-only). |
| | Returns a DataFrame with feature columns + 'label' + 'direction' + 'timestamp'. |
| | |
| | Implementation note: we compute regime/volume/scores using the full |
| | historical series up to each bar β no information from future bars |
| | is ever used. The label is computed separately using FORWARD bars only. |
| | """ |
| | if len(df) < 60: |
| | logger.warning(f"{symbol}: too short ({len(df)} bars), skipping") |
| | return pd.DataFrame() |
| |
|
| | |
| | try: |
| | regime_data = detect_regime(df) |
| | volume_data = analyze_volume(df, atr_series=regime_data["atr_series"]) |
| | except Exception as e: |
| | logger.error(f"{symbol}: rule engine error: {e}") |
| | return pd.DataFrame() |
| |
|
| | atr_series = regime_data["atr_series"] |
| |
|
| | |
| | rows = [] |
| | n = len(df) |
| |
|
| | for i in range(30, n): |
| | |
| | df_i = df.iloc[: i + 1] |
| |
|
| | try: |
| | r_i = detect_regime(df_i) |
| | v_i = analyze_volume(df_i, atr_series=r_i["atr_series"]) |
| | except Exception: |
| | continue |
| |
|
| | sc_i = compute_structure_score(r_i) |
| | direction = infer_direction(r_i["trend"], v_i["breakout"]) |
| | vetoed, _ = apply_veto(r_i, v_i, sc_i, direction=direction) |
| |
|
| | |
| | is_signal = not vetoed and r_i["regime_confidence"] > 0.3 |
| |
|
| | scores = score_token(r_i, v_i, vetoed=False) |
| |
|
| | try: |
| | feat = build_feature_dict(r_i, v_i, scores) |
| | except (KeyError, ValueError): |
| | continue |
| |
|
| | if not validate_features(feat): |
| | continue |
| |
|
| | feat["_symbol"] = symbol |
| | feat["_bar_idx"] = i |
| | feat["_timestamp"] = df.index[i] |
| | feat["_is_signal"] = int(is_signal) |
| | feat["_direction"] = direction |
| | feat["_atr"] = float(r_i["atr"]) |
| | rows.append(feat) |
| |
|
| | if not rows: |
| | return pd.DataFrame() |
| |
|
| | result = pd.DataFrame(rows) |
| |
|
| | |
| | signal_mask_full = pd.Series(False, index=df.index) |
| | direction_full = pd.Series(0, index=df.index) |
| | atr_full = atr_series |
| |
|
| | for row in rows: |
| | if row["_is_signal"]: |
| | idx = df.index[row["_bar_idx"]] |
| | signal_mask_full[idx] = True |
| | direction_full[idx] = row["_direction"] |
| |
|
| | labels = label_dataframe( |
| | df=df, |
| | signal_mask=signal_mask_full, |
| | atr_series=atr_full, |
| | direction_series=direction_full, |
| | forward_bars=LABEL_FORWARD_BARS, |
| | ) |
| |
|
| | |
| | result = result.set_index("_timestamp") |
| | result["label"] = labels.reindex(result.index) |
| | result = result.reset_index().rename(columns={"index": "_timestamp"}) |
| |
|
| | |
| | result = result[result["_is_signal"] == 1].copy() |
| | result = result.dropna(subset=["label"]) |
| | result["label"] = result["label"].astype(int) |
| |
|
| | logger.info( |
| | f"{symbol}: {len(result)} labeled signals β " |
| | f"wr={result['label'].mean():.3f}" |
| | ) |
| | return result |
| |
|
| |
|
| | def build_dataset( |
| | symbols: list, |
| | bars: int = CANDLE_LIMIT, |
| | data_dir: Path = None, |
| | ) -> pd.DataFrame: |
| | """Fetch data and build full labeled feature dataset.""" |
| | all_frames = [] |
| |
|
| | if data_dir and data_dir.exists(): |
| | logger.info(f"Loading CSVs from {data_dir}") |
| | for csv_path in sorted(data_dir.glob("*.csv")): |
| | sym = csv_path.stem |
| | df = pd.read_csv(csv_path, index_col=0, parse_dates=True) |
| | df.index = pd.to_datetime(df.index, utc=True) |
| | df.sort_index(inplace=True) |
| | frame = extract_features_and_labels(sym, df) |
| | if not frame.empty: |
| | all_frames.append(frame) |
| | else: |
| | logger.info(f"Fetching OHLCV for {len(symbols)} symbols ({bars} bars each)") |
| | ohlcv_map = fetch_multiple(symbols, limit=bars, min_bars=60) |
| | for sym, df in ohlcv_map.items(): |
| | frame = extract_features_and_labels(sym, df) |
| | if not frame.empty: |
| | all_frames.append(frame) |
| |
|
| | if not all_frames: |
| | raise ValueError("No labeled data produced. Check symbols and API connectivity.") |
| |
|
| | combined = pd.concat(all_frames, ignore_index=True) |
| | combined.sort_values("_timestamp", inplace=True) |
| | combined.reset_index(drop=True, inplace=True) |
| | logger.info( |
| | f"Dataset: {len(combined)} samples across {len(all_frames)} symbols | " |
| | f"overall wr={combined['label'].mean():.3f}" |
| | ) |
| | return combined |
| |
|
| |
|
| | def fit_final_model( |
| | X: np.ndarray, |
| | y: np.ndarray, |
| | params: dict, |
| | val_frac: float = 0.15, |
| | ) -> ModelBackend: |
| | """Fit final model on full dataset with internal validation split.""" |
| | split = int(len(X) * (1 - val_frac)) |
| | X_tr, y_tr = X[:split], y[:split] |
| | X_va, y_va = X[split:], y[split:] |
| |
|
| | pos_frac = y_tr.mean() |
| | sample_weight = None |
| | if 0.05 < pos_frac < 0.95: |
| | sample_weight = np.where(y_tr == 1, 1.0 / pos_frac, 1.0 / (1 - pos_frac)) |
| |
|
| | backend = ModelBackend(params=params, calibrate=True) |
| | backend.fit(X_tr, y_tr, X_va, y_va, sample_weight=sample_weight) |
| | logger.info(f"Final model: {backend.n_iter_} boosting rounds, backend={backend.backend_name}") |
| | return backend |
| |
|
| |
|
| | def save_artifacts( |
| | backend: ModelBackend, |
| | threshold: float, |
| | summary: dict, |
| | dataset: pd.DataFrame, |
| | ): |
| | import joblib |
| |
|
| | ML_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | joblib.dump(backend, MODEL_PATH) |
| | logger.info(f"Model saved β {MODEL_PATH}") |
| |
|
| | |
| | thresh_data = { |
| | "threshold": threshold, |
| | "objective": THRESHOLD_OBJECTIVE, |
| | "n_folds_used": summary.get("n_folds", 0), |
| | "mean_test_expectancy": summary.get("mean_expectancy"), |
| | "mean_test_sharpe": summary.get("mean_sharpe"), |
| | "mean_test_precision": summary.get("mean_precision"), |
| | } |
| | with open(THRESHOLD_PATH, "w") as f: |
| | json.dump(thresh_data, f, indent=2) |
| | logger.info(f"Threshold saved β {THRESHOLD_PATH} (value={threshold:.4f})") |
| |
|
| | |
| | imp_df = pd.DataFrame({ |
| | "feature": FEATURE_COLUMNS, |
| | "importance": backend.feature_importances_, |
| | }).sort_values("importance", ascending=False) |
| | imp_df.to_csv(FEATURE_IMP_PATH, index=False) |
| | logger.info(f"Feature importances saved β {FEATURE_IMP_PATH}") |
| |
|
| | |
| | label_stats = compute_label_stats(pd.Series(dataset["label"].values)) |
| | with open(LABEL_PATH, "w") as f: |
| | json.dump(label_stats, f, indent=2) |
| | logger.info(f"Label stats: {label_stats}") |
| |
|
| |
|
| | def main(args): |
| | logger.info("=" * 60) |
| | logger.info("OKX TRADE FILTER β TRAINING PIPELINE") |
| | logger.info("=" * 60) |
| |
|
| | if args.use_defaults: |
| | symbols = DEFAULT_SYMBOLS |
| | elif args.symbols: |
| | symbols = args.symbols |
| | else: |
| | symbols = DEFAULT_SYMBOLS[:20] |
| |
|
| | data_dir = Path(args.data_dir) if args.data_dir else None |
| | dataset = build_dataset(symbols, bars=args.bars, data_dir=data_dir) |
| |
|
| | X = dataset[FEATURE_COLUMNS].values.astype(np.float64) |
| | y = dataset["label"].values.astype(np.int32) |
| | timestamps = dataset["_timestamp"].values |
| |
|
| | logger.info(f"Feature matrix: {X.shape} | Positive rate: {y.mean():.4f}") |
| |
|
| | |
| | logger.info("Running walk-forward validation...") |
| | wf_results = run_walk_forward(X, y, timestamps=timestamps, params=LGBM_PARAMS) |
| | summary = summarize_walk_forward(wf_results) |
| |
|
| | logger.info("\n=== WALK-FORWARD SUMMARY ===") |
| | logger.info(f" Folds: {summary['n_folds']}") |
| | logger.info(f" Mean threshold: {summary['mean_threshold']:.4f} Β± {summary['std_threshold']:.4f}") |
| | logger.info(f" Mean expectancy: {summary['mean_expectancy']}") |
| | logger.info(f" Mean sharpe: {summary['mean_sharpe']}") |
| | logger.info(f" Mean precision: {summary['mean_precision']}") |
| |
|
| | if summary.get("mean_expectancy") is not None and summary["mean_expectancy"] < 0: |
| | logger.warning("Negative mean expectancy! Model may not generalize. Check data quality.") |
| |
|
| | |
| | final_threshold = summary["mean_threshold"] |
| | logger.info(f"\nFinal threshold: {final_threshold:.4f}") |
| |
|
| | |
| | imp_arr = np.array(summary["avg_feature_importance"]) |
| | imp_pairs = sorted(zip(FEATURE_COLUMNS, imp_arr), key=lambda x: x[1], reverse=True) |
| | logger.info("\n=== TOP 15 FEATURES BY IMPORTANCE ===") |
| | for feat, imp in imp_pairs[:15]: |
| | bar = "β" * int(imp / imp_arr.max() * 30) if imp_arr.max() > 0 else "" |
| | logger.info(f" {feat:<28} {imp:>8.2f} {bar}") |
| |
|
| | |
| | logger.info("\nFitting final model on full dataset...") |
| | final_backend = fit_final_model(X, y, LGBM_PARAMS, val_frac=0.15) |
| |
|
| | |
| | save_artifacts(final_backend, final_threshold, summary, dataset) |
| | logger.info("\nβ Training complete.") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Train OKX trade probability filter") |
| | parser.add_argument("--symbols", nargs="+", default=None, help="Symbol list, e.g. BTC-USDT ETH-USDT") |
| | parser.add_argument("--use-defaults", action="store_true", help="Use all DEFAULT_SYMBOLS from config") |
| | parser.add_argument("--bars", type=int, default=300, help="OHLCV bars to fetch per symbol") |
| | parser.add_argument("--data-dir", type=str, default=None, help="Directory of pre-saved CSV files") |
| | args = parser.parse_args() |
| | main(args) |
| |
|