deepbattler / RL /train_battleground_rlaif.py
wyksdsg's picture
Upload folder using huggingface_hub
787c99c verified
#!/usr/bin/env python
# train_battleground_rlaif.py
#
# SFT + GRPO (RLAIF style) on synthetic Hearthstone Battlegrounds data.
# Dataset format: RL/datasets/battleground_rlaif_multicandidate.jsonl
# Each row: { game_id, step_id, turn, phase, state, candidates[3], meta }
# candidates: [{role, action{...}, reward}, ...]
import argparse
import json
import os
import sys
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
if _SCRIPT_DIR not in sys.path:
sys.path.append(_SCRIPT_DIR)
from battleground_nl_utils import (
dataset_state_to_game_state,
game_state_to_natural_language,
)
# ================== Model paths & defaults ==================
LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507"
def _resolve_default_model_id() -> str:
env_override = os.environ.get("QWEN_INSTRUCT_MODEL")
if env_override:
return env_override
if os.path.isdir(LOCAL_INSTRUCT_PATH):
return LOCAL_INSTRUCT_PATH
return "Qwen/Qwen3-4B-Instruct"
DEFAULT_MODEL_ID = _resolve_default_model_id()
DEFAULT_OUTPUT_DIR = "./battleground_rlaif_qwen"
DEFAULT_DATA_FILE = "RL/datasets/battleground_rlaif_multicandidate_expert1_med0_bad-0_5.jsonl"
DEFAULT_TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
# ================== Config dataclass ==================
@dataclass
class PipelineConfig:
model_name_or_path: str = DEFAULT_MODEL_ID
output_dir: str = DEFAULT_OUTPUT_DIR
data_file: str = DEFAULT_DATA_FILE
input_mode: str = "json"
max_seq_length: int = 1024
sft_epochs: int = 3
grpo_epochs: int = 3
bf16: bool = True
per_device_batch_size: int = 4 # A800 80GB can handle larger batches
grad_accum_steps: int = 4
sft_learning_rate: float = 1e-5
grpo_learning_rate: float = 5e-6
max_completion_length: int = 128
num_generations: int = 3
steps_per_generation: int = 1 # kept for symmetry; not directly used in this script
target_modules: Optional[List[str]] = None
skip_sft: bool = False
skip_grpo: bool = False
train_on_all_data: bool = False
def parse_args() -> PipelineConfig:
parser = argparse.ArgumentParser(
description="Run SFT + GRPO (RLAIF) on Battlegrounds synthetic dataset."
)
parser.add_argument(
"--model",
default=DEFAULT_MODEL_ID,
help="Model id or local path for the Qwen instruct checkpoint.",
)
parser.add_argument(
"--output-dir",
default=DEFAULT_OUTPUT_DIR,
help="Directory for checkpoints and logs.",
)
parser.add_argument(
"--data-file",
default=DEFAULT_DATA_FILE,
help="Path to JSONL file with multi-candidate Battlegrounds data.",
)
parser.add_argument(
"--input-mode",
choices=["json", "nl"],
default="json",
help="Input format for game state: 'json' uses raw JSON, 'nl' uses natural language description.",
)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--sft-epochs", type=int, default=35)
parser.add_argument("--grpo-epochs", type=int, default=10)
parser.add_argument("--per-device-batch-size", type=int, default=4, help="Batch size per device (default: 4 for A800 80GB)")
parser.add_argument("--grad-accum-steps", type=int, default=4)
parser.add_argument("--sft-learning-rate", type=float, default=1e-5)
parser.add_argument("--grpo-learning-rate", type=float, default=5e-6)
parser.add_argument("--max-completion-length", type=int, default=128)
parser.add_argument("--num-generations", type=int, default=3)
parser.add_argument(
"--target-modules",
default=None,
help="Comma-separated list of module names for LoRA (defaults to Qwen attn/FFN blocks).",
)
parser.add_argument(
"--disable-bf16",
action="store_true",
help="Force fp16/fp32 training if bf16 is not desired or unsupported.",
)
parser.add_argument("--skip-sft", action="store_true", help="Skip the SFT phase.")
parser.add_argument("--skip-grpo", action="store_true", help="Skip the GRPO phase.")
parser.add_argument(
"--train-on-all-data",
action="store_true",
help="Use all rows as training data (no hold-out split); SFT eval runs on the same data.",
)
args = parser.parse_args()
target_modules = (
[m.strip() for m in args.target_modules.split(",") if m.strip()]
if args.target_modules
else None
)
return PipelineConfig(
model_name_or_path=args.model,
output_dir=args.output_dir,
data_file=args.data_file,
input_mode=args.input_mode,
max_seq_length=args.max_seq_length,
sft_epochs=args.sft_epochs,
grpo_epochs=args.grpo_epochs,
bf16=not args.disable_bf16,
per_device_batch_size=args.per_device_batch_size,
grad_accum_steps=args.grad_accum_steps,
sft_learning_rate=args.sft_learning_rate,
grpo_learning_rate=args.grpo_learning_rate,
max_completion_length=args.max_completion_length,
num_generations=args.num_generations,
target_modules=target_modules,
skip_sft=args.skip_sft,
skip_grpo=args.skip_grpo,
train_on_all_data=args.train_on_all_data,
)
# ================== Data: Battlegrounds formatting ==================
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose exactly one best action and respond with a single JSON object in this exact format:
{"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "action".
3. Use 0-based integers for indices or null when not used.
4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
5. "card_name" must exactly match a card name from the game state when required, otherwise null.
Now here is the game state JSON:
"""
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
Given the following natural language description of the current game state, choose exactly one best action and respond with a single JSON object in this exact format:
{"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "action".
3. Use 0-based integers for indices or null when not used.
4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
5. "card_name" must exactly match a card name from the game state when required, otherwise null.
Now here is the description of the game state:
"""
def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
"""
把 state 打包成一个 JSON prompt:
{
"task": "battlegrounds_policy_v1",
"phase": ...,
"turn": ...,
"state": {...}
}
"""
if input_mode == "nl":
game_state = dataset_state_to_game_state(example)
nl_state = game_state_to_natural_language(game_state)
prefix = INSTRUCTION_PREFIX_NL
state_text = nl_state
else:
obj = {
"task": "battlegrounds_policy_v1",
"phase": example["phase"],
"turn": example["turn"],
"state": example["state"],
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
prefix = INSTRUCTION_PREFIX
return prefix + "\n" + state_text
def _build_completion_from_action(action: Dict[str, Any]) -> str:
"""
把 action 也打成 JSON completion:
{ "action": { ...action fields... } }
"""
return json.dumps({"action": action}, separators=(",", ":"), ensure_ascii=False)
def load_battleground_rlaif(
data_file: str,
test_size: float = 0.1,
seed: int = 42,
train_on_all_data: bool = False,
input_mode: str = "json",
):
"""
从 JSONL 读取数据,构造:
- SFT dataset: prompt + completion(只用 expert action)
- RL dataset: prompt + candidates(多候选,给 reward_fn 用)
"""
raw = load_dataset(
"json",
data_files={"train": data_file},
)["train"]
# 划分 train / eval(按 state 划分)
if train_on_all_data:
raw_train = raw
raw_eval = raw
else:
split = raw.train_test_split(test_size=test_size, seed=seed)
raw_train = split["train"]
raw_eval = split["test"]
def to_sft(example):
# 选 expert candidate;如果没有显式 expert,就选 reward 最大的
candidates = example["candidates"]
expert = None
for c in candidates:
if c.get("role") == "expert":
expert = c
break
if expert is None:
expert = max(candidates, key=lambda x: float(x.get("reward", 0.0)))
prompt = _build_prompt(example, input_mode=input_mode)
completion = _build_completion_from_action(expert["action"])
return {
"prompt": prompt,
"completion": completion,
}
def to_rl(example):
prompt = _build_prompt(example, input_mode=input_mode)
# candidates 保留给 reward_fn 使用
return {
"prompt": prompt,
"candidates": example["candidates"],
}
sft_train = raw_train.map(to_sft, remove_columns=raw_train.column_names)
sft_eval = raw_eval.map(to_sft, remove_columns=raw_eval.column_names)
rl_train = raw_train.map(to_rl, remove_columns=raw_train.column_names)
return sft_train, sft_eval, rl_train
# ================== Reward function for GRPO (RLAIF style) ==================
def _parse_action_from_completion(text: str) -> Optional[Dict[str, Any]]:
"""
尝试把 model 的 completion 解析为 JSON action:
- 期望格式:
{"action": {...}} or {...}
"""
text = text.strip()
try:
obj = json.loads(text)
except Exception:
return None
if isinstance(obj, dict):
if "action" in obj and isinstance(obj["action"], dict):
return obj["action"]
return obj
return None
def _actions_equal(a: Dict[str, Any], b: Dict[str, Any]) -> bool:
"""
简单 dict 相等比较:
- 假设字段集合一致即可。
- 如果你之后想更 robust,可以只比较 type/tavern_index/hand_index/board_index/card_name。
"""
return a == b
def battleground_rlaif_reward(
completions: List[str],
candidates: List[List[Dict[str, Any]]],
**kwargs,
) -> List[float]:
"""
RLAIF-style reward function for GRPOTrainer.
对每个 completion(一个 action JSON 文本):
1. 解析为 action dict
2. 与该样本的 candidates 中的 action 比较
3. 如果完全匹配某个 candidate.action,则得到对应 reward (1.0 / 0.5 / 0.0)
4. 否则 reward = 0.0 (可以理解为“不是我们标记的任何动作”)
注意:TRL 会自动把 dataset 的 candidates 列展开复制到 batch 中,
所以这里 candidates 的长度与 completions 相同,一一对应。
"""
rewards: List[float] = []
for comp_text, cand_list in zip(completions, candidates):
act = _parse_action_from_completion(comp_text)
if act is None:
rewards.append(0.0)
continue
best_reward = 0.0
for cand in cand_list:
cand_action = cand.get("action", {})
if _actions_equal(act, cand_action):
r = float(cand.get("reward", 0.0))
if r > best_reward:
best_reward = r
rewards.append(best_reward)
return rewards
# ================== SFT phase ==================
def run_sft(train_ds, eval_ds, tokenizer, cfg: PipelineConfig):
"""Run a short supervised fine-tuning pass with LoRA adapters (prompt→action JSON)."""
target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=target_modules,
task_type="CAUSAL_LM",
)
sft_config = SFTConfig(
output_dir=os.path.join(cfg.output_dir, "sft"),
per_device_train_batch_size=cfg.per_device_batch_size,
per_device_eval_batch_size=cfg.per_device_batch_size,
gradient_accumulation_steps=cfg.grad_accum_steps,
learning_rate=cfg.sft_learning_rate,
num_train_epochs=cfg.sft_epochs,
logging_steps=10,
save_steps=200,
eval_steps=200,
eval_strategy="steps",
save_total_limit=2,
max_length=cfg.max_seq_length,
bf16=cfg.bf16,
fp16=not cfg.bf16,
report_to=["none"],
)
trainer = SFTTrainer(
model=cfg.model_name_or_path, # model id / path,SFTTrainer 会自己加载
args=sft_config,
train_dataset=train_ds,
eval_dataset=eval_ds,
processing_class=tokenizer,
peft_config=peft_config,
)
trainer.train()
save_path = os.path.join(cfg.output_dir, "sft_model")
trainer.save_model(save_path)
return trainer.model # PEFT-wrapped model instance
# ================== GRPO phase ==================
def run_grpo(rl_dataset, base_model, tokenizer, cfg: PipelineConfig):
"""Run a GRPO RLAIF loop on top of the (optionally) SFT-initialized model."""
target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES
if hasattr(base_model, "peft_config"):
peft_config = None
else:
peft_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
target_modules=target_modules,
task_type="CAUSAL_LM",
)
generation_batch_size = cfg.per_device_batch_size * cfg.num_generations
grpo_config = GRPOConfig(
output_dir=os.path.join(cfg.output_dir, "grpo"),
num_train_epochs=cfg.grpo_epochs,
per_device_train_batch_size=cfg.per_device_batch_size,
gradient_accumulation_steps=cfg.grad_accum_steps,
logging_steps=10,
save_strategy="epoch",
save_total_limit=cfg.grpo_epochs,
bf16=cfg.bf16,
fp16=not cfg.bf16,
learning_rate=cfg.grpo_learning_rate,
max_prompt_length=cfg.max_seq_length,
max_completion_length=cfg.max_completion_length,
num_generations=cfg.num_generations,
generation_batch_size=generation_batch_size,
report_to=["none"],
)
if peft_config is not None:
trainer = GRPOTrainer(
model=base_model,
args=grpo_config,
processing_class=tokenizer,
reward_funcs=battleground_rlaif_reward,
train_dataset=rl_dataset,
peft_config=peft_config,
)
else:
trainer = GRPOTrainer(
model=base_model,
args=grpo_config,
processing_class=tokenizer,
reward_funcs=battleground_rlaif_reward,
train_dataset=rl_dataset,
)
trainer.train()
trainer.save_model(os.path.join(cfg.output_dir, "grpo_model"))
# ================== Main ==================
def main():
cfg = parse_args()
os.makedirs(cfg.output_dir, exist_ok=True)
print(f"Using model: {cfg.model_name_or_path}")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_name_or_path,
use_fast=True,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# For GRPO, we want left padding
tokenizer.padding_side = "left"
print(f"Loading Battlegrounds dataset from: {cfg.data_file}")
sft_train, sft_eval, rl_train = load_battleground_rlaif(
cfg.data_file,
train_on_all_data=cfg.train_on_all_data,
input_mode=cfg.input_mode,
)
# ----- SFT -----
if cfg.skip_sft:
print("Skipping SFT phase; loading base model directly.")
dtype = (
torch.bfloat16
if cfg.bf16 and torch.cuda.is_available()
else (torch.float16 if torch.cuda.is_available() else torch.float32)
)
model_kwargs = {
"torch_dtype": dtype,
"trust_remote_code": True,
}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
base_model = AutoModelForCausalLM.from_pretrained(
cfg.model_name_or_path, **model_kwargs
)
else:
print("Running SFT phase...")
base_model = run_sft(sft_train, sft_eval, tokenizer, cfg)
# ----- GRPO -----
if cfg.skip_grpo:
print("Skipping GRPO phase; only SFT outputs (if any) were produced.")
else:
print("Running GRPO (RLAIF) phase...")
run_grpo(rl_train, base_model, tokenizer, cfg)
print("All done. Check outputs under:", cfg.output_dir)
if __name__ == "__main__":
main()