File size: 11,950 Bytes
cfb37af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
"""
train.py β€” Full training pipeline. Run this script to train the model.

Usage:
    python train.py --symbols BTC-USDT ETH-USDT SOL-USDT ... --bars 500
    python train.py --use-defaults --bars 300
    python train.py --data-dir ./historical_csv   # load pre-saved CSVs

Pipeline:
    1. Fetch OHLCV for all symbols
    2. Run rule engine to extract features (no lookahead)
    3. Label each signal bar with forward-looking outcome
    4. Concatenate all symbols (adds cross-asset diversity)
    5. Walk-forward validation β†’ choose threshold
    6. Final model fit on full dataset
    7. Save model + threshold + feature importances
"""

import argparse
import json
import logging
import sys
from pathlib import Path

import numpy as np
import pandas as pd

sys.path.insert(0, str(Path(__file__).parent))

from config import DEFAULT_SYMBOLS, TIMEFRAME, CANDLE_LIMIT
from data_fetcher import fetch_multiple
from regime import detect_regime
from volume_analysis import analyze_volume
from scorer import compute_structure_score, score_token
from veto import apply_veto
from feature_builder import build_feature_dict, validate_features
from labeler import label_dataframe, compute_label_stats
from walk_forward import run_walk_forward, summarize_walk_forward
from model_backend import ModelBackend
from ml_config import (
    ML_DIR,
    MODEL_PATH,
    THRESHOLD_PATH,
    FEATURE_IMP_PATH,
    LABEL_PATH,
    LGBM_PARAMS,
    FEATURE_COLUMNS,
    LABEL_FORWARD_BARS,
    THRESHOLD_MIN,
    THRESHOLD_MAX,
    THRESHOLD_STEPS,
    THRESHOLD_OBJECTIVE,
    STOP_MULT,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    stream=sys.stdout,
)
logger = logging.getLogger("train")


def infer_direction(trend: str, breakout: int) -> int:
    if trend == "bullish" or breakout == 1:
        return 1
    if trend == "bearish" or breakout == -1:
        return -1
    return 0


def extract_features_and_labels(
    symbol: str,
    df: pd.DataFrame,
) -> pd.DataFrame:
    """
    Run the full rule engine over a DataFrame, bar by bar (forward-only).
    Returns a DataFrame with feature columns + 'label' + 'direction' + 'timestamp'.

    Implementation note: we compute regime/volume/scores using the full
    historical series up to each bar β€” no information from future bars
    is ever used. The label is computed separately using FORWARD bars only.
    """
    if len(df) < 60:
        logger.warning(f"{symbol}: too short ({len(df)} bars), skipping")
        return pd.DataFrame()

    # Compute full-series regime and volume (these use only past data internally)
    try:
        regime_data  = detect_regime(df)
        volume_data  = analyze_volume(df, atr_series=regime_data["atr_series"])
    except Exception as e:
        logger.error(f"{symbol}: rule engine error: {e}")
        return pd.DataFrame()

    atr_series = regime_data["atr_series"]

    # Build per-bar feature rows for all bars with valid ATR (skip first ATR_PERIOD)
    rows = []
    n = len(df)

    for i in range(30, n):
        # Slice up to bar i (inclusive) β€” simulate running bar by bar
        df_i = df.iloc[: i + 1]

        try:
            r_i = detect_regime(df_i)
            v_i = analyze_volume(df_i, atr_series=r_i["atr_series"])
        except Exception:
            continue

        sc_i = compute_structure_score(r_i)
        direction = infer_direction(r_i["trend"], v_i["breakout"])
        vetoed, _ = apply_veto(r_i, v_i, sc_i, direction=direction)

        # Only label bars that the rule engine would have flagged as signals
        is_signal = not vetoed and r_i["regime_confidence"] > 0.3

        scores = score_token(r_i, v_i, vetoed=False)  # compute scores even if vetoed

        try:
            feat = build_feature_dict(r_i, v_i, scores)
        except (KeyError, ValueError):
            continue

        if not validate_features(feat):
            continue

        feat["_symbol"]    = symbol
        feat["_bar_idx"]   = i
        feat["_timestamp"] = df.index[i]
        feat["_is_signal"] = int(is_signal)
        feat["_direction"] = direction
        feat["_atr"]       = float(r_i["atr"])
        rows.append(feat)

    if not rows:
        return pd.DataFrame()

    result = pd.DataFrame(rows)

    # Label: compute forward outcomes for signal bars
    signal_mask_full = pd.Series(False, index=df.index)
    direction_full   = pd.Series(0, index=df.index)
    atr_full         = atr_series

    for row in rows:
        if row["_is_signal"]:
            idx = df.index[row["_bar_idx"]]
            signal_mask_full[idx] = True
            direction_full[idx]   = row["_direction"]

    labels = label_dataframe(
        df=df,
        signal_mask=signal_mask_full,
        atr_series=atr_full,
        direction_series=direction_full,
        forward_bars=LABEL_FORWARD_BARS,
    )

    # Merge labels back into result
    result = result.set_index("_timestamp")
    result["label"] = labels.reindex(result.index)
    result = result.reset_index().rename(columns={"index": "_timestamp"})

    # Keep only signal bars with valid labels
    result = result[result["_is_signal"] == 1].copy()
    result = result.dropna(subset=["label"])
    result["label"] = result["label"].astype(int)

    logger.info(
        f"{symbol}: {len(result)} labeled signals β€” "
        f"wr={result['label'].mean():.3f}"
    )
    return result


