rl_btc_v4_iql / train.py
fbzu's picture
Upload folder using huggingface_hub
22d888b verified
#!/usr/bin/env python3
"""
Train BTC v4 Offline IQL Trading Agent.
Usage:
python -m rl_btc_v4.train
python -m rl_btc_v4.train --data-path /path/to/parquet --epochs 100 --hidden-dim 256
python -m rl_btc_v4.train --help
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import pandas as pd
from rl_btc_v4.constants import (
DEFAULT_DATA_PATH,
MARKET_FEATURE_COLUMNS,
N_ACTIONS,
DEFAULT_HISTORY_LENGTH,
DEFAULT_EPISODE_SPAN_DAYS,
DEFAULT_EPISODE_STRIDE_DAYS,
TAKER_FEE_RATE,
)
from rl_btc_v4.dataset import build_offline_rl_dataset
from rl_btc_v4.iql_trainer import IQLTrainer, IQLConfig
def train(
*,
data_path: str | None = None,
outdir: str = "artifacts_rl_btc_v4_iql",
history_length: int = DEFAULT_HISTORY_LENGTH,
episode_span_days: int = DEFAULT_EPISODE_SPAN_DAYS,
episode_stride_days: int = DEFAULT_EPISODE_STRIDE_DAYS,
hidden_dim: int = 256,
num_layers: int = 2,
dropout: float = 0.0,
expectile: float = 0.7,
temperature: float = 3.0,
gamma: float = 0.99,
tau: float = 0.005,
learning_rate: float = 3e-4,
batch_size: int = 256,
num_epochs: int = 100,
weight_decay: float = 1e-4,
risk_lambda: float = 1.0,
soft_dd_penalty: float = 0.50,
behavioral_policy_mode: str = "conservative",
behavioral_temperature: float = 1.0,
min_trade_edge: float = 0.005,
behavioral_epsilon: float = 0.03,
taker_fee_rate: float = TAKER_FEE_RATE,
test_fraction: float = 0.2,
seed: int = 42,
device: str | None = None,
) -> dict:
"""Full training pipeline."""
out_path = Path(outdir)
out_path.mkdir(parents=True, exist_ok=True)
data_path = Path(data_path) if data_path else DEFAULT_DATA_PATH
import torch
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[v4 IQL] Building offline RL dataset from {data_path}")
print(f" History length: {history_length}")
print(f" Episode span: {episode_span_days}d, stride: {episode_stride_days}d")
print(f" Risk lambda: {risk_lambda}, DD penalty: {soft_dd_penalty}")
train_dataset, test_dataset = build_offline_rl_dataset(
data_path=data_path,
history_length=history_length,
episode_span_days=episode_span_days,
episode_stride_days=episode_stride_days,
risk_lambda=risk_lambda,
soft_dd_penalty=soft_dd_penalty,
behavioral_policy_mode=behavioral_policy_mode,
behavioral_temperature=behavioral_temperature,
min_trade_edge=min_trade_edge,
behavioral_epsilon=behavioral_epsilon,
test_fraction=test_fraction,
seed=seed,
taker_fee_rate=taker_fee_rate,
)
print(f"[v4 IQL] Train dataset: {train_dataset.to_dict()}")
print(f"[v4 IQL] Test dataset: {test_dataset.to_dict()}")
# Save dataset stats
dataset_info = {
"train": train_dataset.to_dict(),
"test": test_dataset.to_dict(),
}
(out_path / "dataset_info.json").write_text(json.dumps(dataset_info, indent=2))
# Compute observation dimension
state_dim = train_dataset.states.shape[1]
# IQL config
config = IQLConfig(
hidden_dim=hidden_dim,
num_layers=num_layers,
dropout=dropout,
expectile=expectile,
temperature=temperature,
gamma=gamma,
tau=tau,
learning_rate=learning_rate,
batch_size=batch_size,
num_epochs=num_epochs,
weight_decay=weight_decay,
device=device,
seed=seed,
)
print(f"\n[v4 IQL] Training IQL on {device}")
print(f" State dim: {state_dim}, Action dim: {N_ACTIONS}")
print(f" Hidden: {hidden_dim}, Layers: {num_layers}")
print(f" Expectile: {expectile}, Temperature: {temperature}")
print(f" LR: {learning_rate}, Batch: {batch_size}, Epochs: {num_epochs}")
trainer = IQLTrainer(state_dim=state_dim, action_dim=N_ACTIONS, config=config)
t_start = time.time()
def progress_fn(epoch, metrics):
if epoch % config.eval_freq == 0:
elapsed = time.time() - t_start
print(f" [{elapsed:.0f}s] Epoch {epoch}: "
f"Q={metrics['q_loss']:.4f} V={metrics['v_loss']:.4f} "
f"π={metrics['policy_loss']:.4f} "
f"Adv={metrics['advantage']:.4f}")
result = trainer.train(
states=train_dataset.states,
actions=train_dataset.actions,
rewards=train_dataset.rewards,
next_states=train_dataset.next_states,
dones=train_dataset.dones,
eval_states=test_dataset.states,
eval_rewards=test_dataset.rewards,
progress_fn=progress_fn,
)
t_elapsed = time.time() - t_start
print(f"\n[v4 IQL] Training complete in {t_elapsed:.1f}s")
print(f"[v4 IQL] Final metrics: {result['final_metrics']}")
# Save model
trainer.save(out_path)
# Save normalization stats
np.savez(
out_path / "scaler.npz",
mean=train_dataset.mean,
std=train_dataset.std,
reward_mean=result["reward_mean"],
reward_std=result["reward_std"],
)
# Save training config and results
report = {
"algorithm": "IQL",
"config": config.__dict__,
"dataset": {
"path": str(data_path),
"history_length": history_length,
"episode_span_days": episode_span_days,
"episode_stride_days": episode_stride_days,
"risk_lambda": risk_lambda,
"soft_dd_penalty": soft_dd_penalty,
"behavioral_policy_mode": behavioral_policy_mode,
"behavioral_temperature": behavioral_temperature,
"min_trade_edge": min_trade_edge,
"behavioral_epsilon": behavioral_epsilon,
"taker_fee_rate": taker_fee_rate,
},
"results": result,
"training_time_seconds": t_elapsed,
"device": device,
}
(out_path / "train_report.json").write_text(json.dumps(report, indent=2))
# Quick policy evaluation on test set
test_states = test_dataset.states
test_actions = test_dataset.actions
test_rewards = test_dataset.rewards
with torch.no_grad():
states_t = torch.FloatTensor(test_states).to(trainer.device)
logits = trainer.policy_net(states_t)
predicted_actions = torch.argmax(logits, dim=-1).cpu().numpy()
agreement = float(np.mean(predicted_actions == test_actions))
print(f"\n[v4 IQL] Policy-behavior agreement on test set: {agreement:.4f}")
report["test_agreement"] = agreement
(out_path / "train_report.json").write_text(json.dumps(report, indent=2))
print(f"\n[v4 IQL] Artifacts saved to {out_path}")
print(f" - iql_model.pt")
print(f" - scaler.npz")
print(f" - train_report.json")
print(f" - dataset_info.json")
return report
def build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Train BTC v4 IQL Trading Agent.")
parser.add_argument("--data-path", default=None, help="Path to parquet dataset")
parser.add_argument("--outdir", default="artifacts_rl_btc_v4_iql")
parser.add_argument("--history-length", type=int, default=DEFAULT_HISTORY_LENGTH)
parser.add_argument("--episode-span-days", type=int, default=DEFAULT_EPISODE_SPAN_DAYS)
parser.add_argument("--episode-stride-days", type=int, default=DEFAULT_EPISODE_STRIDE_DAYS)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--expectile", type=float, default=0.7)
parser.add_argument("--temperature", type=float, default=3.0)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--learning-rate", type=float, default=3e-4)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--risk-lambda", type=float, default=1.0)
parser.add_argument("--soft-dd-penalty", type=float, default=0.50)
parser.add_argument("--behavioral-policy-mode", choices=("conservative", "softmax"), default="conservative")
parser.add_argument("--behavioral-temperature", type=float, default=1.0)
parser.add_argument("--min-trade-edge", type=float, default=0.005)
parser.add_argument("--behavioral-epsilon", type=float, default=0.03)
parser.add_argument("--taker-fee-rate", type=float, default=TAKER_FEE_RATE)
parser.add_argument("--test-fraction", type=float, default=0.2)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", default=None)
return parser
def main(argv: list[str] | None = None) -> None:
parser = build_arg_parser()
args = parser.parse_args(argv)
result = train(
data_path=args.data_path,
outdir=args.outdir,
history_length=args.history_length,
episode_span_days=args.episode_span_days,
episode_stride_days=args.episode_stride_days,
hidden_dim=args.hidden_dim,
num_layers=args.num_layers,
dropout=args.dropout,
expectile=args.expectile,
temperature=args.temperature,
gamma=args.gamma,
tau=args.tau,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
num_epochs=args.epochs,
weight_decay=args.weight_decay,
risk_lambda=args.risk_lambda,
soft_dd_penalty=args.soft_dd_penalty,
behavioral_policy_mode=args.behavioral_policy_mode,
behavioral_temperature=args.behavioral_temperature,
min_trade_edge=args.min_trade_edge,
behavioral_epsilon=args.behavioral_epsilon,
taker_fee_rate=args.taker_fee_rate,
test_fraction=args.test_fraction,
seed=args.seed,
device=args.device,
)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()