Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # train_battleground_rlaif_gamehistory.py | |
| # | |
| # SFT + GRPO (RLAIF style) on Hearthstone Battlegrounds "game_history" data. | |
| # | |
| # Expected data format per JSON file (per game): | |
| # { | |
| # "game_metadata": {...}, | |
| # "turns": [ | |
| # { | |
| # "turn_number": 0, | |
| # "phase": "PlayerTurn", | |
| # "state": { # nested game_state / player_hero / resources / board_state | |
| # "game_state": {...}, | |
| # "player_hero": {...}, | |
| # "resources": {...}, | |
| # "board_state": {...} | |
| # }, | |
| # "candidates": [ # RLAIF annotations you add | |
| # {"role": "expert", "actions": [{...}, {...}], "reward": 1.0}, | |
| # {"role": "medium", "actions": [{...}], "reward": 0.0}, | |
| # {"role": "bad", "actions": [{...}], "reward": -0.5} | |
| # ], | |
| # ... other fields like battle_result, reward, etc. ... | |
| # }, | |
| # ... | |
| # ] | |
| # } | |
| # | |
| # Each candidate's "actions" field is a SEQUENCE (list) of atomic Battlegrounds | |
| # actions, where each atomic action dict uses the schema from the original | |
| # RLAIF pipeline: | |
| # { | |
| # "type": "BUY_FROM_TAVERN" | "PLAY_FROM_HAND" | "SELL_FROM_BOARD" | | |
| # "HERO_POWER" | "ROLL" | "UPGRADE_TAVERN" | "FREEZE" | "END_TURN", | |
| # "tavern_index": int or null, | |
| # "hand_index": int or null, | |
| # "board_index": int or null, | |
| # "card_name": string or null | |
| # } | |
| # | |
| # The loader flattens all labeled turns (those with "candidates") into per-step | |
| # records while preserving the nested "state" structure. | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Optional, Dict, Any | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import LoraConfig | |
| from trl import SFTTrainer, SFTConfig, GRPOTrainer, GRPOConfig | |
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if _SCRIPT_DIR not in sys.path: | |
| sys.path.append(_SCRIPT_DIR) | |
| from battleground_nl_utils import ( | |
| game_state_to_natural_language, | |
| ) | |
| # ================== Model paths & defaults ================== | |
| LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" | |
| def _resolve_default_model_id() -> str: | |
| env_override = os.environ.get("QWEN_INSTRUCT_MODEL") | |
| if env_override: | |
| return env_override | |
| if os.path.isdir(LOCAL_INSTRUCT_PATH): | |
| return LOCAL_INSTRUCT_PATH | |
| return "Qwen/Qwen3-4B-Instruct" | |
| DEFAULT_MODEL_ID = _resolve_default_model_id() | |
| DEFAULT_OUTPUT_DIR = "./battleground_rlaif_qwen_gamehistory" | |
| # By default, point to a single game-history style file. You can override | |
| # with a directory containing many such JSONs. | |
| DEFAULT_DATA_FILE = "RL/datasets/game_history_fixed.json" | |
| DEFAULT_TARGET_MODULES = [ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ] | |
| # ================== Config dataclass ================== | |
| class PipelineConfig: | |
| model_name_or_path: str = DEFAULT_MODEL_ID | |
| output_dir: str = DEFAULT_OUTPUT_DIR | |
| data_file: str = DEFAULT_DATA_FILE | |
| input_mode: str = "json" # "json" uses nested game_history state; "nl" uses natural language | |
| max_seq_length: int = 1024 | |
| sft_epochs: int = 3 | |
| grpo_epochs: int = 3 | |
| bf16: bool = True | |
| per_device_batch_size: int = 4 | |
| grad_accum_steps: int = 4 | |
| sft_learning_rate: float = 1e-5 | |
| grpo_learning_rate: float = 5e-6 | |
| max_completion_length: int = 128 | |
| num_generations: int = 3 | |
| steps_per_generation: int = 1 | |
| target_modules: Optional[List[str]] = None | |
| skip_sft: bool = False | |
| skip_grpo: bool = False | |
| train_on_all_data: bool = False | |
| def parse_args() -> PipelineConfig: | |
| parser = argparse.ArgumentParser( | |
| description="Run SFT + GRPO (RLAIF) on Battlegrounds game_history dataset." | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| default=DEFAULT_MODEL_ID, | |
| help="Model id or local path for the Qwen instruct checkpoint.", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default=DEFAULT_OUTPUT_DIR, | |
| help="Directory for checkpoints and logs.", | |
| ) | |
| parser.add_argument( | |
| "--data-file", | |
| default=DEFAULT_DATA_FILE, | |
| help=( | |
| "Path to a game_history-style JSON file or a directory of such files. " | |
| "Each file should have {game_metadata, turns[...]} and each labeled turn " | |
| "must contain a 'candidates' list." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--input-mode", | |
| choices=["json", "nl"], | |
| default="json", | |
| help=( | |
| "Input format for game state: 'json' uses nested game_history JSON; " | |
| "'nl' converts the nested state to natural language." | |
| ), | |
| ) | |
| parser.add_argument("--max-seq-length", type=int, default=1024) | |
| parser.add_argument("--sft-epochs", type=int, default=20) | |
| parser.add_argument("--grpo-epochs", type=int, default=3) | |
| parser.add_argument( | |
| "--per-device-batch-size", | |
| type=int, | |
| default=4, | |
| help="Batch size per device (default: 4 for A800 80GB)", | |
| ) | |
| parser.add_argument("--grad-accum-steps", type=int, default=4) | |
| parser.add_argument("--sft-learning-rate", type=float, default=1e-5) | |
| parser.add_argument("--grpo-learning-rate", type=float, default=5e-6) | |
| parser.add_argument("--max-completion-length", type=int, default=128) | |
| parser.add_argument("--num-generations", type=int, default=3) | |
| parser.add_argument( | |
| "--target-modules", | |
| default=None, | |
| help="Comma-separated list of module names for LoRA (defaults to Qwen attn/FFN blocks).", | |
| ) | |
| parser.add_argument( | |
| "--disable-bf16", | |
| action="store_true", | |
| help="Force fp16/fp32 training if bf16 is not desired or unsupported.", | |
| ) | |
| parser.add_argument("--skip-sft", action="store_true", help="Skip the SFT phase.") | |
| parser.add_argument("--skip-grpo", action="store_true", help="Skip the GRPO phase.") | |
| parser.add_argument( | |
| "--train-on-all-data", | |
| action="store_true", | |
| help="Use all rows as training data (no hold-out split); SFT eval runs on the same data.", | |
| ) | |
| args = parser.parse_args() | |
| target_modules = ( | |
| [m.strip() for m in args.target_modules.split(",") if m.strip()] | |
| if args.target_modules | |
| else None | |
| ) | |
| return PipelineConfig( | |
| model_name_or_path=args.model, | |
| output_dir=args.output_dir, | |
| data_file=args.data_file, | |
| input_mode=args.input_mode, | |
| max_seq_length=args.max_seq_length, | |
| sft_epochs=args.sft_epochs, | |
| grpo_epochs=args.grpo_epochs, | |
| bf16=not args.disable_bf16, | |
| per_device_batch_size=args.per_device_batch_size, | |
| grad_accum_steps=args.grad_accum_steps, | |
| sft_learning_rate=args.sft_learning_rate, | |
| grpo_learning_rate=args.grpo_learning_rate, | |
| max_completion_length=args.max_completion_length, | |
| num_generations=args.num_generations, | |
| target_modules=target_modules, | |
| skip_sft=args.skip_sft, | |
| skip_grpo=args.skip_grpo, | |
| train_on_all_data=args.train_on_all_data, | |
| ) | |
| # ================== Data: Battlegrounds formatting ================== | |
| INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. | |
| Given the current game state as a JSON object, choose the best full-turn sequence | |
| of actions and respond with a single JSON object in this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the game state JSON: | |
| """ | |
| INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. | |
| Given the following natural language description of the current game state, choose | |
| the best full-turn sequence of actions and respond with a single JSON object in | |
| this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the description of the game state: | |
| """ | |
| def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: | |
| """Build a prompt from a flattened game_history example. | |
| The example has: | |
| - phase: string (e.g., "PlayerTurn") | |
| - turn: int | |
| - state: nested dict with keys: game_state, player_hero, resources, board_state | |
| """ | |
| state = example.get("state", {}) | |
| if input_mode == "nl": | |
| # state is already in the game_state / player_hero / resources / board_state shape | |
| nl_state = game_state_to_natural_language(state) | |
| prefix = INSTRUCTION_PREFIX_NL | |
| state_text = nl_state | |
| else: | |
| # JSON mode: wrap the nested state in a small task object. | |
| gs = state.get("game_state", {}) or {} | |
| phase = example.get("phase", gs.get("phase", "Unknown")) | |
| turn = example.get("turn", gs.get("turn_number", 0)) | |
| obj = { | |
| "task": "battlegrounds_policy_v1", | |
| "phase": phase, | |
| "turn": turn, | |
| "state": state, | |
| } | |
| state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) | |
| prefix = INSTRUCTION_PREFIX | |
| return prefix + "\n" + state_text | |
| def _build_completion_from_actions(actions: List[Dict[str, Any]]) -> str: | |
| """Pack a sequence of atomic actions into the expected JSON completion. | |
| {"actions": [ {...}, {...}, ... ]} | |
| """ | |
| return json.dumps({"actions": actions}, separators=(",", ":"), ensure_ascii=False) | |
| def load_gamehistory_rlaif( | |
| data_file: str, | |
| test_size: float = 0.1, | |
| seed: int = 42, | |
| train_on_all_data: bool = False, | |
| input_mode: str = "json", | |
| ): | |
| """Load game_history-style JSON data and build SFT & RL datasets. | |
| - data_file can be: | |
| * a single JSON file with {game_metadata, turns: [...]} structure; | |
| * a JSON file containing a list of such game objects; | |
| * a directory containing multiple .json files in either of the above forms. | |
| - Each labeled turn must contain a "candidates" list; turns without candidates | |
| are skipped. | |
| """ | |
| path = Path(data_file) | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Data file or directory not found: {data_file}") | |
| rows: List[Dict[str, Any]] = [] | |
| def _consume_game_obj(game_obj: Dict[str, Any], game_id_hint: str) -> None: | |
| meta = game_obj.get("game_metadata", {}) or {} | |
| turns = game_obj.get("turns", []) or [] | |
| for t in turns: | |
| state = t.get("state", {}) or {} | |
| candidates = t.get("candidates") | |
| if not candidates: | |
| # Skip unlabeled turns (no RLAIF annotations yet) | |
| continue | |
| gs = state.get("game_state", {}) or {} | |
| phase = t.get("phase") or gs.get("phase", "PlayerTurn") | |
| turn = gs.get("turn_number", t.get("turn_number", 0)) | |
| row_meta = { | |
| "game_metadata": meta, | |
| "battle_result": t.get("battle_result"), | |
| "health_before_battle": t.get("health_before_battle"), | |
| "health_after_battle": t.get("health_after_battle"), | |
| "health_change": t.get("health_change"), | |
| "action_taken": t.get("action_taken"), | |
| } | |
| rows.append( | |
| { | |
| "game_id": meta.get("game_id") or game_id_hint, | |
| "step_id": t.get("turn_number", turn), | |
| "turn": turn, | |
| "phase": phase, | |
| "state": state, | |
| "candidates": candidates, | |
| "meta": row_meta, | |
| } | |
| ) | |
| def _load_one_json_file(p: Path) -> None: | |
| with p.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| # Case 1: single game_history object with turns | |
| if isinstance(data, dict) and "turns" in data: | |
| _consume_game_obj(data, game_id_hint=p.stem) | |
| # Case 2: already-flattened per-turn rows in a list | |
| elif isinstance(data, list) and data and isinstance(data[0], dict) and "state" in data[0]: | |
| for idx, row in enumerate(data): | |
| if not isinstance(row, dict): | |
| raise ValueError( | |
| f"Unsupported JSON row at index {idx} in file {p}: expected dict with 'state'." | |
| ) | |
| candidates = row.get("candidates") | |
| if not candidates: | |
| # Skip unlabeled rows (no RLAIF annotations yet) | |
| continue | |
| state = row.get("state", {}) or {} | |
| gs = state.get("game_state", {}) or {} | |
| if "phase" not in row: | |
| row["phase"] = gs.get("phase", "PlayerTurn") | |
| if "turn" not in row: | |
| row["turn"] = gs.get("turn_number", row.get("step_id", 0)) | |
| # Ensure at least the keys expected downstream are present; keep any | |
| # extra metadata fields as-is. | |
| rows.append(row) | |
| # Case 3: list of game_history objects with turns | |
| elif isinstance(data, list): | |
| for idx, item in enumerate(data): | |
| if isinstance(item, dict) and "turns" in item: | |
| game_id_hint = item.get("game_metadata", {}).get("game_id") or f"{p.stem}_{idx}" | |
| _consume_game_obj(item, game_id_hint=game_id_hint) | |
| else: | |
| raise ValueError( | |
| f"Unsupported JSON object at index {idx} in file {p}: expected game_history with 'turns' or flat rows with 'state'." | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Unsupported JSON structure in file {p}: expected dict with 'turns', list of such dicts, or list of flat rows with 'state'." | |
| ) | |
| if path.is_dir(): | |
| json_files = sorted(path.glob("*.json")) | |
| if not json_files: | |
| raise ValueError(f"No .json files found in directory: {data_file}") | |
| for p in json_files: | |
| _load_one_json_file(p) | |
| else: | |
| _load_one_json_file(path) | |
| if not rows: | |
| raise ValueError( | |
| "No labeled turns (with 'candidates') were found in the provided data. " | |
| "Make sure each turn you want to train on has a non-empty 'candidates' list." | |
| ) | |
| raw = Dataset.from_list(rows) | |
| # Train / eval split | |
| if train_on_all_data: | |
| raw_train = raw | |
| raw_eval = raw | |
| else: | |
| split = raw.train_test_split(test_size=test_size, seed=seed) | |
| raw_train = split["train"] | |
| raw_eval = split["test"] | |
| def to_sft(example: Dict[str, Any]) -> Dict[str, Any]: | |
| # Pick the expert candidate; if not present, fall back to max reward. | |
| candidates = example["candidates"] | |
| expert = None | |
| for c in candidates: | |
| if c.get("role") == "expert": | |
| expert = c | |
| break | |
| if expert is None: | |
| expert = max(candidates, key=lambda x: float(x.get("reward", 0.0))) | |
| prompt = _build_prompt(example, input_mode=input_mode) | |
| # In the game_history pipeline, each candidate carries a SEQUENCE of | |
| # atomic actions under the "actions" key. | |
| completion = _build_completion_from_actions(expert["actions"]) | |
| return { | |
| "prompt": prompt, | |
| "completion": completion, | |
| } | |
| def to_rl(example: Dict[str, Any]) -> Dict[str, Any]: | |
| prompt = _build_prompt(example, input_mode=input_mode) | |
| return { | |
| "prompt": prompt, | |
| "candidates": example["candidates"], | |
| } | |
| sft_train = raw_train.map(to_sft, remove_columns=raw_train.column_names) | |
| sft_eval = raw_eval.map(to_sft, remove_columns=raw_eval.column_names) | |
| rl_train = raw_train.map(to_rl, remove_columns=raw_train.column_names) | |
| return sft_train, sft_eval, rl_train | |
| # ================== Reward function for GRPO (RLAIF style) ================== | |
| def _parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]: | |
| """Parse a model completion into a sequence of atomic action dicts. | |
| Expected formats: | |
| - {"actions": [ {...}, {...}, ... ]} | |
| - {"action": [ {...}, {...}, ... ]} # tolerated fallback | |
| """ | |
| text = text.strip() | |
| # Try to locate a JSON object within the text (in case of extra chatter | |
| # before/after the JSON), similar to the eval-time parser. | |
| start_idx = text.find("{") | |
| if start_idx == -1: | |
| return None | |
| end_idx = text.rfind("}") | |
| if end_idx == -1: | |
| return None | |
| json_str = text[start_idx : end_idx + 1] | |
| try: | |
| obj = json.loads(json_str) | |
| except Exception: | |
| return None | |
| if not isinstance(obj, dict): | |
| return None | |
| seq = None | |
| # Preferred key from the instruction | |
| if "actions" in obj: | |
| if isinstance(obj["actions"], list): | |
| seq = obj["actions"] | |
| elif isinstance(obj["actions"], dict): | |
| # Tolerate a single dict instead of a list | |
| seq = [obj["actions"]] | |
| # Fallback key for older/variant outputs | |
| elif "action" in obj: | |
| if isinstance(obj["action"], list): | |
| seq = obj["action"] | |
| elif isinstance(obj["action"], dict): | |
| seq = [obj["action"]] | |
| if seq is None: | |
| return None | |
| # Ensure we have a list of dicts. | |
| actions: List[Dict[str, Any]] = [] | |
| for step in seq: | |
| if not isinstance(step, dict): | |
| return None | |
| actions.append(step) | |
| return actions | |
| def _action_sequences_equal( | |
| a: List[Dict[str, Any]], b: List[Dict[str, Any]] | |
| ) -> bool: | |
| """Strict equality for sequences of atomic actions. | |
| Both length and each per-step dict must match exactly. This relies on a | |
| canonical action representation in the data and model outputs. | |
| """ | |
| if len(a) != len(b): | |
| return False | |
| for s1, s2 in zip(a, b): | |
| if s1 != s2: | |
| return False | |
| return True | |
| def battleground_rlaif_reward( | |
| completions: List[str], | |
| candidates: List[List[Dict[str, Any]]], | |
| **kwargs, | |
| ) -> List[float]: | |
| """RLAIF-style reward function for GRPOTrainer. | |
| For each completion (one JSON text): | |
| 1. Parse into a sequence of atomic actions. | |
| 2. Compare with the example's candidates[i].actions. | |
| 3. If it exactly matches a candidate.actions sequence, return that | |
| candidate's reward. | |
| 4. Otherwise reward = 0.0. | |
| """ | |
| rewards: List[float] = [] | |
| for comp_text, cand_list in zip(completions, candidates): | |
| seq = _parse_actions_from_completion(comp_text) | |
| if seq is None: | |
| rewards.append(0.0) | |
| continue | |
| best_reward = 0.0 | |
| for cand in cand_list: | |
| cand_actions = cand.get("actions") | |
| if not isinstance(cand_actions, list): | |
| continue | |
| if _action_sequences_equal(seq, cand_actions): | |
| r = float(cand.get("reward", 0.0)) | |
| if r > best_reward: | |
| best_reward = r | |
| rewards.append(best_reward) | |
| return rewards | |
| # ================== SFT phase ================== | |
| def run_sft(train_ds, eval_ds, tokenizer, cfg: PipelineConfig): | |
| """Run a supervised fine-tuning pass with LoRA adapters (prompt→action JSON).""" | |
| target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| bias="none", | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| ) | |
| sft_config = SFTConfig( | |
| output_dir=os.path.join(cfg.output_dir, "sft"), | |
| per_device_train_batch_size=cfg.per_device_batch_size, | |
| per_device_eval_batch_size=cfg.per_device_batch_size, | |
| gradient_accumulation_steps=cfg.grad_accum_steps, | |
| learning_rate=cfg.sft_learning_rate, | |
| num_train_epochs=cfg.sft_epochs, | |
| logging_steps=10, | |
| save_steps=200, | |
| eval_steps=200, | |
| eval_strategy="steps", | |
| save_total_limit=2, | |
| max_length=cfg.max_seq_length, | |
| bf16=cfg.bf16, | |
| fp16=not cfg.bf16, | |
| report_to=["none"], | |
| ) | |
| trainer = SFTTrainer( | |
| model=cfg.model_name_or_path, # model id / path; SFTTrainer loads it | |
| args=sft_config, | |
| train_dataset=train_ds, | |
| eval_dataset=eval_ds, | |
| processing_class=tokenizer, | |
| peft_config=peft_config, | |
| ) | |
| trainer.train() | |
| save_path = os.path.join(cfg.output_dir, "sft_model") | |
| trainer.save_model(save_path) | |
| return trainer.model # PEFT-wrapped model instance | |
| # ================== GRPO phase ================== | |
| def run_grpo(rl_dataset, base_model, tokenizer, cfg: PipelineConfig): | |
| """Run a GRPO RLAIF loop on top of the (optionally) SFT-initialized model.""" | |
| target_modules = cfg.target_modules or DEFAULT_TARGET_MODULES | |
| if hasattr(base_model, "peft_config"): | |
| peft_config = None | |
| else: | |
| peft_config = LoraConfig( | |
| r=8, | |
| lora_alpha=16, | |
| lora_dropout=0.05, | |
| bias="none", | |
| target_modules=target_modules, | |
| task_type="CAUSAL_LM", | |
| ) | |
| generation_batch_size = cfg.per_device_batch_size * cfg.num_generations | |
| grpo_config = GRPOConfig( | |
| output_dir=os.path.join(cfg.output_dir, "grpo"), | |
| num_train_epochs=cfg.grpo_epochs, | |
| per_device_train_batch_size=cfg.per_device_batch_size, | |
| gradient_accumulation_steps=cfg.grad_accum_steps, | |
| logging_steps=10, | |
| save_strategy="epoch", | |
| save_total_limit=cfg.grpo_epochs, | |
| bf16=cfg.bf16, | |
| fp16=not cfg.bf16, | |
| learning_rate=cfg.grpo_learning_rate, | |
| max_prompt_length=cfg.max_seq_length, | |
| max_completion_length=cfg.max_completion_length, | |
| num_generations=cfg.num_generations, | |
| generation_batch_size=generation_batch_size, | |
| report_to=["none"], | |
| ) | |
| if peft_config is not None: | |
| trainer = GRPOTrainer( | |
| model=base_model, | |
| args=grpo_config, | |
| processing_class=tokenizer, | |
| reward_funcs=battleground_rlaif_reward, | |
| train_dataset=rl_dataset, | |
| peft_config=peft_config, | |
| ) | |
| else: | |
| trainer = GRPOTrainer( | |
| model=base_model, | |
| args=grpo_config, | |
| processing_class=tokenizer, | |
| reward_funcs=battleground_rlaif_reward, | |
| train_dataset=rl_dataset, | |
| ) | |
| trainer.train() | |
| trainer.save_model(os.path.join(cfg.output_dir, "grpo_model")) | |
| # ================== Main ================== | |
| def main(): | |
| cfg = parse_args() | |
| os.makedirs(cfg.output_dir, exist_ok=True) | |
| print(f"Using model: {cfg.model_name_or_path}") | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| cfg.model_name_or_path, | |
| use_fast=True, | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # For GRPO, we want left padding | |
| tokenizer.padding_side = "left" | |
| print(f"Loading Battlegrounds game_history dataset from: {cfg.data_file}") | |
| sft_train, sft_eval, rl_train = load_gamehistory_rlaif( | |
| cfg.data_file, | |
| train_on_all_data=cfg.train_on_all_data, | |
| input_mode=cfg.input_mode, | |
| ) | |
| # ----- SFT ----- | |
| if cfg.skip_sft: | |
| print("Skipping SFT phase; loading base model directly.") | |
| dtype = ( | |
| torch.bfloat16 | |
| if cfg.bf16 and torch.cuda.is_available() | |
| else (torch.float16 if torch.cuda.is_available() else torch.float32) | |
| ) | |
| model_kwargs: Dict[str, Any] = { | |
| "torch_dtype": dtype, | |
| "trust_remote_code": True, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| cfg.model_name_or_path, **model_kwargs | |
| ) | |
| else: | |
| print("Running SFT phase...") | |
| base_model = run_sft(sft_train, sft_eval, tokenizer, cfg) | |
| # ----- GRPO ----- | |
| if cfg.skip_grpo: | |
| print("Skipping GRPO phase; only SFT outputs (if any) were produced.") | |
| else: | |
| print("Running GRPO (RLAIF) phase...") | |
| run_grpo(rl_train, base_model, tokenizer, cfg) | |
| print("All done. Check outputs under:", cfg.output_dir) | |
| if __name__ == "__main__": | |
| main() | |