def build_dataset(
    symbols: list,
    bars: int = CANDLE_LIMIT,
    data_dir: Path = None,
) -> pd.DataFrame:
    """Fetch data and build full labeled feature dataset."""
    all_frames = []

    if data_dir and data_dir.exists():
        logger.info(f"Loading CSVs from {data_dir}")
        for csv_path in sorted(data_dir.glob("*.csv")):
            sym = csv_path.stem
            df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
            df.index = pd.to_datetime(df.index, utc=True)
            df.sort_index(inplace=True)
            frame = extract_features_and_labels(sym, df)
            if not frame.empty:
                all_frames.append(frame)
    else:
        logger.info(f"Fetching OHLCV for {len(symbols)} symbols ({bars} bars each)")
        ohlcv_map = fetch_multiple(symbols, limit=bars, min_bars=60)
        for sym, df in ohlcv_map.items():
            frame = extract_features_and_labels(sym, df)
            if not frame.empty:
                all_frames.append(frame)

    if not all_frames:
        raise ValueError("No labeled data produced. Check symbols and API connectivity.")

    combined = pd.concat(all_frames, ignore_index=True)
    combined.sort_values("_timestamp", inplace=True)
    combined.reset_index(drop=True, inplace=True)
    logger.info(
        f"Dataset: {len(combined)} samples across {len(all_frames)} symbols | "
        f"overall wr={combined['label'].mean():.3f}"
    )
    return combined


def fit_final_model(
    X: np.ndarray,
    y: np.ndarray,
    params: dict,
    val_frac: float = 0.15,
) -> ModelBackend:
    """Fit final model on full dataset with internal validation split."""
    split = int(len(X) * (1 - val_frac))
    X_tr, y_tr = X[:split], y[:split]
    X_va, y_va = X[split:], y[split:]

    pos_frac = y_tr.mean()
    sample_weight = None
    if 0.05 < pos_frac < 0.95:
        sample_weight = np.where(y_tr == 1, 1.0 / pos_frac, 1.0 / (1 - pos_frac))

    backend = ModelBackend(params=params, calibrate=True)
    backend.fit(X_tr, y_tr, X_va, y_va, sample_weight=sample_weight)
    logger.info(f"Final model: {backend.n_iter_} boosting rounds, backend={backend.backend_name}")
    return backend


