Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # train_battleground_rlaif.py | |
| # | |
| # SFT + GRPO (RLAIF style) on synthetic Hearthstone Battlegrounds data. | |
| # Dataset format: RL/datasets/battleground_rlaif_multicandidate.jsonl | |
| # Each row: { game_id, step_id, turn, phase, state, candidates[3], meta } | |
| # candidates: [{role, action{...}, reward}, ...] | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Dict, Any | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig | |
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if _SCRIPT_DIR not in sys.path: | |
| sys.path.append(_SCRIPT_DIR) | |
| from battleground_nl_utils import ( | |
| dataset_state_to_game_state, | |
| game_state_to_natural_language, | |
| ) | |
| # ================== Model paths & defaults ================== | |
| LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" | |
| def _resolve_default_model_id() -> str: | |
| env_override = os.environ.get("QWEN_INSTRUCT_MODEL") | |
| if env_override: | |
| return env_override | |
| if os.path.isdir(LOCAL_INSTRUCT_PATH): | |
| return LOCAL_INSTRUCT_PATH | |
| return "Qwen/Qwen3-4B-Instruct" | |
| DEFAULT_MODEL_ID = _resolve_default_model_id() | |
| DEFAULT_OUTPUT_DIR = "./battleground_rlaif_qwen" | |
| DEFAULT_DATA_FILE = "RL/datasets/battleground_rlaif_multicandidate_expert1_med0_bad-0_5.jsonl" | |
| DEFAULT_TARGET_MODULES = [ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ] | |
| # ================== Config dataclass ================== | |
| class PipelineConfig: | |
| model_name_or_path: str = DEFAULT_MODEL_ID | |
| output_dir: str = DEFAULT_OUTPUT_DIR | |
| data_file: str = DEFAULT_DATA_FILE | |
| input_mode: str = "json" | |
| max_seq_length: int = 1024 | |
| sft_epochs: int = 3 | |
| grpo_epochs: int = 3 | |
| bf16: bool = True | |
| per_device_batch_size: int = 4 # A800 80GB can handle larger batches | |
| grad_accum_steps: int = 4 | |
| sft_learning_rate: float = 1e-5 | |
| grpo_learning_rate: float = 5e-6 | |
| max_completion_length: int = 128 | |
| num_generations: int = 3 | |
| steps_per_generation: int = 1 # kept for symmetry; not directly used in this script | |
| target_modules: Optional[List[str]] = None | |
| skip_sft: bool = False | |
| skip_grpo: bool = False | |
| train_on_all_data: bool = False | |
| def parse_args() -> PipelineConfig: | |
| parser = argparse.ArgumentParser( | |
| description="Run SFT + GRPO (RLAIF) on Battlegrounds synthetic dataset." | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default=DEFAULT_MODEL_ID, | |
| help="Model id or local path for the Qwen instruct checkpoint.", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default=DEFAULT_OUTPUT_DIR, | |
| help="Directory for checkpoints and logs.", | |
| ) | |
| parser.add_argument( | |
| "--data-file", | |
| default=DEFAULT_DATA_FILE, | |
| help="Path to JSONL file with multi-candidate Battlegrounds data.", | |
| ) | |
| parser.add_argument( | |
| "--input-mode", | |
| choices=["json", "nl"], | |
| default="json", | |
| help="Input format for game state: 'json' uses raw JSON, 'nl' uses natural language description.", | |
| ) | |
| parser.add_argument("--max-seq-length", type=int, default=1024) | |
| parser.add_argument("--sft-epochs", type=int, default=35) | |
| parser.add_argument("--grpo-epochs", type=int, default=10) | |
| parser.add_argument("--per-device-batch-size", type=int, default=4, help="Batch size per device (default: 4 for A800 80GB)") | |
| parser.add_argument("--grad-accum-steps", type=int, default=4) | |
| parser.add_argument("--sft-learning-rate", type=float, default=1e-5) | |
| parser.add_argument("--grpo-learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--max-completion-length", type=int, default=128) | |
| parser.add_argument("--num-generations", type=int, default=3) | |
| parser.add_argument( | |
| "--target-modules", | |
| default=None, | |
| help="Comma-separated list of module names for LoRA (defaults to Qwen attn/FFN blocks).", | |
| ) | |
| parser.add_argument( | |
| "--disable-bf16", | |
| action="store_true", | |
| help="Force fp16/fp32 training if bf16 is not desired or unsupported.", | |
| ) | |
| parser.add_argument("--skip-sft", action="store_true", help="Skip the SFT phase.") | |
| parser.add_argument("--skip-grpo", action="store_true", help="Skip the GRPO phase.") | |
| parser.add_argument( | |
| "--train-on-all-data", | |
| action="store_true", | |
| help="Use all rows as training data (no hold-out split); SFT eval runs on the same data.", | |
| ) | |
| args = parser.parse_args() | |
| target_modules = ( | |
| [m.strip() for m in args.target_modules.split(",") if m.strip()] | |
| if args.target_modules | |
| else None | |
| ) | |
| return PipelineConfig( | |
| model_name_or_path=args.model, | |
| output_dir=args.output_dir, | |
| data_file=args.data_file, | |
| input_mode=args.input_mode, | |
| max_seq_length=args.max_seq_length, | |
| sft_epochs=args.sft_epochs, | |
| grpo_epochs=args.grpo_epochs, | |
| bf16=not args.disable_bf16, | |
| per_device_batch_size=args.per_device_batch_size, | |
| grad_accum_steps=args.grad_accum_steps, | |
| sft_learning_rate=args.sft_learning_rate, | |
| grpo_learning_rate=args.grpo_learning_rate, | |
| max_completion_length=args.max_completion_length, | |
| num_generations=args.num_generations, | |
| target_modules=target_modules, | |
| skip_sft=args.skip_sft, | |
| skip_grpo=args.skip_grpo, | |
| train_on_all_data=args.train_on_all_data, | |
| ) | |
| # ================== Data: Battlegrounds formatting ================== | |
| INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. | |
| Given the current game state as a JSON object, choose exactly one best action and respond with a single JSON object in this exact format: | |
| {"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "action". | |
| 3. Use 0-based integers for indices or null when not used. | |
| 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 5. "card_name" must exactly match a card name from the game state when required, otherwise null. | |
| Now here is the game state JSON: | |
| """ | |
| INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. | |
| Given the following natural language description of the current game state, choose exactly one best action and respond with a single JSON object in this exact format: | |
| {"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "action". | |
| 3. Use 0-based integers for indices or null when not used. | |
| 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 5. "card_name" must exactly match a card name from the game state when required, otherwise null. | |
| Now here is the description of the game state: | |
| """ | |
| def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: | |
| """ | |
| 把 state 打包成一个 JSON prompt: | |
| { | |
| "task": "battlegrounds_policy_v1", | |
| "phase": ..., | |
| "turn": ..., | |
| "state": {...} | |
| } | |
| """ | |
| if input_mode == "nl": | |
| game_state = dataset_state_to_game_state(example) | |
| nl_state = game_state_to_natural_language(game_state) | |
| prefix = INSTRUCTION_PREFIX_NL | |
| state_text = nl_state | |
| else: | |
| obj = { | |
| "task": "battlegrounds_policy_v1", | |
| "phase": example["phase"], | |
| "turn": example["turn"], | |
| "state": example["state"], | |
| } | |
| state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) | |
| prefix = INSTRUCTION_PREFIX | |
| return prefix + "\n" + state_text | |
| def _build_completion_from_action(action: Dict[str, Any]) -> str: | |
| """ | |
| 把 action 也打成 JSON completion: | |
| { "action": { ...action fields... } } | |
| """ | |
| return json.dumps({"action": action}, separators=(",", ":"), ensure_ascii=False) | |
| def load_battleground_rlaif( | |
| data_file: str, | |
| test_size: float = 0.1, | |
| seed: int = 42, | |
| train_on_all_data: bool = False, | |
| input_mode: str = "json", | |
| ): | |
| """ | |
| 从 JSONL 读取数据,构造: | |
| - SFT dataset: prompt + completion(只用 expert action) | |
| - RL dataset: prompt + candidates(多候选,给 reward_fn 用) | |
| """ | |
| raw = load_dataset( | |
| "json", | |
| data_files={"train": data_file}, | |
| )["train"] | |
| # 划分 train / eval(按 state 划分) | |
| if train_on_all_data: | |
| raw_train = raw | |
| raw_eval = raw | |
| else: | |
| split = raw.train_test_split(test_size=test_size, seed=seed) | |
| raw_train = split["train"] | |
| raw_eval = split["test"] | |
| def to_sft(example): | |
| # 选 expert candidate;如果没有显式 expert,就选 reward 最大的 | |
| candidates = example["candidates"] | |
| expert = None | |
| for c in candidates: | |
| if c.get("role") == "expert": | |
| expert = c | |
| break | |
| if expert is None: | |
| expert = max(candidates, key=lambda x: float(x.get("reward", 0.0))) | |
| prompt = _build_prompt(example, input_mode=input_mode) | |
| completion = _build_completion_from_action(expert["action"]) | |
| return { | |
| "prompt": prompt, | |
| "completion": completion, | |
| } | |
| def to_rl(example): | |
| prompt = _build_prompt(example, input_mode=input_mode) | |
| # candidates 保留给 reward_fn 使用 | |
| return { | |
| "prompt": prompt, | |
| "candidates": example["candidates"], | |
| } | |
| sft_train = raw_train.map(to_sft, remove_columns=raw_train.column_names) | |
| sft_eval = raw_eval.map(to_sft, remove_columns=raw_eval.column_names) | |
| rl_train = raw_train.map(to_rl, remove_columns=raw_train.column_names) | |
| return sft_train, sft_eval, rl_train | |
| # ================== Reward function for GRPO (RLAIF style) ================== | |
| def _parse_action_from_completion(text: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| 尝试把 model 的 completion 解析为 JSON action: | |
| - 期望格式: | |
| {"action": {...}} or {...} | |
| """ | |
| text = text.strip() | |
| try: | |
| obj = json.loads(text) | |
| except Exception: | |
| return None | |
| if isinstance(obj, dict): | |
| if "action" in obj and isinstance(obj["action"], dict): | |
| return obj["action"] | |
| return obj | |
| return None | |
| def _actions_equal(a: Dict[str, Any], b: Dict[str, Any]) -> bool: | |
| """ | |
| 简单 dict 相等比较: | |
| - 假设字段集合一致即可。 | |
| - 如果你之后想更 robust,可以只比较 type/tavern_index/hand_index/board_index/card_name。 | |
| """ | |
| return a == b | |
| def battleground_rlaif_reward( | |
| completions: List[str], | |
| candidates: List[List[Dict[str, Any]]], | |
| **kwargs, | |
| ) -> List[float]: | |
| """ | |
| RLAIF-style reward function for GRPOTrainer. | |
| 对每个 completion(一个 action JSON 文本): | |
| 1. 解析为 action dict | |
| 2. 与该样本的 candidates 中的 action 比较 | |
| 3. 如果完全匹配某个 candidate.action,则得到对应 reward (1.0 / 0.5 / 0.0) | |
| 4. 否则 reward = 0.0 (可以理解为“不是我们标记的任何动作”) | |
| 注意:TRL 会自动把 dataset 的 candidates 列展开复制到 batch 中, | |
| 所以这里 candidates 的长度与 completions 相同,一一对应。 | |
| """ | |
| rewards: List[float] = [] | |
| for comp_text, cand_list in zip(completions, candidates): | |
| act = _parse_action_from_completion(comp_text) | |
| if act is None: | |
| rewards.append(0.0) | |
| continue | |
| best_reward = 0.0 | |
| for cand in cand_list: | |
| cand_action = cand.get("action", {}) | |
| if _actions_equal(act, cand_action): | |
| r = float(cand.get("reward", 0.0)) | |
| if r > best_reward: | |
| best_reward = r | |
| rewards.append(best_reward) | |
| return rewards | |
| # ================== SFT phase ================== | |
| def run_sft(train_ds, eval_ds, tokenizer, cfg: PipelineConfig): | |
| """Run a short supervised fine-tuning pass with LoRA adapters (prompt→action JSON).""" | |
| target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| ) | |
| sft_config = SFTConfig( | |
| output_dir=os.path.join(cfg.output_dir, "sft"), | |
| per_device_train_batch_size=cfg.per_device_batch_size, | |
| per_device_eval_batch_size=cfg.per_device_batch_size, | |
| gradient_accumulation_steps=cfg.grad_accum_steps, | |
| learning_rate=cfg.sft_learning_rate, | |
| num_train_epochs=cfg.sft_epochs, | |
| logging_steps=10, | |
| save_steps=200, | |
| eval_steps=200, | |
| eval_strategy="steps", | |
| save_total_limit=2, | |
| max_length=cfg.max_seq_length, | |
| bf16=cfg.bf16, | |
| fp16=not cfg.bf16, | |
| report_to=["none"], | |
| ) | |
| trainer = SFTTrainer( | |
| model=cfg.model_name_or_path, # model id / path,SFTTrainer 会自己加载 | |
| args=sft_config, | |
| train_dataset=train_ds, | |
| eval_dataset=eval_ds, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| ) | |
| trainer.train() | |
| save_path = os.path.join(cfg.output_dir, "sft_model") | |
| trainer.save_model(save_path) | |
| return trainer.model # PEFT-wrapped model instance | |
| # ================== GRPO phase ================== | |
| def run_grpo(rl_dataset, base_model, tokenizer, cfg: PipelineConfig): | |
| """Run a GRPO RLAIF loop on top of the (optionally) SFT-initialized model.""" | |
| target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES | |
| if hasattr(base_model, "peft_config"): | |
| peft_config = None | |
| else: | |
| peft_config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| bias="none", | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| ) | |
| generation_batch_size = cfg.per_device_batch_size * cfg.num_generations | |
| grpo_config = GRPOConfig( | |
| output_dir=os.path.join(cfg.output_dir, "grpo"), | |
| num_train_epochs=cfg.grpo_epochs, | |
| per_device_train_batch_size=cfg.per_device_batch_size, | |
| gradient_accumulation_steps=cfg.grad_accum_steps, | |
| logging_steps=10, | |
| save_strategy="epoch", | |
| save_total_limit=cfg.grpo_epochs, | |
| bf16=cfg.bf16, | |
| fp16=not cfg.bf16, | |
| learning_rate=cfg.grpo_learning_rate, | |
| max_prompt_length=cfg.max_seq_length, | |
| max_completion_length=cfg.max_completion_length, | |
| num_generations=cfg.num_generations, | |
| generation_batch_size=generation_batch_size, | |
| report_to=["none"], | |
| ) | |
| if peft_config is not None: | |
| trainer = GRPOTrainer( | |
| model=base_model, | |
| args=grpo_config, | |
| processing_class=tokenizer, | |
| reward_funcs=battleground_rlaif_reward, | |
| train_dataset=rl_dataset, | |
| peft_config=peft_config, | |
| ) | |
| else: | |
| trainer = GRPOTrainer( | |
| model=base_model, | |
| args=grpo_config, | |
| processing_class=tokenizer, | |
| reward_funcs=battleground_rlaif_reward, | |
| train_dataset=rl_dataset, | |
| ) | |
| trainer.train() | |
| trainer.save_model(os.path.join(cfg.output_dir, "grpo_model")) | |
| # ================== Main ================== | |
| def main(): | |
| cfg = parse_args() | |
| os.makedirs(cfg.output_dir, exist_ok=True) | |
| print(f"Using model: {cfg.model_name_or_path}") | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| cfg.model_name_or_path, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # For GRPO, we want left padding | |
| tokenizer.padding_side = "left" | |
| print(f"Loading Battlegrounds dataset from: {cfg.data_file}") | |
| sft_train, sft_eval, rl_train = load_battleground_rlaif( | |
| cfg.data_file, | |
| train_on_all_data=cfg.train_on_all_data, | |
| input_mode=cfg.input_mode, | |
| ) | |
| # ----- SFT ----- | |
| if cfg.skip_sft: | |
| print("Skipping SFT phase; loading base model directly.") | |
| dtype = ( | |
| torch.bfloat16 | |
| if cfg.bf16 and torch.cuda.is_available() | |
| else (torch.float16 if torch.cuda.is_available() else torch.float32) | |
| ) | |
| model_kwargs = { | |
| "torch_dtype": dtype, | |
| "trust_remote_code": True, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| cfg.model_name_or_path, **model_kwargs | |
| ) | |
| else: | |
| print("Running SFT phase...") | |
| base_model = run_sft(sft_train, sft_eval, tokenizer, cfg) | |
| # ----- GRPO ----- | |
| if cfg.skip_grpo: | |
| print("Skipping GRPO phase; only SFT outputs (if any) were produced.") | |
| else: | |
| print("Running GRPO (RLAIF) phase...") | |
| run_grpo(rl_train, base_model, tokenizer, cfg) | |
| print("All done. Check outputs under:", cfg.output_dir) | |
| if __name__ == "__main__": | |
| main() | |