Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Train a Qwen behavioral-cloning model on the compact macro dataset. | |
| This is the macro-policy training path. It consumes the compact | |
| `macro_dataset.jsonl.gz` written directly by `scripts/collect_bot_data.py`. | |
| Typical usage: | |
| python scripts/train_bc_qwen.py \ | |
| --data-path data/macro/macro_dataset.jsonl.gz \ | |
| --model Qwen/Qwen3-4B \ | |
| --output-dir checkpoints/openra-bc-qwen | |
| Recommended workflow: | |
| 1. Collect the compact macro dataset with `collect_bot_data.py` | |
| 2. Upload the compact dataset to Colab or use it locally | |
| 3. Train this macro BC model | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gzip | |
| import json | |
| import random | |
| import sys | |
| from pathlib import Path | |
| from typing import Any | |
| def open_text_reader(path: Path): | |
| if path.suffix == ".gz": | |
| return gzip.open(path, "rt", encoding="utf-8-sig") | |
| return open(path, "r", encoding="utf-8-sig") | |
| def load_macro_rows( | |
| data_path: Path, | |
| max_rows: int | None = None, | |
| max_episodes: int | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """Load macro dataset rows from JSONL / JSONL.GZ.""" | |
| if not data_path.exists(): | |
| print(f"Dataset not found: {data_path}") | |
| print("Run scripts/collect_bot_data.py first to create macro_dataset.jsonl.gz.") | |
| sys.exit(1) | |
| rows: list[dict[str, Any]] = [] | |
| seen_episodes: set[str] = set() | |
| skipped_missing = 0 | |
| with open_text_reader(data_path) as f: | |
| for line_no, raw_line in enumerate(f, start=1): | |
| line = raw_line.strip() | |
| if not line: | |
| continue | |
| try: | |
| row = json.loads(line) | |
| except json.JSONDecodeError as exc: | |
| print(f"Invalid JSONL at line {line_no}: {exc}") | |
| sys.exit(1) | |
| episode = str(row.get("episode", "") or "") | |
| if max_episodes is not None and episode and episode not in seen_episodes: | |
| if len(seen_episodes) >= max_episodes: | |
| continue | |
| seen_episodes.add(episode) | |
| elif episode: | |
| seen_episodes.add(episode) | |
| prompt = str(row.get("prompt", "") or "").strip() | |
| completion = str(row.get("completion", "") or "").strip() | |
| if not prompt or not completion: | |
| skipped_missing += 1 | |
| continue | |
| record = dict(row) | |
| record["text"] = f"{prompt}\n{completion}" | |
| rows.append(record) | |
| if max_rows is not None and len(rows) >= max_rows: | |
| break | |
| if not rows: | |
| print(f"No usable rows found in {data_path}") | |
| sys.exit(1) | |
| print(f"Loaded {len(rows)} rows from {data_path}") | |
| print(f" Episodes: {len(seen_episodes)}") | |
| if skipped_missing: | |
| print(f" Skipped rows missing prompt/completion: {skipped_missing}") | |
| return rows | |
| def split_by_episode( | |
| rows: list[dict[str, Any]], | |
| val_ratio: float, | |
| seed: int, | |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
| """Split train/eval at the episode level to avoid leakage.""" | |
| if val_ratio <= 0: | |
| return rows, [] | |
| episode_to_rows: dict[str, list[dict[str, Any]]] = {} | |
| for row in rows: | |
| episode = str(row.get("episode", "") or "") | |
| episode_to_rows.setdefault(episode, []).append(row) | |
| episodes = list(episode_to_rows) | |
| if len(episodes) < 2: | |
| return rows, [] | |
| rng = random.Random(seed) | |
| rng.shuffle(episodes) | |
| n_eval = max(1, round(len(episodes) * val_ratio)) | |
| if n_eval >= len(episodes): | |
| n_eval = len(episodes) - 1 | |
| eval_episodes = set(episodes[:n_eval]) | |
| train_rows: list[dict[str, Any]] = [] | |
| eval_rows: list[dict[str, Any]] = [] | |
| for episode, episode_rows in episode_to_rows.items(): | |
| if episode in eval_episodes: | |
| eval_rows.extend(episode_rows) | |
| else: | |
| train_rows.extend(episode_rows) | |
| return train_rows, eval_rows | |
| def describe_rows(rows: list[dict[str, Any]], label: str) -> None: | |
| primary_intents: dict[str, int] = {} | |
| phases: dict[str, int] = {} | |
| results: dict[str, int] = {} | |
| for row in rows: | |
| intent = str(row.get("primary_intent", "") or "unknown") | |
| phase = str(row.get("phase", "") or "unknown") | |
| result = str(row.get("episode_result", "") or "unknown") | |
| primary_intents[intent] = primary_intents.get(intent, 0) + 1 | |
| phases[phase] = phases.get(phase, 0) + 1 | |
| results[result] = results.get(result, 0) + 1 | |
| def top_counts(counts: dict[str, int], limit: int = 8) -> str: | |
| ordered = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))[:limit] | |
| return ", ".join(f"{k}:{v}" for k, v in ordered) if ordered else "none" | |
| print(f"\n{label}: {len(rows)} rows") | |
| print(f" Primary intents: {top_counts(primary_intents)}") | |
| print(f" Phases: {top_counts(phases)}") | |
| print(f" Episode results: {top_counts(results)}") | |
| def build_peft_config(args): | |
| from peft import LoraConfig | |
| target_modules = [module.strip() for module in args.lora_target_modules.split(",") if module.strip()] | |
| return LoraConfig( | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=target_modules, | |
| ) | |
| def load_model_and_tokenizer(args): | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "right" | |
| quantization_config = None | |
| if args.load_in_4bit: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16, | |
| ) | |
| model_kwargs: dict[str, Any] = { | |
| "trust_remote_code": True, | |
| } | |
| if quantization_config is not None: | |
| model_kwargs["quantization_config"] = quantization_config | |
| model_kwargs["device_map"] = "auto" | |
| elif torch.cuda.is_available(): | |
| model_kwargs["torch_dtype"] = "auto" | |
| model_kwargs["device_map"] = "auto" | |
| model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs) | |
| return model, tokenizer | |
| def train(args) -> None: | |
| rows = load_macro_rows( | |
| data_path=args.data_path, | |
| max_rows=args.max_rows, | |
| max_episodes=args.max_episodes, | |
| ) | |
| train_rows, eval_rows = split_by_episode(rows, val_ratio=args.val_ratio, seed=args.seed) | |
| if not train_rows: | |
| print("No training rows after split. Exiting.") | |
| sys.exit(1) | |
| describe_rows(train_rows, "Train split") | |
| if eval_rows: | |
| describe_rows(eval_rows, "Eval split") | |
| else: | |
| print("\nEval split: disabled") | |
| print("\nExample prompt:") | |
| print(train_rows[0]["prompt"][:800]) | |
| print("\nExample completion:") | |
| print(train_rows[0]["completion"][:400]) | |
| if args.prepare_only: | |
| print("\nPrepare-only mode; not starting training.") | |
| return | |
| import torch | |
| from datasets import Dataset | |
| from peft import prepare_model_for_kbit_training | |
| from trl import SFTConfig, SFTTrainer | |
| train_dataset = Dataset.from_list(train_rows) | |
| eval_dataset = Dataset.from_list(eval_rows) if eval_rows else None | |
| print(f"\nLoading model: {args.model}") | |
| model, tokenizer = load_model_and_tokenizer(args) | |
| peft_config = build_peft_config(args) if not args.no_lora else None | |
| if args.load_in_4bit: | |
| model = prepare_model_for_kbit_training(model) | |
| use_bf16 = bool(torch.cuda.is_available() and torch.cuda.is_bf16_supported()) | |
| use_fp16 = bool(torch.cuda.is_available() and not use_bf16) | |
| training_args = SFTConfig( | |
| output_dir=str(args.output_dir), | |
| num_train_epochs=args.epochs, | |
| per_device_train_batch_size=args.batch_size, | |
| per_device_eval_batch_size=args.eval_batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| learning_rate=args.lr, | |
| max_seq_length=args.max_seq_length, | |
| logging_steps=args.logging_steps, | |
| save_steps=args.save_steps, | |
| eval_steps=args.eval_steps if eval_dataset is not None else None, | |
| save_total_limit=args.save_total_limit, | |
| warmup_ratio=args.warmup_ratio, | |
| lr_scheduler_type=args.lr_scheduler_type, | |
| dataset_text_field="text", | |
| report_to="none", | |
| bf16=use_bf16, | |
| fp16=use_fp16, | |
| gradient_checkpointing=args.gradient_checkpointing, | |
| evaluation_strategy="steps" if eval_dataset is not None else "no", | |
| ) | |
| print("\nStarting macro BC training...") | |
| print(f" Output: {args.output_dir}") | |
| print(f" Train rows: {len(train_dataset)}") | |
| print(f" Eval rows: {len(eval_dataset) if eval_dataset is not None else 0}") | |
| print(f" Epochs: {args.epochs}") | |
| print(f" Batch size: {args.batch_size}") | |
| print(f" Grad accum: {args.grad_accum}") | |
| print(f" Learning rate: {args.lr}") | |
| print(f" Max seq length: {args.max_seq_length}") | |
| print(f" LoRA: {'off' if args.no_lora else 'on'}") | |
| if not args.no_lora: | |
| print( | |
| " LoRA targets: " | |
| f"{args.lora_target_modules} (r={args.lora_r}, alpha={args.lora_alpha}, dropout={args.lora_dropout})" | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| tokenizer=tokenizer, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| args=training_args, | |
| peft_config=peft_config, | |
| ) | |
| trainer.train() | |
| print(f"\nSaving model to {args.output_dir}") | |
| trainer.save_model(str(args.output_dir)) | |
| tokenizer.save_pretrained(str(args.output_dir)) | |
| print("Training complete.") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Train Qwen on the macro-action dataset for behavior cloning." | |
| ) | |
| parser.add_argument( | |
| "--data-path", | |
| type=Path, | |
| default=Path("data/macro/macro_dataset.jsonl.gz"), | |
| help="Path to macro dataset JSONL / JSONL.GZ (default: data/macro/macro_dataset.jsonl.gz)", | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default="Qwen/Qwen3-4B", | |
| help="Base model to fine-tune (default: Qwen/Qwen3-4B)", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=Path, | |
| default=Path("checkpoints/openra-bc-qwen"), | |
| help="Output directory for the trained model (default: checkpoints/openra-bc-qwen)", | |
| ) | |
| parser.add_argument( | |
| "--max-rows", | |
| type=int, | |
| default=None, | |
| help="Maximum macro rows to load (default: all)", | |
| ) | |
| parser.add_argument( | |
| "--max-episodes", | |
| type=int, | |
| default=None, | |
| help="Maximum unique episodes to use (default: all)", | |
| ) | |
| parser.add_argument( | |
| "--val-ratio", | |
| type=float, | |
| default=0.1, | |
| help="Episode-level validation split ratio (default: 0.1)", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=7, | |
| help="Random seed for episode split (default: 7)", | |
| ) | |
| parser.add_argument( | |
| "--epochs", | |
| type=int, | |
| default=3, | |
| help="Number of training epochs (default: 3)", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=2, | |
| help="Per-device train batch size (default: 2)", | |
| ) | |
| parser.add_argument( | |
| "--eval-batch-size", | |
| type=int, | |
| default=2, | |
| help="Per-device eval batch size (default: 2)", | |
| ) | |
| parser.add_argument( | |
| "--grad-accum", | |
| type=int, | |
| default=4, | |
| help="Gradient accumulation steps (default: 4)", | |
| ) | |
| parser.add_argument( | |
| "--lr", | |
| type=float, | |
| default=2e-5, | |
| help="Learning rate (default: 2e-5)", | |
| ) | |
| parser.add_argument( | |
| "--max-seq-length", | |
| type=int, | |
| default=1024, | |
| help="Maximum sequence length (default: 1024)", | |
| ) | |
| parser.add_argument( | |
| "--logging-steps", | |
| type=int, | |
| default=10, | |
| help="Log every N steps (default: 10)", | |
| ) | |
| parser.add_argument( | |
| "--save-steps", | |
| type=int, | |
| default=500, | |
| help="Save every N steps (default: 500)", | |
| ) | |
| parser.add_argument( | |
| "--eval-steps", | |
| type=int, | |
| default=200, | |
| help="Evaluate every N steps when eval split exists (default: 200)", | |
| ) | |
| parser.add_argument( | |
| "--save-total-limit", | |
| type=int, | |
| default=2, | |
| help="Maximum number of checkpoints to keep (default: 2)", | |
| ) | |
| parser.add_argument( | |
| "--warmup-ratio", | |
| type=float, | |
| default=0.03, | |
| help="Warmup ratio (default: 0.03)", | |
| ) | |
| parser.add_argument( | |
| "--lr-scheduler-type", | |
| default="cosine", | |
| help="Learning rate scheduler type (default: cosine)", | |
| ) | |
| parser.add_argument( | |
| "--gradient-checkpointing", | |
| action="store_true", | |
| help="Enable gradient checkpointing", | |
| ) | |
| parser.add_argument( | |
| "--prepare-only", | |
| action="store_true", | |
| help="Load and split the dataset, then stop before training", | |
| ) | |
| parser.add_argument( | |
| "--load-in-4bit", | |
| action="store_true", | |
| help="Load the base model with 4-bit quantization", | |
| ) | |
| parser.add_argument( | |
| "--no-lora", | |
| action="store_true", | |
| help="Disable LoRA and fine-tune the base model directly", | |
| ) | |
| parser.add_argument( | |
| "--lora-r", | |
| type=int, | |
| default=16, | |
| help="LoRA rank (default: 16)", | |
| ) | |
| parser.add_argument( | |
| "--lora-alpha", | |
| type=int, | |
| default=32, | |
| help="LoRA alpha (default: 32)", | |
| ) | |
| parser.add_argument( | |
| "--lora-dropout", | |
| type=float, | |
| default=0.05, | |
| help="LoRA dropout (default: 0.05)", | |
| ) | |
| parser.add_argument( | |
| "--lora-target-modules", | |
| default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj", | |
| help="Comma-separated LoRA target modules (default: q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj)", | |
| ) | |
| args = parser.parse_args() | |
| train(args) | |
| if __name__ == "__main__": | |
| main() | |