#!/usr/bin/env python # train_battleground_rlaif.py # # SFT + GRPO (RLAIF style) on synthetic Hearthstone Battlegrounds data. # Dataset format: RL/datasets/battleground_rlaif_multicandidate.jsonl # Each row: { game_id, step_id, turn, phase, state, candidates[3], meta } # candidates: [{role, action{...}, reward}, ...] import argparse import json import os import sys from dataclasses import dataclass from typing import List, Optional, Dict, Any import torch from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM from peft import LoraConfig from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) if _SCRIPT_DIR not in sys.path: sys.path.append(_SCRIPT_DIR) from battleground_nl_utils import ( dataset_state_to_game_state, game_state_to_natural_language, ) # ================== Model paths & defaults ================== LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" def _resolve_default_model_id() -> str: env_override = os.environ.get("QWEN_INSTRUCT_MODEL") if env_override: return env_override if os.path.isdir(LOCAL_INSTRUCT_PATH): return LOCAL_INSTRUCT_PATH return "Qwen/Qwen3-4B-Instruct" DEFAULT_MODEL_ID = _resolve_default_model_id() DEFAULT_OUTPUT_DIR = "./battleground_rlaif_qwen" DEFAULT_DATA_FILE = "RL/datasets/battleground_rlaif_multicandidate_expert1_med0_bad-0_5.jsonl" DEFAULT_TARGET_MODULES = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ] # ================== Config dataclass ================== @dataclass class PipelineConfig: model_name_or_path: str = DEFAULT_MODEL_ID output_dir: str = DEFAULT_OUTPUT_DIR data_file: str = DEFAULT_DATA_FILE input_mode: str = "json" max_seq_length: int = 1024 sft_epochs: int = 3 grpo_epochs: int = 3 bf16: bool = True per_device_batch_size: int = 4 # A800 80GB can handle larger batches grad_accum_steps: int = 4 sft_learning_rate: float = 1e-5 grpo_learning_rate: float = 5e-6 max_completion_length: int = 128 num_generations: int = 3 steps_per_generation: int = 1 # kept for symmetry; not directly used in this script target_modules: Optional[List[str]] = None skip_sft: bool = False skip_grpo: bool = False train_on_all_data: bool = False def parse_args() -> PipelineConfig: parser = argparse.ArgumentParser( description="Run SFT + GRPO (RLAIF) on Battlegrounds synthetic dataset." ) parser.add_argument( "--model", default=DEFAULT_MODEL_ID, help="Model id or local path for the Qwen instruct checkpoint.", ) parser.add_argument( "--output-dir", default=DEFAULT_OUTPUT_DIR, help="Directory for checkpoints and logs.", ) parser.add_argument( "--data-file", default=DEFAULT_DATA_FILE, help="Path to JSONL file with multi-candidate Battlegrounds data.", ) parser.add_argument( "--input-mode", choices=["json", "nl"], default="json", help="Input format for game state: 'json' uses raw JSON, 'nl' uses natural language description.", ) parser.add_argument("--max-seq-length", type=int, default=1024) parser.add_argument("--sft-epochs", type=int, default=35) parser.add_argument("--grpo-epochs", type=int, default=10) parser.add_argument("--per-device-batch-size", type=int, default=4, help="Batch size per device (default: 4 for A800 80GB)") parser.add_argument("--grad-accum-steps", type=int, default=4) parser.add_argument("--sft-learning-rate", type=float, default=1e-5) parser.add_argument("--grpo-learning-rate", type=float, default=5e-6) parser.add_argument("--max-completion-length", type=int, default=128) parser.add_argument("--num-generations", type=int, default=3) parser.add_argument( "--target-modules", default=None, help="Comma-separated list of module names for LoRA (defaults to Qwen attn/FFN blocks).", ) parser.add_argument( "--disable-bf16", action="store_true", help="Force fp16/fp32 training if bf16 is not desired or unsupported.", ) parser.add_argument("--skip-sft", action="store_true", help="Skip the SFT phase.") parser.add_argument("--skip-grpo", action="store_true", help="Skip the GRPO phase.") parser.add_argument( "--train-on-all-data", action="store_true", help="Use all rows as training data (no hold-out split); SFT eval runs on the same data.", ) args = parser.parse_args() target_modules = ( [m.strip() for m in args.target_modules.split(",") if m.strip()] if args.target_modules else None ) return PipelineConfig( model_name_or_path=args.model, output_dir=args.output_dir, data_file=args.data_file, input_mode=args.input_mode, max_seq_length=args.max_seq_length, sft_epochs=args.sft_epochs, grpo_epochs=args.grpo_epochs, bf16=not args.disable_bf16, per_device_batch_size=args.per_device_batch_size, grad_accum_steps=args.grad_accum_steps, sft_learning_rate=args.sft_learning_rate, grpo_learning_rate=args.grpo_learning_rate, max_completion_length=args.max_completion_length, num_generations=args.num_generations, target_modules=target_modules, skip_sft=args.skip_sft, skip_grpo=args.skip_grpo, train_on_all_data=args.train_on_all_data, ) # ================== Data: Battlegrounds formatting ================== INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. Given the current game state as a JSON object, choose exactly one best action and respond with a single JSON object in this exact format: {"action":{"type":"","tavern_index":,"hand_index":,"board_index":,"card_name":}} Rules: 1. Respond with JSON only. Do not add explanations or any extra text. 2. The top-level object must have exactly one key: "action". 3. Use 0-based integers for indices or null when not used. 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". 5. "card_name" must exactly match a card name from the game state when required, otherwise null. Now here is the game state JSON: """ INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. Given the following natural language description of the current game state, choose exactly one best action and respond with a single JSON object in this exact format: {"action":{"type":"","tavern_index":,"hand_index":,"board_index":,"card_name":}} Rules: 1. Respond with JSON only. Do not add explanations or any extra text. 2. The top-level object must have exactly one key: "action". 3. Use 0-based integers for indices or null when not used. 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". 5. "card_name" must exactly match a card name from the game state when required, otherwise null. Now here is the description of the game state: """ def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: """ 把 state 打包成一个 JSON prompt: { "task": "battlegrounds_policy_v1", "phase": ..., "turn": ..., "state": {...} } """ if input_mode == "nl": game_state = dataset_state_to_game_state(example) nl_state = game_state_to_natural_language(game_state) prefix = INSTRUCTION_PREFIX_NL state_text = nl_state else: obj = { "task": "battlegrounds_policy_v1", "phase": example["phase"], "turn": example["turn"], "state": example["state"], } state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) prefix = INSTRUCTION_PREFIX return prefix + "\n" + state_text def _build_completion_from_action(action: Dict[str, Any]) -> str: """ 把 action 也打成 JSON completion: { "action": { ...action fields... } } """ return json.dumps({"action": action}, separators=(",", ":"), ensure_ascii=False) def load_battleground_rlaif( data_file: str, test_size: float = 0.1, seed: int = 42, train_on_all_data: bool = False, input_mode: str = "json", ): """ 从 JSONL 读取数据,构造: - SFT dataset: prompt + completion(只用 expert action) - RL dataset: prompt + candidates(多候选,给 reward_fn 用) """ raw = load_dataset( "json", data_files={"train": data_file}, )["train"] # 划分 train / eval(按 state 划分) if train_on_all_data: raw_train = raw raw_eval = raw else: split = raw.train_test_split(test_size=test_size, seed=seed) raw_train = split["train"] raw_eval = split["test"] def to_sft(example): # 选 expert candidate;如果没有显式 expert,就选 reward 最大的 candidates = example["candidates"] expert = None for c in candidates: if c.get("role") == "expert": expert = c break if expert is None: expert = max(candidates, key=lambda x: float(x.get("reward", 0.0))) prompt = _build_prompt(example, input_mode=input_mode) completion = _build_completion_from_action(expert["action"]) return { "prompt": prompt, "completion": completion, } def to_rl(example): prompt = _build_prompt(example, input_mode=input_mode) # candidates 保留给 reward_fn 使用 return { "prompt": prompt, "candidates": example["candidates"], } sft_train = raw_train.map(to_sft, remove_columns=raw_train.column_names) sft_eval = raw_eval.map(to_sft, remove_columns=raw_eval.column_names) rl_train = raw_train.map(to_rl, remove_columns=raw_train.column_names) return sft_train, sft_eval, rl_train # ================== Reward function for GRPO (RLAIF style) ================== def _parse_action_from_completion(text: str) -> Optional[Dict[str, Any]]: """ 尝试把 model 的 completion 解析为 JSON action: - 期望格式: {"action": {...}} or {...} """ text = text.strip() try: obj = json.loads(text) except Exception: return None if isinstance(obj, dict): if "action" in obj and isinstance(obj["action"], dict): return obj["action"] return obj return None def _actions_equal(a: Dict[str, Any], b: Dict[str, Any]) -> bool: """ 简单 dict 相等比较: - 假设字段集合一致即可。 - 如果你之后想更 robust,可以只比较 type/tavern_index/hand_index/board_index/card_name。 """ return a == b def battleground_rlaif_reward( completions: List[str], candidates: List[List[Dict[str, Any]]], **kwargs, ) -> List[float]: """ RLAIF-style reward function for GRPOTrainer. 对每个 completion(一个 action JSON 文本): 1. 解析为 action dict 2. 与该样本的 candidates 中的 action 比较 3. 如果完全匹配某个 candidate.action,则得到对应 reward (1.0 / 0.5 / 0.0) 4. 否则 reward = 0.0 (可以理解为“不是我们标记的任何动作”) 注意:TRL 会自动把 dataset 的 candidates 列展开复制到 batch 中, 所以这里 candidates 的长度与 completions 相同,一一对应。 """ rewards: List[float] = [] for comp_text, cand_list in zip(completions, candidates): act = _parse_action_from_completion(comp_text) if act is None: rewards.append(0.0) continue best_reward = 0.0 for cand in cand_list: cand_action = cand.get("action", {}) if _actions_equal(act, cand_action): r = float(cand.get("reward", 0.0)) if r > best_reward: best_reward = r rewards.append(best_reward) return rewards # ================== SFT phase ================== def run_sft(train_ds, eval_ds, tokenizer, cfg: PipelineConfig): """Run a short supervised fine-tuning pass with LoRA adapters (prompt→action JSON).""" target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", target_modules=target_modules, task_type="CAUSAL_LM", ) sft_config = SFTConfig( output_dir=os.path.join(cfg.output_dir, "sft"), per_device_train_batch_size=cfg.per_device_batch_size, per_device_eval_batch_size=cfg.per_device_batch_size, gradient_accumulation_steps=cfg.grad_accum_steps, learning_rate=cfg.sft_learning_rate, num_train_epochs=cfg.sft_epochs, logging_steps=10, save_steps=200, eval_steps=200, eval_strategy="steps", save_total_limit=2, max_length=cfg.max_seq_length, bf16=cfg.bf16, fp16=not cfg.bf16, report_to=["none"], ) trainer = SFTTrainer( model=cfg.model_name_or_path, # model id / path,SFTTrainer 会自己加载 args=sft_config, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, peft_config=peft_config, ) trainer.train() save_path = os.path.join(cfg.output_dir, "sft_model") trainer.save_model(save_path) return trainer.model # PEFT-wrapped model instance # ================== GRPO phase ================== def run_grpo(rl_dataset, base_model, tokenizer, cfg: PipelineConfig): """Run a GRPO RLAIF loop on top of the (optionally) SFT-initialized model.""" target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES if hasattr(base_model, "peft_config"): peft_config = None else: peft_config = LoraConfig( r=8, lora_alpha=16, lora_dropout=0.05, bias="none", target_modules=target_modules, task_type="CAUSAL_LM", ) generation_batch_size = cfg.per_device_batch_size * cfg.num_generations grpo_config = GRPOConfig( output_dir=os.path.join(cfg.output_dir, "grpo"), num_train_epochs=cfg.grpo_epochs, per_device_train_batch_size=cfg.per_device_batch_size, gradient_accumulation_steps=cfg.grad_accum_steps, logging_steps=10, save_strategy="epoch", save_total_limit=cfg.grpo_epochs, bf16=cfg.bf16, fp16=not cfg.bf16, learning_rate=cfg.grpo_learning_rate, max_prompt_length=cfg.max_seq_length, max_completion_length=cfg.max_completion_length, num_generations=cfg.num_generations, generation_batch_size=generation_batch_size, report_to=["none"], ) if peft_config is not None: trainer = GRPOTrainer( model=base_model, args=grpo_config, processing_class=tokenizer, reward_funcs=battleground_rlaif_reward, train_dataset=rl_dataset, peft_config=peft_config, ) else: trainer = GRPOTrainer( model=base_model, args=grpo_config, processing_class=tokenizer, reward_funcs=battleground_rlaif_reward, train_dataset=rl_dataset, ) trainer.train() trainer.save_model(os.path.join(cfg.output_dir, "grpo_model")) # ================== Main ================== def main(): cfg = parse_args() os.makedirs(cfg.output_dir, exist_ok=True) print(f"Using model: {cfg.model_name_or_path}") print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( cfg.model_name_or_path, use_fast=True, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # For GRPO, we want left padding tokenizer.padding_side = "left" print(f"Loading Battlegrounds dataset from: {cfg.data_file}") sft_train, sft_eval, rl_train = load_battleground_rlaif( cfg.data_file, train_on_all_data=cfg.train_on_all_data, input_mode=cfg.input_mode, ) # ----- SFT ----- if cfg.skip_sft: print("Skipping SFT phase; loading base model directly.") dtype = ( torch.bfloat16 if cfg.bf16 and torch.cuda.is_available() else (torch.float16 if torch.cuda.is_available() else torch.float32) ) model_kwargs = { "torch_dtype": dtype, "trust_remote_code": True, } if torch.cuda.is_available(): model_kwargs["device_map"] = "auto" base_model = AutoModelForCausalLM.from_pretrained( cfg.model_name_or_path, **model_kwargs ) else: print("Running SFT phase...") base_model = run_sft(sft_train, sft_eval, tokenizer, cfg) # ----- GRPO ----- if cfg.skip_grpo: print("Skipping GRPO phase; only SFT outputs (if any) were produced.") else: print("Running GRPO (RLAIF) phase...") run_grpo(rl_train, base_model, tokenizer, cfg) print("All done. Check outputs under:", cfg.output_dir) if __name__ == "__main__": main()