Goshawk_Hedge_Pro / train.py
GoshawkVortexAI's picture
Create train.py
cfb37af verified
"""
train.py β€” Full training pipeline. Run this script to train the model.
Usage:
python train.py --symbols BTC-USDT ETH-USDT SOL-USDT ... --bars 500
python train.py --use-defaults --bars 300
python train.py --data-dir ./historical_csv # load pre-saved CSVs
Pipeline:
1. Fetch OHLCV for all symbols
2. Run rule engine to extract features (no lookahead)
3. Label each signal bar with forward-looking outcome
4. Concatenate all symbols (adds cross-asset diversity)
5. Walk-forward validation β†’ choose threshold
6. Final model fit on full dataset
7. Save model + threshold + feature importances
"""
import argparse
import json
import logging
import sys
from pathlib import Path
import numpy as np
import pandas as pd
sys.path.insert(0, str(Path(__file__).parent))
from config import DEFAULT_SYMBOLS, TIMEFRAME, CANDLE_LIMIT
from data_fetcher import fetch_multiple
from regime import detect_regime
from volume_analysis import analyze_volume
from scorer import compute_structure_score, score_token
from veto import apply_veto
from feature_builder import build_feature_dict, validate_features
from labeler import label_dataframe, compute_label_stats
from walk_forward import run_walk_forward, summarize_walk_forward
from model_backend import ModelBackend
from ml_config import (
ML_DIR,
MODEL_PATH,
THRESHOLD_PATH,
FEATURE_IMP_PATH,
LABEL_PATH,
LGBM_PARAMS,
FEATURE_COLUMNS,
LABEL_FORWARD_BARS,
THRESHOLD_MIN,
THRESHOLD_MAX,
THRESHOLD_STEPS,
THRESHOLD_OBJECTIVE,
STOP_MULT,
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
stream=sys.stdout,
)
logger = logging.getLogger("train")
def infer_direction(trend: str, breakout: int) -> int:
if trend == "bullish" or breakout == 1:
return 1
if trend == "bearish" or breakout == -1:
return -1
return 0
def extract_features_and_labels(
symbol: str,
df: pd.DataFrame,
) -> pd.DataFrame:
"""
Run the full rule engine over a DataFrame, bar by bar (forward-only).
Returns a DataFrame with feature columns + 'label' + 'direction' + 'timestamp'.
Implementation note: we compute regime/volume/scores using the full
historical series up to each bar β€” no information from future bars
is ever used. The label is computed separately using FORWARD bars only.
"""
if len(df) < 60:
logger.warning(f"{symbol}: too short ({len(df)} bars), skipping")
return pd.DataFrame()
# Compute full-series regime and volume (these use only past data internally)
try:
regime_data = detect_regime(df)
volume_data = analyze_volume(df, atr_series=regime_data["atr_series"])
except Exception as e:
logger.error(f"{symbol}: rule engine error: {e}")
return pd.DataFrame()
atr_series = regime_data["atr_series"]
# Build per-bar feature rows for all bars with valid ATR (skip first ATR_PERIOD)
rows = []
n = len(df)
for i in range(30, n):
# Slice up to bar i (inclusive) β€” simulate running bar by bar
df_i = df.iloc[: i + 1]
try:
r_i = detect_regime(df_i)
v_i = analyze_volume(df_i, atr_series=r_i["atr_series"])
except Exception:
continue
sc_i = compute_structure_score(r_i)
direction = infer_direction(r_i["trend"], v_i["breakout"])
vetoed, _ = apply_veto(r_i, v_i, sc_i, direction=direction)
# Only label bars that the rule engine would have flagged as signals
is_signal = not vetoed and r_i["regime_confidence"] > 0.3
scores = score_token(r_i, v_i, vetoed=False) # compute scores even if vetoed
try:
feat = build_feature_dict(r_i, v_i, scores)
except (KeyError, ValueError):
continue
if not validate_features(feat):
continue
feat["_symbol"] = symbol
feat["_bar_idx"] = i
feat["_timestamp"] = df.index[i]
feat["_is_signal"] = int(is_signal)
feat["_direction"] = direction
feat["_atr"] = float(r_i["atr"])
rows.append(feat)
if not rows:
return pd.DataFrame()
result = pd.DataFrame(rows)
# Label: compute forward outcomes for signal bars
signal_mask_full = pd.Series(False, index=df.index)
direction_full = pd.Series(0, index=df.index)
atr_full = atr_series
for row in rows:
if row["_is_signal"]:
idx = df.index[row["_bar_idx"]]
signal_mask_full[idx] = True
direction_full[idx] = row["_direction"]
labels = label_dataframe(
df=df,
signal_mask=signal_mask_full,
atr_series=atr_full,
direction_series=direction_full,
forward_bars=LABEL_FORWARD_BARS,
)
# Merge labels back into result
result = result.set_index("_timestamp")
result["label"] = labels.reindex(result.index)
result = result.reset_index().rename(columns={"index": "_timestamp"})
# Keep only signal bars with valid labels
result = result[result["_is_signal"] == 1].copy()
result = result.dropna(subset=["label"])
result["label"] = result["label"].astype(int)
logger.info(
f"{symbol}: {len(result)} labeled signals β€” "
f"wr={result['label'].mean():.3f}"
)
return result
def build_dataset(
symbols: list,
bars: int = CANDLE_LIMIT,
data_dir: Path = None,
) -> pd.DataFrame:
"""Fetch data and build full labeled feature dataset."""
all_frames = []
if data_dir and data_dir.exists():
logger.info(f"Loading CSVs from {data_dir}")
for csv_path in sorted(data_dir.glob("*.csv")):
sym = csv_path.stem
df = pd.read_csv(csv_path, index_col=0, parse_dates=True)
df.index = pd.to_datetime(df.index, utc=True)
df.sort_index(inplace=True)
frame = extract_features_and_labels(sym, df)
if not frame.empty:
all_frames.append(frame)
else:
logger.info(f"Fetching OHLCV for {len(symbols)} symbols ({bars} bars each)")
ohlcv_map = fetch_multiple(symbols, limit=bars, min_bars=60)
for sym, df in ohlcv_map.items():
frame = extract_features_and_labels(sym, df)
if not frame.empty:
all_frames.append(frame)
if not all_frames:
raise ValueError("No labeled data produced. Check symbols and API connectivity.")
combined = pd.concat(all_frames, ignore_index=True)
combined.sort_values("_timestamp", inplace=True)
combined.reset_index(drop=True, inplace=True)
logger.info(
f"Dataset: {len(combined)} samples across {len(all_frames)} symbols | "
f"overall wr={combined['label'].mean():.3f}"
)
return combined
def fit_final_model(
X: np.ndarray,
y: np.ndarray,
params: dict,
val_frac: float = 0.15,
) -> ModelBackend:
"""Fit final model on full dataset with internal validation split."""
split = int(len(X) * (1 - val_frac))
X_tr, y_tr = X[:split], y[:split]
X_va, y_va = X[split:], y[split:]
pos_frac = y_tr.mean()
sample_weight = None
if 0.05 < pos_frac < 0.95:
sample_weight = np.where(y_tr == 1, 1.0 / pos_frac, 1.0 / (1 - pos_frac))
backend = ModelBackend(params=params, calibrate=True)
backend.fit(X_tr, y_tr, X_va, y_va, sample_weight=sample_weight)
logger.info(f"Final model: {backend.n_iter_} boosting rounds, backend={backend.backend_name}")
return backend
def save_artifacts(
backend: ModelBackend,
threshold: float,
summary: dict,
dataset: pd.DataFrame,
):
import joblib
ML_DIR.mkdir(parents=True, exist_ok=True)
# Save model
joblib.dump(backend, MODEL_PATH)
logger.info(f"Model saved β†’ {MODEL_PATH}")
# Save threshold
thresh_data = {
"threshold": threshold,
"objective": THRESHOLD_OBJECTIVE,
"n_folds_used": summary.get("n_folds", 0),
"mean_test_expectancy": summary.get("mean_expectancy"),
"mean_test_sharpe": summary.get("mean_sharpe"),
"mean_test_precision": summary.get("mean_precision"),
}
with open(THRESHOLD_PATH, "w") as f:
json.dump(thresh_data, f, indent=2)
logger.info(f"Threshold saved β†’ {THRESHOLD_PATH} (value={threshold:.4f})")
# Save feature importances
imp_df = pd.DataFrame({
"feature": FEATURE_COLUMNS,
"importance": backend.feature_importances_,
}).sort_values("importance", ascending=False)
imp_df.to_csv(FEATURE_IMP_PATH, index=False)
logger.info(f"Feature importances saved β†’ {FEATURE_IMP_PATH}")
# Save label stats
label_stats = compute_label_stats(pd.Series(dataset["label"].values))
with open(LABEL_PATH, "w") as f:
json.dump(label_stats, f, indent=2)
logger.info(f"Label stats: {label_stats}")
def main(args):
logger.info("=" * 60)
logger.info("OKX TRADE FILTER β€” TRAINING PIPELINE")
logger.info("=" * 60)
if args.use_defaults:
symbols = DEFAULT_SYMBOLS
elif args.symbols:
symbols = args.symbols
else:
symbols = DEFAULT_SYMBOLS[:20] # safe default for quick runs
data_dir = Path(args.data_dir) if args.data_dir else None
dataset = build_dataset(symbols, bars=args.bars, data_dir=data_dir)
X = dataset[FEATURE_COLUMNS].values.astype(np.float64)
y = dataset["label"].values.astype(np.int32)
timestamps = dataset["_timestamp"].values
logger.info(f"Feature matrix: {X.shape} | Positive rate: {y.mean():.4f}")
# Walk-forward validation
logger.info("Running walk-forward validation...")
wf_results = run_walk_forward(X, y, timestamps=timestamps, params=LGBM_PARAMS)
summary = summarize_walk_forward(wf_results)
logger.info("\n=== WALK-FORWARD SUMMARY ===")
logger.info(f" Folds: {summary['n_folds']}")
logger.info(f" Mean threshold: {summary['mean_threshold']:.4f} Β± {summary['std_threshold']:.4f}")
logger.info(f" Mean expectancy: {summary['mean_expectancy']}")
logger.info(f" Mean sharpe: {summary['mean_sharpe']}")
logger.info(f" Mean precision: {summary['mean_precision']}")
if summary.get("mean_expectancy") is not None and summary["mean_expectancy"] < 0:
logger.warning("Negative mean expectancy! Model may not generalize. Check data quality.")
# Choose final threshold: mean of walk-forward optimal thresholds
final_threshold = summary["mean_threshold"]
logger.info(f"\nFinal threshold: {final_threshold:.4f}")
# Feature importance report
imp_arr = np.array(summary["avg_feature_importance"])
imp_pairs = sorted(zip(FEATURE_COLUMNS, imp_arr), key=lambda x: x[1], reverse=True)
logger.info("\n=== TOP 15 FEATURES BY IMPORTANCE ===")
for feat, imp in imp_pairs[:15]:
bar = "β–ˆ" * int(imp / imp_arr.max() * 30) if imp_arr.max() > 0 else ""
logger.info(f" {feat:<28} {imp:>8.2f} {bar}")
# Fit final model on all data
logger.info("\nFitting final model on full dataset...")
final_backend = fit_final_model(X, y, LGBM_PARAMS, val_frac=0.15)
# Save everything
save_artifacts(final_backend, final_threshold, summary, dataset)
logger.info("\nβœ“ Training complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train OKX trade probability filter")
parser.add_argument("--symbols", nargs="+", default=None, help="Symbol list, e.g. BTC-USDT ETH-USDT")
parser.add_argument("--use-defaults", action="store_true", help="Use all DEFAULT_SYMBOLS from config")
parser.add_argument("--bars", type=int, default=300, help="OHLCV bars to fetch per symbol")
parser.add_argument("--data-dir", type=str, default=None, help="Directory of pre-saved CSV files")
args = parser.parse_args()
main(args)