#!/usr/bin/env python # eval_battleground_rlaif.py # # Evaluation script for Battlegrounds RLAIF models: No FT, SFT, and SFT+GRPO. # Measures action prediction accuracy against expert/labeled actions. import argparse import json import os import sys from typing import Optional, Dict, Any, List from tqdm import tqdm import torch from datasets import load_dataset, Dataset from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) if _SCRIPT_DIR not in sys.path: sys.path.append(_SCRIPT_DIR) from battleground_nl_utils import ( dataset_state_to_game_state, game_state_to_natural_language, ) # ================== Constants ================== LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" DEFAULT_DATA_FILE = "RL/datasets/battleground_rlaif_multicandidate.jsonl" def _resolve_default_model_id() -> str: env_override = os.environ.get("QWEN_INSTRUCT_MODEL") if env_override: return env_override if os.path.isdir(LOCAL_INSTRUCT_PATH): return LOCAL_INSTRUCT_PATH return "Qwen/Qwen3-4B-Instruct" DEFAULT_MODEL_ID = _resolve_default_model_id() # ================== Data loading ================== INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. Given the current game state as a JSON object, choose exactly one best action and respond with a single JSON object in this exact format: {"action":{"type":"","tavern_index":,"hand_index":,"board_index":,"card_name":}} Rules: 1. Respond with JSON only. Do not add explanations or any extra text. 2. The top-level object must have exactly one key: "action". 3. Use 0-based integers for indices or null when not used. 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". 5. "card_name" must exactly match a card name from the game state when required, otherwise null. Now here is the game state JSON: """ INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. Given the following natural language description of the current game state, choose exactly one best action and respond with a single JSON object in this exact format: {"action":{"type":"","tavern_index":,"hand_index":,"board_index":,"card_name":}} Rules: 1. Respond with JSON only. Do not add explanations or any extra text. 2. The top-level object must have exactly one key: "action". 3. Use 0-based integers for indices or null when not used. 4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". 5. "card_name" must exactly match a card name from the game state when required, otherwise null. Now here is the description of the game state: """ def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: """Build prompt from game state (same format as training).""" if input_mode == "nl": game_state = dataset_state_to_game_state(example) nl_state = game_state_to_natural_language(game_state) prefix = INSTRUCTION_PREFIX_NL state_text = nl_state else: obj = { "task": "battlegrounds_policy_v1", "phase": example["phase"], "turn": example["turn"], "state": example["state"], } state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) prefix = INSTRUCTION_PREFIX return prefix + "\n" + state_text def load_eval_dataset( data_file: str, test_size: float = 0.1, seed: int = 42, limit: Optional[int] = None, input_mode: str = "json", ): """ Load evaluation dataset from JSONL file. Uses the same train/test split as training to get the held-out test set. """ raw = load_dataset("json", data_files={"train": data_file})["train"] # Same split as training split = raw.train_test_split(test_size=test_size, seed=seed) test_ds = split["test"] def format_example(example): prompt = _build_prompt(example, input_mode=input_mode) candidates = example["candidates"] # Find expert action expert = None for c in candidates: if c.get("role") == "expert": expert = c break if expert is None: expert = max(candidates, key=lambda x: float(x.get("reward", 0.0))) return { "prompt": prompt, "expert_action": expert["action"], "candidates": candidates, "game_id": example.get("game_id", ""), "step_id": example.get("step_id", 0), "turn": example["turn"], "phase": example["phase"], } test_ds = test_ds.map(format_example, remove_columns=raw.column_names) if limit is not None: test_ds = test_ds.select(range(min(limit, len(test_ds)))) return test_ds # ================== Action parsing & comparison ================== def parse_action_from_completion(text: str) -> Optional[Dict[str, Any]]: """ Parse model completion to extract action dict. Expected format from training: {"action": {...}} """ text = text.strip() # Try to find JSON in the text # Sometimes model outputs extra text before/after JSON start_idx = text.find("{") if start_idx == -1: return None # Find matching closing brace brace_count = 0 end_idx = -1 for i, c in enumerate(text[start_idx:], start=start_idx): if c == "{": brace_count += 1 elif c == "}": brace_count -= 1 if brace_count == 0: end_idx = i + 1 break if end_idx == -1: # No matching brace, try to find any closing brace end_idx = text.rfind("}") + 1 if end_idx == 0: return None json_str = text[start_idx:end_idx] try: obj = json.loads(json_str) except Exception: # Try to fix common issues try: # Sometimes model outputs incomplete JSON, try adding closing braces obj = json.loads(json_str + "}") except: try: obj = json.loads(json_str + "}}") except: return None if isinstance(obj, dict): # Format from training: {"action": {...}} if "action" in obj and isinstance(obj["action"], dict): return obj["action"] # If it's directly an action dict (has "type" field) if "type" in obj: return obj return None def actions_match(pred: Dict[str, Any], gold: Dict[str, Any], strict: bool = True) -> bool: """ Compare predicted action with gold action. Args: pred: Predicted action dict gold: Gold/expert action dict strict: If True, all fields must match exactly. If False, only compare key fields. """ if strict: return pred == gold # Relaxed matching: compare only essential fields key_fields = ["type", "tavern_index", "hand_index", "board_index", "card_name"] for field in key_fields: pred_val = pred.get(field) gold_val = gold.get(field) # Treat None and missing as equivalent if pred_val is None and gold_val is None: continue if pred_val != gold_val: return False return True def get_action_reward(pred: Dict[str, Any], candidates: List[Dict[str, Any]]) -> float: """Get reward for predicted action by matching against candidates.""" for cand in candidates: cand_action = cand.get("action", {}) if actions_match(pred, cand_action, strict=False): return float(cand.get("reward", 0.0)) return 0.0 # ================== Model loading ================== def load_base_model(model_path: str, bf16: bool = True): """Load base model without any adapters.""" dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16 model_kwargs = { "torch_dtype": dtype, "trust_remote_code": True, } if torch.cuda.is_available(): model_kwargs["device_map"] = "auto" model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained( model_path, use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" return model, tokenizer def load_peft_model(base_model_path: str, adapter_path: str, bf16: bool = True): """Load base model with PEFT adapter.""" dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16 model_kwargs = { "torch_dtype": dtype, "trust_remote_code": True, } if torch.cuda.is_available(): model_kwargs["device_map"] = "auto" base_model = AutoModelForCausalLM.from_pretrained(base_model_path, **model_kwargs) model = PeftModel.from_pretrained(base_model, adapter_path) model = model.merge_and_unload() # Merge for faster inference tokenizer = AutoTokenizer.from_pretrained( base_model_path, use_fast=True, trust_remote_code=True ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" return model, tokenizer # ================== Evaluation ================== @torch.no_grad() def evaluate_model( model, tokenizer, test_ds, max_new_tokens: int = 128, batch_size: int = 8, verbose: bool = False, ): """ Evaluate model on Battlegrounds test set. Returns: - exact_match_acc: Accuracy of exact action match - relaxed_match_acc: Accuracy with relaxed matching (key fields only) - avg_reward: Average reward of predicted actions - results: List of per-sample results """ model.eval() device = next(model.parameters()).device exact_correct = 0 relaxed_correct = 0 total_reward = 0.0 total = 0 parse_failures = 0 results = [] for i in tqdm(range(0, len(test_ds), batch_size), desc="Evaluating"): batch = test_ds[i : i + batch_size] prompts = batch["prompt"] if isinstance(batch["prompt"], list) else [batch["prompt"]] expert_actions = batch["expert_action"] if isinstance(batch["expert_action"], list) else [batch["expert_action"]] candidates_list = batch["candidates"] if isinstance(batch["candidates"], list) else [batch["candidates"]] inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=1024, ).to(device) outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode and evaluate each sample for j, (output, prompt, expert_action, candidates) in enumerate( zip(outputs, prompts, expert_actions, candidates_list) ): input_len = inputs["input_ids"][j].shape[0] generated = tokenizer.decode(output[input_len:], skip_special_tokens=True) pred_action = parse_action_from_completion(generated) is_exact_match = False is_relaxed_match = False reward = 0.0 if pred_action is None: parse_failures += 1 else: is_exact_match = actions_match(pred_action, expert_action, strict=True) is_relaxed_match = actions_match(pred_action, expert_action, strict=False) reward = get_action_reward(pred_action, candidates) if is_exact_match: exact_correct += 1 if is_relaxed_match: relaxed_correct += 1 total_reward += reward total += 1 result = { "game_id": batch["game_id"][j] if isinstance(batch["game_id"], list) else batch["game_id"], "step_id": batch["step_id"][j] if isinstance(batch["step_id"], list) else batch["step_id"], "turn": batch["turn"][j] if isinstance(batch["turn"], list) else batch["turn"], "phase": batch["phase"][j] if isinstance(batch["phase"], list) else batch["phase"], "expert_action": expert_action, "predicted_action": pred_action, "generated_text": generated.strip()[:200], # Truncate for readability "exact_match": is_exact_match, "relaxed_match": is_relaxed_match, "reward": reward, } results.append(result) if verbose and not is_relaxed_match: print(f"\n[WRONG] Game: {result['game_id']}, Step: {result['step_id']}") print(f" Expert: {expert_action}") print(f" Pred: {pred_action}") print(f" Gen: {generated[:150]}") exact_match_acc = exact_correct / total if total > 0 else 0.0 relaxed_match_acc = relaxed_correct / total if total > 0 else 0.0 avg_reward = total_reward / total if total > 0 else 0.0 return { "exact_match_acc": exact_match_acc, "relaxed_match_acc": relaxed_match_acc, "avg_reward": avg_reward, "parse_failure_rate": parse_failures / total if total > 0 else 0.0, "total_samples": total, "results": results, } # ================== Main ================== def main(): parser = argparse.ArgumentParser(description="Evaluate Battlegrounds RLAIF models: No FT, SFT, SFT+GRPO") parser.add_argument( "--base-model", default=DEFAULT_MODEL_ID, help="Base model path (Qwen instruct checkpoint).", ) parser.add_argument( "--output-dir", default="./battleground_rlaif_qwen", help="Directory containing SFT and GRPO checkpoints.", ) parser.add_argument( "--data-file", default=DEFAULT_DATA_FILE, help="Path to JSONL file with multi-candidate Battlegrounds data.", ) parser.add_argument( "--sft-adapter", default=None, help="Path to SFT adapter (default: /sft_model).", ) parser.add_argument( "--grpo-adapter", default=None, help="Path to GRPO adapter (default: /grpo_model).", ) parser.add_argument( "--eval-samples", type=int, default=50, help="Number of test samples to evaluate (default: 50 for quick testing, use -1 for full set).", ) parser.add_argument("--batch-size", type=int, default=8, help="Batch size for inference (default: 8 for A800).") parser.add_argument("--max-new-tokens", type=int, default=128, help="Max tokens to generate.") parser.add_argument("--disable-bf16", action="store_true", help="Use fp16 instead of bf16.") parser.add_argument("--verbose", action="store_true", help="Print wrong predictions.") parser.add_argument( "--eval-no-ft", action="store_true", help="Evaluate base model (no fine-tuning)." ) parser.add_argument("--eval-sft", action="store_true", help="Evaluate SFT model.") parser.add_argument("--eval-grpo", action="store_true", help="Evaluate SFT+GRPO model.") parser.add_argument( "--save-results", default=None, help="Path to save detailed results as JSON.", ) parser.add_argument( "--input-mode", choices=["json", "nl"], default="json", help="Input format for game state: 'json' uses raw JSON, 'nl' uses natural language description.", ) args = parser.parse_args() bf16 = not args.disable_bf16 # Default: evaluate all if none specified eval_all = not (args.eval_no_ft or args.eval_sft or args.eval_grpo) if eval_all: args.eval_no_ft = True args.eval_sft = True args.eval_grpo = True # Resolve adapter paths sft_adapter = args.sft_adapter or os.path.join(args.output_dir, "sft_model") grpo_adapter = args.grpo_adapter or os.path.join(args.output_dir, "grpo_model") # Handle eval_samples=-1 as full set eval_samples = None if args.eval_samples == -1 else args.eval_samples # Load test data print("Loading Battlegrounds test set...") if not os.path.exists(args.data_file): print(f"ERROR: Data file not found: {args.data_file}") return test_ds = load_eval_dataset( args.data_file, limit=eval_samples, input_mode=args.input_mode, ) print(f"Test samples: {len(test_ds)}") all_results = {} # ===== Evaluate No FT (base model) ===== if args.eval_no_ft: print("\n" + "=" * 60) print("Evaluating: No Fine-Tuning (Base Model)") print("=" * 60) model, tokenizer = load_base_model(args.base_model, bf16=bf16) metrics = evaluate_model( model, tokenizer, test_ds, max_new_tokens=args.max_new_tokens, batch_size=args.batch_size, verbose=args.verbose, ) print(f"[No FT] Exact Match: {metrics['exact_match_acc']:.4f}") print(f"[No FT] Relaxed Match: {metrics['relaxed_match_acc']:.4f}") print(f"[No FT] Avg Reward: {metrics['avg_reward']:.4f}") print(f"[No FT] Parse Failures: {metrics['parse_failure_rate']:.2%}") all_results["no_ft"] = metrics del model torch.cuda.empty_cache() # ===== Evaluate SFT ===== if args.eval_sft: print("\n" + "=" * 60) print("Evaluating: SFT Fine-Tuned Model") print("=" * 60) if not os.path.exists(sft_adapter): print(f"[SKIP] SFT adapter not found at: {sft_adapter}") else: model, tokenizer = load_peft_model(args.base_model, sft_adapter, bf16=bf16) metrics = evaluate_model( model, tokenizer, test_ds, max_new_tokens=args.max_new_tokens, batch_size=args.batch_size, verbose=args.verbose, ) print(f"[SFT] Exact Match: {metrics['exact_match_acc']:.4f}") print(f"[SFT] Relaxed Match: {metrics['relaxed_match_acc']:.4f}") print(f"[SFT] Avg Reward: {metrics['avg_reward']:.4f}") print(f"[SFT] Parse Failures: {metrics['parse_failure_rate']:.2%}") all_results["sft"] = metrics del model torch.cuda.empty_cache() # ===== Evaluate SFT + GRPO ===== if args.eval_grpo: print("\n" + "=" * 60) print("Evaluating: SFT + GRPO Fine-Tuned Model") print("=" * 60) grpo_epoch_dir = os.path.join(args.output_dir, "grpo") adapters_to_eval: List[tuple[str, str]] = [] # If user did not override --grpo-adapter and epoch checkpoints exist, # evaluate all checkpoint-* directories under output_dir/grpo plus final grpo_model. default_grpo_adapter = os.path.join(args.output_dir, "grpo_model") using_default_adapter = (args.grpo_adapter is None) or ( grpo_adapter == default_grpo_adapter ) if using_default_adapter and os.path.isdir(grpo_epoch_dir): checkpoint_names = [ d for d in os.listdir(grpo_epoch_dir) if d.startswith("checkpoint") and os.path.isdir(os.path.join(grpo_epoch_dir, d)) ] checkpoint_names.sort() for name in checkpoint_names: path = os.path.join(grpo_epoch_dir, name) label = f"sft_grpo_{name}" adapters_to_eval.append((label, path)) if os.path.exists(grpo_adapter): adapters_to_eval.append(("sft_grpo_final", grpo_adapter)) else: if os.path.exists(grpo_adapter): adapters_to_eval.append(("sft_grpo", grpo_adapter)) if not adapters_to_eval: print(f"[SKIP] No GRPO adapters found. Expected at: {grpo_adapter} or under {grpo_epoch_dir}") else: for label, adapter_path in adapters_to_eval: print("\n" + "-" * 60) print(f"Evaluating GRPO adapter: {label}") print(f"Path: {adapter_path}") model, tokenizer = load_peft_model( args.base_model, adapter_path, bf16=bf16 ) metrics = evaluate_model( model, tokenizer, test_ds, max_new_tokens=args.max_new_tokens, batch_size=args.batch_size, verbose=args.verbose, ) print(f"[{label}] Exact Match: {metrics['exact_match_acc']:.4f}") print(f"[{label}] Relaxed Match: {metrics['relaxed_match_acc']:.4f}") print(f"[{label}] Avg Reward: {metrics['avg_reward']:.4f}") print(f"[{label}] Parse Failures: {metrics['parse_failure_rate']:.2%}") all_results[label] = metrics del model torch.cuda.empty_cache() # ===== Summary ===== print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) print(f"{'Model':<12} {'Exact':<10} {'Relaxed':<10} {'Reward':<10} {'Parse Fail':<10}") print("-" * 52) for name, data in all_results.items(): if "results" in data: # Has actual results print(f"{name:<12} {data['exact_match_acc']:<10.4f} {data['relaxed_match_acc']:<10.4f} {data['avg_reward']:<10.4f} {data['parse_failure_rate']:<10.2%}") # Save results if args.save_results: save_data = { name: { "exact_match_acc": data["exact_match_acc"], "relaxed_match_acc": data["relaxed_match_acc"], "avg_reward": data["avg_reward"], "parse_failure_rate": data["parse_failure_rate"], "total_samples": data["total_samples"], "sample_predictions": data["results"][:10], # First 10 for inspection } for name, data in all_results.items() if "results" in data } with open(args.save_results, "w") as f: json.dump(save_data, f, indent=2, ensure_ascii=False) print(f"\nResults saved to: {args.save_results}") if __name__ == "__main__": main()