deepbattler / RL /eval_battleground_rlaif.py
wyksdsg's picture
Upload folder using huggingface_hub
787c99c verified
raw
history blame
22.8 kB
#!/usr/bin/env python
# eval_battleground_rlaif.py
#
# Evaluation script for Battlegrounds RLAIF models: No FT, SFT, and SFT+GRPO.
# Measures action prediction accuracy against expert/labeled actions.
import argparse
import json
import os
import sys
from typing import Optional, Dict, Any, List
from tqdm import tqdm
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
if _SCRIPT_DIR not in sys.path:
sys.path.append(_SCRIPT_DIR)
from battleground_nl_utils import (
dataset_state_to_game_state,
game_state_to_natural_language,
)
# ================== Constants ==================
LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507"
DEFAULT_DATA_FILE = "RL/datasets/battleground_rlaif_multicandidate.jsonl"
def _resolve_default_model_id() -> str:
env_override = os.environ.get("QWEN_INSTRUCT_MODEL")
if env_override:
return env_override
if os.path.isdir(LOCAL_INSTRUCT_PATH):
return LOCAL_INSTRUCT_PATH
return "Qwen/Qwen3-4B-Instruct"
DEFAULT_MODEL_ID = _resolve_default_model_id()
# ================== Data loading ==================
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose exactly one best action and respond with a single JSON object in this exact format:
{"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "action".
3. Use 0-based integers for indices or null when not used.
4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
5. "card_name" must exactly match a card name from the game state when required, otherwise null.
Now here is the game state JSON:
"""
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
Given the following natural language description of the current game state, choose exactly one best action and respond with a single JSON object in this exact format:
{"action":{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "action".
3. Use 0-based integers for indices or null when not used.
4. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD","HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
5. "card_name" must exactly match a card name from the game state when required, otherwise null.
Now here is the description of the game state:
"""
def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
"""Build prompt from game state (same format as training)."""
if input_mode == "nl":
game_state = dataset_state_to_game_state(example)
nl_state = game_state_to_natural_language(game_state)
prefix = INSTRUCTION_PREFIX_NL
state_text = nl_state
else:
obj = {
"task": "battlegrounds_policy_v1",
"phase": example["phase"],
"turn": example["turn"],
"state": example["state"],
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
prefix = INSTRUCTION_PREFIX
return prefix + "\n" + state_text
def load_eval_dataset(
data_file: str,
test_size: float = 0.1,
seed: int = 42,
limit: Optional[int] = None,
input_mode: str = "json",
):
"""
Load evaluation dataset from JSONL file.
Uses the same train/test split as training to get the held-out test set.
"""
raw = load_dataset("json", data_files={"train": data_file})["train"]
# Same split as training
split = raw.train_test_split(test_size=test_size, seed=seed)
test_ds = split["test"]
def format_example(example):
prompt = _build_prompt(example, input_mode=input_mode)
candidates = example["candidates"]
# Find expert action
expert = None
for c in candidates:
if c.get("role") == "expert":
expert = c
break
if expert is None:
expert = max(candidates, key=lambda x: float(x.get("reward", 0.0)))
return {
"prompt": prompt,
"expert_action": expert["action"],
"candidates": candidates,
"game_id": example.get("game_id", ""),
"step_id": example.get("step_id", 0),
"turn": example["turn"],
"phase": example["phase"],
}
test_ds = test_ds.map(format_example, remove_columns=raw.column_names)
if limit is not None:
test_ds = test_ds.select(range(min(limit, len(test_ds))))
return test_ds
# ================== Action parsing & comparison ==================
def parse_action_from_completion(text: str) -> Optional[Dict[str, Any]]:
"""
Parse model completion to extract action dict.
Expected format from training: {"action": {...}}
"""
text = text.strip()
# Try to find JSON in the text
# Sometimes model outputs extra text before/after JSON
start_idx = text.find("{")
if start_idx == -1:
return None
# Find matching closing brace
brace_count = 0
end_idx = -1
for i, c in enumerate(text[start_idx:], start=start_idx):
if c == "{":
brace_count += 1
elif c == "}":
brace_count -= 1
if brace_count == 0:
end_idx = i + 1
break
if end_idx == -1:
# No matching brace, try to find any closing brace
end_idx = text.rfind("}") + 1
if end_idx == 0:
return None
json_str = text[start_idx:end_idx]
try:
obj = json.loads(json_str)
except Exception:
# Try to fix common issues
try:
# Sometimes model outputs incomplete JSON, try adding closing braces
obj = json.loads(json_str + "}")
except:
try:
obj = json.loads(json_str + "}}")
except:
return None
if isinstance(obj, dict):
# Format from training: {"action": {...}}
if "action" in obj and isinstance(obj["action"], dict):
return obj["action"]
# If it's directly an action dict (has "type" field)
if "type" in obj:
return obj
return None
def actions_match(pred: Dict[str, Any], gold: Dict[str, Any], strict: bool = True) -> bool:
"""
Compare predicted action with gold action.
Args:
pred: Predicted action dict
gold: Gold/expert action dict
strict: If True, all fields must match exactly. If False, only compare key fields.
"""
if strict:
return pred == gold
# Relaxed matching: compare only essential fields
key_fields = ["type", "tavern_index", "hand_index", "board_index", "card_name"]
for field in key_fields:
pred_val = pred.get(field)
gold_val = gold.get(field)
# Treat None and missing as equivalent
if pred_val is None and gold_val is None:
continue
if pred_val != gold_val:
return False
return True
def get_action_reward(pred: Dict[str, Any], candidates: List[Dict[str, Any]]) -> float:
"""Get reward for predicted action by matching against candidates."""
for cand in candidates:
cand_action = cand.get("action", {})
if actions_match(pred, cand_action, strict=False):
return float(cand.get("reward", 0.0))
return 0.0
# ================== Model loading ==================
def load_base_model(model_path: str, bf16: bool = True):
"""Load base model without any adapters."""
dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16
model_kwargs = {
"torch_dtype": dtype,
"trust_remote_code": True,
}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(
model_path, use_fast=True, trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
return model, tokenizer
def load_peft_model(base_model_path: str, adapter_path: str, bf16: bool = True):
"""Load base model with PEFT adapter."""
dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16
model_kwargs = {
"torch_dtype": dtype,
"trust_remote_code": True,
}
if torch.cuda.is_available():
model_kwargs["device_map"] = "auto"
base_model = AutoModelForCausalLM.from_pretrained(base_model_path, **model_kwargs)
model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload() # Merge for faster inference
tokenizer = AutoTokenizer.from_pretrained(
base_model_path, use_fast=True, trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
return model, tokenizer
# ================== Evaluation ==================
@torch.no_grad()
def evaluate_model(
model,
tokenizer,
test_ds,
max_new_tokens: int = 128,
batch_size: int = 8,
verbose: bool = False,
):
"""
Evaluate model on Battlegrounds test set.
Returns:
- exact_match_acc: Accuracy of exact action match
- relaxed_match_acc: Accuracy with relaxed matching (key fields only)
- avg_reward: Average reward of predicted actions
- results: List of per-sample results
"""
model.eval()
device = next(model.parameters()).device
exact_correct = 0
relaxed_correct = 0
total_reward = 0.0
total = 0
parse_failures = 0
results = []
for i in tqdm(range(0, len(test_ds), batch_size), desc="Evaluating"):
batch = test_ds[i : i + batch_size]
prompts = batch["prompt"] if isinstance(batch["prompt"], list) else [batch["prompt"]]
expert_actions = batch["expert_action"] if isinstance(batch["expert_action"], list) else [batch["expert_action"]]
candidates_list = batch["candidates"] if isinstance(batch["candidates"], list) else [batch["candidates"]]
inputs = tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=1024,
).to(device)
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and evaluate each sample
for j, (output, prompt, expert_action, candidates) in enumerate(
zip(outputs, prompts, expert_actions, candidates_list)
):
input_len = inputs["input_ids"][j].shape[0]
generated = tokenizer.decode(output[input_len:], skip_special_tokens=True)
pred_action = parse_action_from_completion(generated)
is_exact_match = False
is_relaxed_match = False
reward = 0.0
if pred_action is None:
parse_failures += 1
else:
is_exact_match = actions_match(pred_action, expert_action, strict=True)
is_relaxed_match = actions_match(pred_action, expert_action, strict=False)
reward = get_action_reward(pred_action, candidates)
if is_exact_match:
exact_correct += 1
if is_relaxed_match:
relaxed_correct += 1
total_reward += reward
total += 1
result = {
"game_id": batch["game_id"][j] if isinstance(batch["game_id"], list) else batch["game_id"],
"step_id": batch["step_id"][j] if isinstance(batch["step_id"], list) else batch["step_id"],
"turn": batch["turn"][j] if isinstance(batch["turn"], list) else batch["turn"],
"phase": batch["phase"][j] if isinstance(batch["phase"], list) else batch["phase"],
"expert_action": expert_action,
"predicted_action": pred_action,
"generated_text": generated.strip()[:200], # Truncate for readability
"exact_match": is_exact_match,
"relaxed_match": is_relaxed_match,
"reward": reward,
}
results.append(result)
if verbose and not is_relaxed_match:
print(f"\n[WRONG] Game: {result['game_id']}, Step: {result['step_id']}")
print(f" Expert: {expert_action}")
print(f" Pred: {pred_action}")
print(f" Gen: {generated[:150]}")
exact_match_acc = exact_correct / total if total > 0 else 0.0
relaxed_match_acc = relaxed_correct / total if total > 0 else 0.0
avg_reward = total_reward / total if total > 0 else 0.0
return {
"exact_match_acc": exact_match_acc,
"relaxed_match_acc": relaxed_match_acc,
"avg_reward": avg_reward,
"parse_failure_rate": parse_failures / total if total > 0 else 0.0,
"total_samples": total,
"results": results,
}
# ================== Main ==================
def main():
parser = argparse.ArgumentParser(description="Evaluate Battlegrounds RLAIF models: No FT, SFT, SFT+GRPO")
parser.add_argument(
"--base-model",
default=DEFAULT_MODEL_ID,
help="Base model path (Qwen instruct checkpoint).",
)
parser.add_argument(
"--output-dir",
default="./battleground_rlaif_qwen",
help="Directory containing SFT and GRPO checkpoints.",
)
parser.add_argument(
"--data-file",
default=DEFAULT_DATA_FILE,
help="Path to JSONL file with multi-candidate Battlegrounds data.",
)
parser.add_argument(
"--sft-adapter",
default=None,
help="Path to SFT adapter (default: <output-dir>/sft_model).",
)
parser.add_argument(
"--grpo-adapter",
default=None,
help="Path to GRPO adapter (default: <output-dir>/grpo_model).",
)
parser.add_argument(
"--eval-samples",
type=int,
default=50,
help="Number of test samples to evaluate (default: 50 for quick testing, use -1 for full set).",
)
parser.add_argument("--batch-size", type=int, default=8, help="Batch size for inference (default: 8 for A800).")
parser.add_argument("--max-new-tokens", type=int, default=128, help="Max tokens to generate.")
parser.add_argument("--disable-bf16", action="store_true", help="Use fp16 instead of bf16.")
parser.add_argument("--verbose", action="store_true", help="Print wrong predictions.")
parser.add_argument(
"--eval-no-ft", action="store_true", help="Evaluate base model (no fine-tuning)."
)
parser.add_argument("--eval-sft", action="store_true", help="Evaluate SFT model.")
parser.add_argument("--eval-grpo", action="store_true", help="Evaluate SFT+GRPO model.")
parser.add_argument(
"--save-results",
default=None,
help="Path to save detailed results as JSON.",
)
parser.add_argument(
"--input-mode",
choices=["json", "nl"],
default="json",
help="Input format for game state: 'json' uses raw JSON, 'nl' uses natural language description.",
)
args = parser.parse_args()
bf16 = not args.disable_bf16
# Default: evaluate all if none specified
eval_all = not (args.eval_no_ft or args.eval_sft or args.eval_grpo)
if eval_all:
args.eval_no_ft = True
args.eval_sft = True
args.eval_grpo = True
# Resolve adapter paths
sft_adapter = args.sft_adapter or os.path.join(args.output_dir, "sft_model")
grpo_adapter = args.grpo_adapter or os.path.join(args.output_dir, "grpo_model")
# Handle eval_samples=-1 as full set
eval_samples = None if args.eval_samples == -1 else args.eval_samples
# Load test data
print("Loading Battlegrounds test set...")
if not os.path.exists(args.data_file):
print(f"ERROR: Data file not found: {args.data_file}")
return
test_ds = load_eval_dataset(
args.data_file,
limit=eval_samples,
input_mode=args.input_mode,
)
print(f"Test samples: {len(test_ds)}")
all_results = {}
# ===== Evaluate No FT (base model) =====
if args.eval_no_ft:
print("\n" + "=" * 60)
print("Evaluating: No Fine-Tuning (Base Model)")
print("=" * 60)
model, tokenizer = load_base_model(args.base_model, bf16=bf16)
metrics = evaluate_model(
model, tokenizer, test_ds,
max_new_tokens=args.max_new_tokens,
batch_size=args.batch_size,
verbose=args.verbose,
)
print(f"[No FT] Exact Match: {metrics['exact_match_acc']:.4f}")
print(f"[No FT] Relaxed Match: {metrics['relaxed_match_acc']:.4f}")
print(f"[No FT] Avg Reward: {metrics['avg_reward']:.4f}")
print(f"[No FT] Parse Failures: {metrics['parse_failure_rate']:.2%}")
all_results["no_ft"] = metrics
del model
torch.cuda.empty_cache()
# ===== Evaluate SFT =====
if args.eval_sft:
print("\n" + "=" * 60)
print("Evaluating: SFT Fine-Tuned Model")
print("=" * 60)
if not os.path.exists(sft_adapter):
print(f"[SKIP] SFT adapter not found at: {sft_adapter}")
else:
model, tokenizer = load_peft_model(args.base_model, sft_adapter, bf16=bf16)
metrics = evaluate_model(
model, tokenizer, test_ds,
max_new_tokens=args.max_new_tokens,
batch_size=args.batch_size,
verbose=args.verbose,
)
print(f"[SFT] Exact Match: {metrics['exact_match_acc']:.4f}")
print(f"[SFT] Relaxed Match: {metrics['relaxed_match_acc']:.4f}")
print(f"[SFT] Avg Reward: {metrics['avg_reward']:.4f}")
print(f"[SFT] Parse Failures: {metrics['parse_failure_rate']:.2%}")
all_results["sft"] = metrics
del model
torch.cuda.empty_cache()
# ===== Evaluate SFT + GRPO =====
if args.eval_grpo:
print("\n" + "=" * 60)
print("Evaluating: SFT + GRPO Fine-Tuned Model")
print("=" * 60)
grpo_epoch_dir = os.path.join(args.output_dir, "grpo")
adapters_to_eval: List[tuple[str, str]] = []
# If user did not override --grpo-adapter and epoch checkpoints exist,
# evaluate all checkpoint-* directories under output_dir/grpo plus final grpo_model.
default_grpo_adapter = os.path.join(args.output_dir, "grpo_model")
using_default_adapter = (args.grpo_adapter is None) or (
grpo_adapter == default_grpo_adapter
)
if using_default_adapter and os.path.isdir(grpo_epoch_dir):
checkpoint_names = [
d
for d in os.listdir(grpo_epoch_dir)
if d.startswith("checkpoint")
and os.path.isdir(os.path.join(grpo_epoch_dir, d))
]
checkpoint_names.sort()
for name in checkpoint_names:
path = os.path.join(grpo_epoch_dir, name)
label = f"sft_grpo_{name}"
adapters_to_eval.append((label, path))
if os.path.exists(grpo_adapter):
adapters_to_eval.append(("sft_grpo_final", grpo_adapter))
else:
if os.path.exists(grpo_adapter):
adapters_to_eval.append(("sft_grpo", grpo_adapter))
if not adapters_to_eval:
print(f"[SKIP] No GRPO adapters found. Expected at: {grpo_adapter} or under {grpo_epoch_dir}")
else:
for label, adapter_path in adapters_to_eval:
print("\n" + "-" * 60)
print(f"Evaluating GRPO adapter: {label}")
print(f"Path: {adapter_path}")
model, tokenizer = load_peft_model(
args.base_model, adapter_path, bf16=bf16
)
metrics = evaluate_model(
model,
tokenizer,
test_ds,
max_new_tokens=args.max_new_tokens,
batch_size=args.batch_size,
verbose=args.verbose,
)
print(f"[{label}] Exact Match: {metrics['exact_match_acc']:.4f}")
print(f"[{label}] Relaxed Match: {metrics['relaxed_match_acc']:.4f}")
print(f"[{label}] Avg Reward: {metrics['avg_reward']:.4f}")
print(f"[{label}] Parse Failures: {metrics['parse_failure_rate']:.2%}")
all_results[label] = metrics
del model
torch.cuda.empty_cache()
# ===== Summary =====
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"{'Model':<12} {'Exact':<10} {'Relaxed':<10} {'Reward':<10} {'Parse Fail':<10}")
print("-" * 52)
for name, data in all_results.items():
if "results" in data: # Has actual results
print(f"{name:<12} {data['exact_match_acc']:<10.4f} {data['relaxed_match_acc']:<10.4f} {data['avg_reward']:<10.4f} {data['parse_failure_rate']:<10.2%}")
# Save results
if args.save_results:
save_data = {
name: {
"exact_match_acc": data["exact_match_acc"],
"relaxed_match_acc": data["relaxed_match_acc"],
"avg_reward": data["avg_reward"],
"parse_failure_rate": data["parse_failure_rate"],
"total_samples": data["total_samples"],
"sample_predictions": data["results"][:10], # First 10 for inspection
}
for name, data in all_results.items()
if "results" in data
}
with open(args.save_results, "w") as f:
json.dump(save_data, f, indent=2, ensure_ascii=False)
print(f"\nResults saved to: {args.save_results}")
if __name__ == "__main__":
main()