openra-rl-challenge / scripts /train_bc_qwen.py
github-actions[bot]
Sync Space files from 04ee2ab23ee580fa351550303a8efd99a52df7e2
82d84b1
#!/usr/bin/env python3
"""Train a Qwen behavioral-cloning model on the compact macro dataset.
This is the macro-policy training path. It consumes the compact
`macro_dataset.jsonl.gz` written directly by `scripts/collect_bot_data.py`.
Typical usage:
python scripts/train_bc_qwen.py \
--data-path data/macro/macro_dataset.jsonl.gz \
--model Qwen/Qwen3-4B \
--output-dir checkpoints/openra-bc-qwen
Recommended workflow:
1. Collect the compact macro dataset with `collect_bot_data.py`
2. Upload the compact dataset to Colab or use it locally
3. Train this macro BC model
"""
from __future__ import annotations
import argparse
import gzip
import json
import random
import sys
from pathlib import Path
from typing import Any
def open_text_reader(path: Path):
if path.suffix == ".gz":
return gzip.open(path, "rt", encoding="utf-8-sig")
return open(path, "r", encoding="utf-8-sig")
def load_macro_rows(
data_path: Path,
max_rows: int | None = None,
max_episodes: int | None = None,
) -> list[dict[str, Any]]:
"""Load macro dataset rows from JSONL / JSONL.GZ."""
if not data_path.exists():
print(f"Dataset not found: {data_path}")
print("Run scripts/collect_bot_data.py first to create macro_dataset.jsonl.gz.")
sys.exit(1)
rows: list[dict[str, Any]] = []
seen_episodes: set[str] = set()
skipped_missing = 0
with open_text_reader(data_path) as f:
for line_no, raw_line in enumerate(f, start=1):
line = raw_line.strip()
if not line:
continue
try:
row = json.loads(line)
except json.JSONDecodeError as exc:
print(f"Invalid JSONL at line {line_no}: {exc}")
sys.exit(1)
episode = str(row.get("episode", "") or "")
if max_episodes is not None and episode and episode not in seen_episodes:
if len(seen_episodes) >= max_episodes:
continue
seen_episodes.add(episode)
elif episode:
seen_episodes.add(episode)
prompt = str(row.get("prompt", "") or "").strip()
completion = str(row.get("completion", "") or "").strip()
if not prompt or not completion:
skipped_missing += 1
continue
record = dict(row)
record["text"] = f"{prompt}\n{completion}"
rows.append(record)
if max_rows is not None and len(rows) >= max_rows:
break
if not rows:
print(f"No usable rows found in {data_path}")
sys.exit(1)
print(f"Loaded {len(rows)} rows from {data_path}")
print(f" Episodes: {len(seen_episodes)}")
if skipped_missing:
print(f" Skipped rows missing prompt/completion: {skipped_missing}")
return rows
def split_by_episode(
rows: list[dict[str, Any]],
val_ratio: float,
seed: int,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Split train/eval at the episode level to avoid leakage."""
if val_ratio <= 0:
return rows, []
episode_to_rows: dict[str, list[dict[str, Any]]] = {}
for row in rows:
episode = str(row.get("episode", "") or "")
episode_to_rows.setdefault(episode, []).append(row)
episodes = list(episode_to_rows)
if len(episodes) < 2:
return rows, []
rng = random.Random(seed)
rng.shuffle(episodes)
n_eval = max(1, round(len(episodes) * val_ratio))
if n_eval >= len(episodes):
n_eval = len(episodes) - 1
eval_episodes = set(episodes[:n_eval])
train_rows: list[dict[str, Any]] = []
eval_rows: list[dict[str, Any]] = []
for episode, episode_rows in episode_to_rows.items():
if episode in eval_episodes:
eval_rows.extend(episode_rows)
else:
train_rows.extend(episode_rows)
return train_rows, eval_rows
def describe_rows(rows: list[dict[str, Any]], label: str) -> None:
primary_intents: dict[str, int] = {}
phases: dict[str, int] = {}
results: dict[str, int] = {}
for row in rows:
intent = str(row.get("primary_intent", "") or "unknown")
phase = str(row.get("phase", "") or "unknown")
result = str(row.get("episode_result", "") or "unknown")
primary_intents[intent] = primary_intents.get(intent, 0) + 1
phases[phase] = phases.get(phase, 0) + 1
results[result] = results.get(result, 0) + 1
def top_counts(counts: dict[str, int], limit: int = 8) -> str:
ordered = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))[:limit]
return ", ".join(f"{k}:{v}" for k, v in ordered) if ordered else "none"
print(f"\n{label}: {len(rows)} rows")
print(f" Primary intents: {top_counts(primary_intents)}")
print(f" Phases: {top_counts(phases)}")
print(f" Episode results: {top_counts(results)}")
def build_peft_config(args):
from peft import LoraConfig
target_modules = [module.strip() for module in args.lora_target_modules.split(",") if module.strip()]
return LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules,
)
def load_model_and_tokenizer(args):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
quantization_config = None
if args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
)
model_kwargs: dict[str, Any] = {
"trust_remote_code": True,
}
if quantization_config is not None:
model_kwargs["quantization_config"] = quantization_config
model_kwargs["device_map"] = "auto"
elif torch.cuda.is_available():
model_kwargs["torch_dtype"] = "auto"
model_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs)
return model, tokenizer
def train(args) -> None:
rows = load_macro_rows(
data_path=args.data_path,
max_rows=args.max_rows,
max_episodes=args.max_episodes,
)
train_rows, eval_rows = split_by_episode(rows, val_ratio=args.val_ratio, seed=args.seed)
if not train_rows:
print("No training rows after split. Exiting.")
sys.exit(1)
describe_rows(train_rows, "Train split")
if eval_rows:
describe_rows(eval_rows, "Eval split")
else:
print("\nEval split: disabled")
print("\nExample prompt:")
print(train_rows[0]["prompt"][:800])
print("\nExample completion:")
print(train_rows[0]["completion"][:400])
if args.prepare_only:
print("\nPrepare-only mode; not starting training.")
return
import torch
from datasets import Dataset
from peft import prepare_model_for_kbit_training
from trl import SFTConfig, SFTTrainer
train_dataset = Dataset.from_list(train_rows)
eval_dataset = Dataset.from_list(eval_rows) if eval_rows else None
print(f"\nLoading model: {args.model}")
model, tokenizer = load_model_and_tokenizer(args)
peft_config = build_peft_config(args) if not args.no_lora else None
if args.load_in_4bit:
model = prepare_model_for_kbit_training(model)
use_bf16 = bool(torch.cuda.is_available() and torch.cuda.is_bf16_supported())
use_fp16 = bool(torch.cuda.is_available() and not use_bf16)
training_args = SFTConfig(
output_dir=str(args.output_dir),
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.eval_batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.lr,
max_seq_length=args.max_seq_length,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps if eval_dataset is not None else None,
save_total_limit=args.save_total_limit,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type=args.lr_scheduler_type,
dataset_text_field="text",
report_to="none",
bf16=use_bf16,
fp16=use_fp16,
gradient_checkpointing=args.gradient_checkpointing,
evaluation_strategy="steps" if eval_dataset is not None else "no",
)
print("\nStarting macro BC training...")
print(f" Output: {args.output_dir}")
print(f" Train rows: {len(train_dataset)}")
print(f" Eval rows: {len(eval_dataset) if eval_dataset is not None else 0}")
print(f" Epochs: {args.epochs}")
print(f" Batch size: {args.batch_size}")
print(f" Grad accum: {args.grad_accum}")
print(f" Learning rate: {args.lr}")
print(f" Max seq length: {args.max_seq_length}")
print(f" LoRA: {'off' if args.no_lora else 'on'}")
if not args.no_lora:
print(
" LoRA targets: "
f"{args.lora_target_modules} (r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout})"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
peft_config=peft_config,
)
trainer.train()
print(f"\nSaving model to {args.output_dir}")
trainer.save_model(str(args.output_dir))
tokenizer.save_pretrained(str(args.output_dir))
print("Training complete.")
def main() -> None:
parser = argparse.ArgumentParser(
description="Train Qwen on the macro-action dataset for behavior cloning."
)
parser.add_argument(
"--data-path",
type=Path,
default=Path("data/macro/macro_dataset.jsonl.gz"),
help="Path to macro dataset JSONL / JSONL.GZ (default: data/macro/macro_dataset.jsonl.gz)",
)
parser.add_argument(
"--model",
default="Qwen/Qwen3-4B",
help="Base model to fine-tune (default: Qwen/Qwen3-4B)",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("checkpoints/openra-bc-qwen"),
help="Output directory for the trained model (default: checkpoints/openra-bc-qwen)",
)
parser.add_argument(
"--max-rows",
type=int,
default=None,
help="Maximum macro rows to load (default: all)",
)
parser.add_argument(
"--max-episodes",
type=int,
default=None,
help="Maximum unique episodes to use (default: all)",
)
parser.add_argument(
"--val-ratio",
type=float,
default=0.1,
help="Episode-level validation split ratio (default: 0.1)",
)
parser.add_argument(
"--seed",
type=int,
default=7,
help="Random seed for episode split (default: 7)",
)
parser.add_argument(
"--epochs",
type=int,
default=3,
help="Number of training epochs (default: 3)",
)
parser.add_argument(
"--batch-size",
type=int,
default=2,
help="Per-device train batch size (default: 2)",
)
parser.add_argument(
"--eval-batch-size",
type=int,
default=2,
help="Per-device eval batch size (default: 2)",
)
parser.add_argument(
"--grad-accum",
type=int,
default=4,
help="Gradient accumulation steps (default: 4)",
)
parser.add_argument(
"--lr",
type=float,
default=2e-5,
help="Learning rate (default: 2e-5)",
)
parser.add_argument(
"--max-seq-length",
type=int,
default=1024,
help="Maximum sequence length (default: 1024)",
)
parser.add_argument(
"--logging-steps",
type=int,
default=10,
help="Log every N steps (default: 10)",
)
parser.add_argument(
"--save-steps",
type=int,
default=500,
help="Save every N steps (default: 500)",
)
parser.add_argument(
"--eval-steps",
type=int,
default=200,
help="Evaluate every N steps when eval split exists (default: 200)",
)
parser.add_argument(
"--save-total-limit",
type=int,
default=2,
help="Maximum number of checkpoints to keep (default: 2)",
)
parser.add_argument(
"--warmup-ratio",
type=float,
default=0.03,
help="Warmup ratio (default: 0.03)",
)
parser.add_argument(
"--lr-scheduler-type",
default="cosine",
help="Learning rate scheduler type (default: cosine)",
)
parser.add_argument(
"--gradient-checkpointing",
action="store_true",
help="Enable gradient checkpointing",
)
parser.add_argument(
"--prepare-only",
action="store_true",
help="Load and split the dataset, then stop before training",
)
parser.add_argument(
"--load-in-4bit",
action="store_true",
help="Load the base model with 4-bit quantization",
)
parser.add_argument(
"--no-lora",
action="store_true",
help="Disable LoRA and fine-tune the base model directly",
)
parser.add_argument(
"--lora-r",
type=int,
default=16,
help="LoRA rank (default: 16)",
)
parser.add_argument(
"--lora-alpha",
type=int,
default=32,
help="LoRA alpha (default: 32)",
)
parser.add_argument(
"--lora-dropout",
type=float,
default=0.05,
help="LoRA dropout (default: 0.05)",
)
parser.add_argument(
"--lora-target-modules",
default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj",
help="Comma-separated LoRA target modules (default: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj)",
)
args = parser.parse_args()
train(args)
if __name__ == "__main__":
main()