""" train.py — Full training pipeline. Run this script to train the model. Usage: python train.py --symbols BTC-USDT ETH-USDT SOL-USDT ... --bars 500 python train.py --use-defaults --bars 300 python train.py --data-dir ./historical_csv # load pre-saved CSVs Pipeline: 1. Fetch OHLCV for all symbols 2. Run rule engine to extract features (no lookahead) 3. Label each signal bar with forward-looking outcome 4. Concatenate all symbols (adds cross-asset diversity) 5. Walk-forward validation → choose threshold 6. Final model fit on full dataset 7. Save model + threshold + feature importances """ import argparse import json import logging import sys from pathlib import Path import numpy as np import pandas as pd sys.path.insert(0, str(Path(__file__).parent)) from config import DEFAULT_SYMBOLS, TIMEFRAME, CANDLE_LIMIT from data_fetcher import fetch_multiple from regime import detect_regime from volume_analysis import analyze_volume from scorer import compute_structure_score, score_token from veto import apply_veto from feature_builder import build_feature_dict, validate_features from labeler import label_dataframe, compute_label_stats from walk_forward import run_walk_forward, summarize_walk_forward from model_backend import ModelBackend from ml_config import ( ML_DIR, MODEL_PATH, THRESHOLD_PATH, FEATURE_IMP_PATH, LABEL_PATH, LGBM_PARAMS, FEATURE_COLUMNS, LABEL_FORWARD_BARS, THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEPS, THRESHOLD_OBJECTIVE, STOP_MULT, ) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", stream=sys.stdout, ) logger = logging.getLogger("train") def infer_direction(trend: str, breakout: int) -> int: if trend == "bullish" or breakout == 1: return 1 if trend == "bearish" or breakout == -1: return -1 return 0 def extract_features_and_labels( symbol: str, df: pd.DataFrame, ) -> pd.DataFrame: """ Run the full rule engine over a DataFrame, bar by bar (forward-only). Returns a DataFrame with feature columns + 'label' + 'direction' + 'timestamp'. Implementation note: we compute regime/volume/scores using the full historical series up to each bar — no information from future bars is ever used. The label is computed separately using FORWARD bars only. """ if len(df) < 60: logger.warning(f"{symbol}: too short ({len(df)} bars), skipping") return pd.DataFrame() # Compute full-series regime and volume (these use only past data internally) try: regime_data = detect_regime(df) volume_data = analyze_volume(df, atr_series=regime_data["atr_series"]) except Exception as e: logger.error(f"{symbol}: rule engine error: {e}") return pd.DataFrame() atr_series = regime_data["atr_series"] # Build per-bar feature rows for all bars with valid ATR (skip first ATR_PERIOD) rows = [] n = len(df) for i in range(30, n): # Slice up to bar i (inclusive) — simulate running bar by bar df_i = df.iloc[: i + 1] try: r_i = detect_regime(df_i) v_i = analyze_volume(df_i, atr_series=r_i["atr_series"]) except Exception: continue sc_i = compute_structure_score(r_i) direction = infer_direction(r_i["trend"], v_i["breakout"]) vetoed, _ = apply_veto(r_i, v_i, sc_i, direction=direction) # Only label bars that the rule engine would have flagged as signals is_signal = not vetoed and r_i["regime_confidence"] > 0.3 scores = score_token(r_i, v_i, vetoed=False) # compute scores even if vetoed try: feat = build_feature_dict(r_i, v_i, scores) except (KeyError, ValueError): continue if not validate_features(feat): continue feat["_symbol"] = symbol feat["_bar_idx"] = i feat["_timestamp"] = df.index[i] feat["_is_signal"] = int(is_signal) feat["_direction"] = direction feat["_atr"] = float(r_i["atr"]) rows.append(feat) if not rows: return pd.DataFrame() result = pd.DataFrame(rows) # Label: compute forward outcomes for signal bars signal_mask_full = pd.Series(False, index=df.index) direction_full = pd.Series(0, index=df.index) atr_full = atr_series for row in rows: if row["_is_signal"]: idx = df.index[row["_bar_idx"]] signal_mask_full[idx] = True direction_full[idx] = row["_direction"] labels = label_dataframe( df=df, signal_mask=signal_mask_full, atr_series=atr_full, direction_series=direction_full, forward_bars=LABEL_FORWARD_BARS, ) # Merge labels back into result result = result.set_index("_timestamp") result["label"] = labels.reindex(result.index) result = result.reset_index().rename(columns={"index": "_timestamp"}) # Keep only signal bars with valid labels result = result[result["_is_signal"] == 1].copy() result = result.dropna(subset=["label"]) result["label"] = result["label"].astype(int) logger.info( f"{symbol}: {len(result)} labeled signals — " f"wr={result['label'].mean():.3f}" ) return result def build_dataset( symbols: list, bars: int = CANDLE_LIMIT, data_dir: Path = None, ) -> pd.DataFrame: """Fetch data and build full labeled feature dataset.""" all_frames = [] if data_dir and data_dir.exists(): logger.info(f"Loading CSVs from {data_dir}") for csv_path in sorted(data_dir.glob("*.csv")): sym = csv_path.stem df = pd.read_csv(csv_path, index_col=0, parse_dates=True) df.index = pd.to_datetime(df.index, utc=True) df.sort_index(inplace=True) frame = extract_features_and_labels(sym, df) if not frame.empty: all_frames.append(frame) else: logger.info(f"Fetching OHLCV for {len(symbols)} symbols ({bars} bars each)") ohlcv_map = fetch_multiple(symbols, limit=bars, min_bars=60) for sym, df in ohlcv_map.items(): frame = extract_features_and_labels(sym, df) if not frame.empty: all_frames.append(frame) if not all_frames: raise ValueError("No labeled data produced. Check symbols and API connectivity.") combined = pd.concat(all_frames, ignore_index=True) combined.sort_values("_timestamp", inplace=True) combined.reset_index(drop=True, inplace=True) logger.info( f"Dataset: {len(combined)} samples across {len(all_frames)} symbols | " f"overall wr={combined['label'].mean():.3f}" ) return combined def fit_final_model( X: np.ndarray, y: np.ndarray, params: dict, val_frac: float = 0.15, ) -> ModelBackend: """Fit final model on full dataset with internal validation split.""" split = int(len(X) * (1 - val_frac)) X_tr, y_tr = X[:split], y[:split] X_va, y_va = X[split:], y[split:] pos_frac = y_tr.mean() sample_weight = None if 0.05 < pos_frac < 0.95: sample_weight = np.where(y_tr == 1, 1.0 / pos_frac, 1.0 / (1 - pos_frac)) backend = ModelBackend(params=params, calibrate=True) backend.fit(X_tr, y_tr, X_va, y_va, sample_weight=sample_weight) logger.info(f"Final model: {backend.n_iter_} boosting rounds, backend={backend.backend_name}") return backend def save_artifacts( backend: ModelBackend, threshold: float, summary: dict, dataset: pd.DataFrame, ): import joblib ML_DIR.mkdir(parents=True, exist_ok=True) # Save model joblib.dump(backend, MODEL_PATH) logger.info(f"Model saved → {MODEL_PATH}") # Save threshold thresh_data = { "threshold": threshold, "objective": THRESHOLD_OBJECTIVE, "n_folds_used": summary.get("n_folds", 0), "mean_test_expectancy": summary.get("mean_expectancy"), "mean_test_sharpe": summary.get("mean_sharpe"), "mean_test_precision": summary.get("mean_precision"), } with open(THRESHOLD_PATH, "w") as f: json.dump(thresh_data, f, indent=2) logger.info(f"Threshold saved → {THRESHOLD_PATH} (value={threshold:.4f})") # Save feature importances imp_df = pd.DataFrame({ "feature": FEATURE_COLUMNS, "importance": backend.feature_importances_, }).sort_values("importance", ascending=False) imp_df.to_csv(FEATURE_IMP_PATH, index=False) logger.info(f"Feature importances saved → {FEATURE_IMP_PATH}") # Save label stats label_stats = compute_label_stats(pd.Series(dataset["label"].values)) with open(LABEL_PATH, "w") as f: json.dump(label_stats, f, indent=2) logger.info(f"Label stats: {label_stats}") def main(args): logger.info("=" * 60) logger.info("OKX TRADE FILTER — TRAINING PIPELINE") logger.info("=" * 60) if args.use_defaults: symbols = DEFAULT_SYMBOLS elif args.symbols: symbols = args.symbols else: symbols = DEFAULT_SYMBOLS[:20] # safe default for quick runs data_dir = Path(args.data_dir) if args.data_dir else None dataset = build_dataset(symbols, bars=args.bars, data_dir=data_dir) X = dataset[FEATURE_COLUMNS].values.astype(np.float64) y = dataset["label"].values.astype(np.int32) timestamps = dataset["_timestamp"].values logger.info(f"Feature matrix: {X.shape} | Positive rate: {y.mean():.4f}") # Walk-forward validation logger.info("Running walk-forward validation...") wf_results = run_walk_forward(X, y, timestamps=timestamps, params=LGBM_PARAMS) summary = summarize_walk_forward(wf_results) logger.info("\n=== WALK-FORWARD SUMMARY ===") logger.info(f" Folds: {summary['n_folds']}") logger.info(f" Mean threshold: {summary['mean_threshold']:.4f} ± {summary['std_threshold']:.4f}") logger.info(f" Mean expectancy: {summary['mean_expectancy']}") logger.info(f" Mean sharpe: {summary['mean_sharpe']}") logger.info(f" Mean precision: {summary['mean_precision']}") if summary.get("mean_expectancy") is not None and summary["mean_expectancy"] < 0: logger.warning("Negative mean expectancy! Model may not generalize. Check data quality.") # Choose final threshold: mean of walk-forward optimal thresholds final_threshold = summary["mean_threshold"] logger.info(f"\nFinal threshold: {final_threshold:.4f}") # Feature importance report imp_arr = np.array(summary["avg_feature_importance"]) imp_pairs = sorted(zip(FEATURE_COLUMNS, imp_arr), key=lambda x: x[1], reverse=True) logger.info("\n=== TOP 15 FEATURES BY IMPORTANCE ===") for feat, imp in imp_pairs[:15]: bar = "█" * int(imp / imp_arr.max() * 30) if imp_arr.max() > 0 else "" logger.info(f" {feat:<28} {imp:>8.2f} {bar}") # Fit final model on all data logger.info("\nFitting final model on full dataset...") final_backend = fit_final_model(X, y, LGBM_PARAMS, val_frac=0.15) # Save everything save_artifacts(final_backend, final_threshold, summary, dataset) logger.info("\n✓ Training complete.") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train OKX trade probability filter") parser.add_argument("--symbols", nargs="+", default=None, help="Symbol list, e.g. BTC-USDT ETH-USDT") parser.add_argument("--use-defaults", action="store_true", help="Use all DEFAULT_SYMBOLS from config") parser.add_argument("--bars", type=int, default=300, help="OHLCV bars to fetch per symbol") parser.add_argument("--data-dir", type=str, default=None, help="Directory of pre-saved CSV files") args = parser.parse_args() main(args)