def save_artifacts(
    backend: ModelBackend,
    threshold: float,
    summary: dict,
    dataset: pd.DataFrame,
):
    import joblib

    ML_DIR.mkdir(parents=True, exist_ok=True)

    # Save model
    joblib.dump(backend, MODEL_PATH)
    logger.info(f"Model saved β†’ {MODEL_PATH}")

    # Save threshold
    thresh_data = {
        "threshold": threshold,
        "objective": THRESHOLD_OBJECTIVE,
        "n_folds_used": summary.get("n_folds", 0),
        "mean_test_expectancy": summary.get("mean_expectancy"),
        "mean_test_sharpe": summary.get("mean_sharpe"),
        "mean_test_precision": summary.get("mean_precision"),
    }
    with open(THRESHOLD_PATH, "w") as f:
        json.dump(thresh_data, f, indent=2)
    logger.info(f"Threshold saved β†’ {THRESHOLD_PATH} (value={threshold:.4f})")

    # Save feature importances
    imp_df = pd.DataFrame({
        "feature": FEATURE_COLUMNS,
        "importance": backend.feature_importances_,
    }).sort_values("importance", ascending=False)
    imp_df.to_csv(FEATURE_IMP_PATH, index=False)
    logger.info(f"Feature importances saved β†’ {FEATURE_IMP_PATH}")

    # Save label stats
    label_stats = compute_label_stats(pd.Series(dataset["label"].values))
    with open(LABEL_PATH, "w") as f:
        json.dump(label_stats, f, indent=2)
    logger.info(f"Label stats: {label_stats}")


def main(args):
    logger.info("=" * 60)
    logger.info("OKX TRADE FILTER β€” TRAINING PIPELINE")
    logger.info("=" * 60)

    if args.use_defaults:
        symbols = DEFAULT_SYMBOLS
    elif args.symbols:
        symbols = args.symbols
    else:
        symbols = DEFAULT_SYMBOLS[:20]  # safe default for quick runs

    data_dir = Path(args.data_dir) if args.data_dir else None
    dataset  = build_dataset(symbols, bars=args.bars, data_dir=data_dir)

    X = dataset[FEATURE_COLUMNS].values.astype(np.float64)
    y = dataset["label"].values.astype(np.int32)
    timestamps = dataset["_timestamp"].values

    logger.info(f"Feature matrix: {X.shape} | Positive rate: {y.mean():.4f}")

    # Walk-forward validation
    logger.info("Running walk-forward validation...")
    wf_results = run_walk_forward(X, y, timestamps=timestamps, params=LGBM_PARAMS)
    summary    = summarize_walk_forward(wf_results)

    logger.info("\n=== WALK-FORWARD SUMMARY ===")
    logger.info(f"  Folds:            {summary['n_folds']}")
    logger.info(f"  Mean threshold:   {summary['mean_threshold']:.4f} Β± {summary['std_threshold']:.4f}")
    logger.info(f"  Mean expectancy:  {summary['mean_expectancy']}")
    logger.info(f"  Mean sharpe:      {summary['mean_sharpe']}")
    logger.info(f"  Mean precision:   {summary['mean_precision']}")

    if summary.get("mean_expectancy") is not None and summary["mean_expectancy"] < 0:
        logger.warning("Negative mean expectancy! Model may not generalize. Check data quality.")

    # Choose final threshold: mean of walk-forward optimal thresholds
    final_threshold = summary["mean_threshold"]
    logger.info(f"\nFinal threshold: {final_threshold:.4f}")

    # Feature importance report
    imp_arr = np.array(summary["avg_feature_importance"])
    imp_pairs = sorted(zip(FEATURE_COLUMNS, imp_arr), key=lambda x: x[1], reverse=True)
    logger.info("\n=== TOP 15 FEATURES BY IMPORTANCE ===")
    for feat, imp in imp_pairs[:15]:
        bar = "β–ˆ" * int(imp / imp_arr.max() * 30) if imp_arr.max() > 0 else ""
        logger.info(f"  {feat:<28} {imp:>8.2f}  {bar}")

    # Fit final model on all data
    logger.info("\nFitting final model on full dataset...")
    final_backend = fit_final_model(X, y, LGBM_PARAMS, val_frac=0.15)

    # Save everything
    save_artifacts(final_backend, final_threshold, summary, dataset)
    logger.info("\nβœ“ Training complete.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train OKX trade probability filter")
    parser.add_argument("--symbols", nargs="+", default=None, help="Symbol list, e.g. BTC-USDT ETH-USDT")
    parser.add_argument("--use-defaults", action="store_true", help="Use all DEFAULT_SYMBOLS from config")
    parser.add_argument("--bars", type=int, default=300, help="OHLCV bars to fetch per symbol")
    parser.add_argument("--data-dir", type=str, default=None, help="Directory of pre-saved CSV files")
    args = parser.parse_args()
    main(args)