Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # eval_battleground_rlaif_gamehistory.py | |
| # | |
| # Evaluation script for Battlegrounds RLAIF models trained on game_history data. | |
| # Mirrors eval_battleground_rlaif.py but uses the game_history-specific | |
| # prompt format and sequence-of-actions outputs used in | |
| # train_battleground_rlaif_gamehistory.py. | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from typing import Optional, Dict, Any, List, Tuple | |
| import torch | |
| from datasets import Dataset | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| if _SCRIPT_DIR not in sys.path: | |
| sys.path.append(_SCRIPT_DIR) | |
| from battleground_nl_utils import ( | |
| game_state_to_natural_language, | |
| ) | |
| # ================== Model paths & defaults ================== | |
| LOCAL_INSTRUCT_PATH = "models/qwen3-4b-instruct-2507/Qwen/Qwen3-4B-Instruct-2507" | |
| def _resolve_default_model_id() -> str: | |
| env_override = os.environ.get("QWEN_INSTRUCT_MODEL") | |
| if env_override: | |
| return env_override | |
| if os.path.isdir(LOCAL_INSTRUCT_PATH): | |
| return LOCAL_INSTRUCT_PATH | |
| return "Qwen/Qwen3-4B-Instruct" | |
| DEFAULT_MODEL_ID = _resolve_default_model_id() | |
| DEFAULT_DATA_FILE = "RL/datasets/game_history_flat.json" | |
| DEFAULT_OUTPUT_DIR = "./battleground_rlaif_qwen_gamehistory" | |
| # ================== Prompt construction (mirrors training) ================== | |
| INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI. | |
| Given the current game state as a JSON object, choose the best full-turn sequence | |
| of actions and respond with a single JSON object in this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the game state JSON: | |
| """ | |
| INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI. | |
| Given the following natural language description of the current game state, choose | |
| the best full-turn sequence of actions and respond with a single JSON object in | |
| this exact format: | |
| {"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]} | |
| Rules: | |
| 1. Respond with JSON only. Do not add explanations or any extra text. | |
| 2. The top-level object must have exactly one key: "actions". | |
| 3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of | |
| atomic action objects. | |
| 4. Use 0-based integers for indices or null when not used. | |
| 5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD", | |
| "HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN". | |
| 6. "card_name" must exactly match a card name from the game state when required, | |
| otherwise null. | |
| Now here is the description of the game state: | |
| """ | |
| def _build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str: | |
| """Build a prompt from a flattened game_history example. | |
| The example has: | |
| - phase: string (e.g., "PlayerTurn") | |
| - turn: int | |
| - state: nested dict with keys: game_state, player_hero, resources, board_state | |
| This mirrors _build_prompt in train_battleground_rlaif_gamehistory.py. | |
| """ | |
| state = example.get("state", {}) | |
| if input_mode == "nl": | |
| nl_state = game_state_to_natural_language(state) | |
| prefix = INSTRUCTION_PREFIX_NL | |
| state_text = nl_state | |
| else: | |
| gs = state.get("game_state", {}) or {} | |
| phase = example.get("phase", gs.get("phase", "PlayerTurn")) | |
| turn = example.get("turn", gs.get("turn_number", 0)) | |
| obj = { | |
| "task": "battlegrounds_policy_v1", | |
| "phase": phase, | |
| "turn": turn, | |
| "state": state, | |
| } | |
| state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False) | |
| prefix = INSTRUCTION_PREFIX | |
| return prefix + "\n" + state_text | |
| # ================== Data loading for evaluation ================== | |
| def _load_flat_gamehistory(data_file: str) -> Dataset: | |
| """Load a flattened game_history dataset. | |
| Expected format: | |
| - JSON file containing a list of per-turn rows, each with at least: | |
| {"game_id", "step_id", "turn", "phase", "state", "candidates", ...} | |
| This is the same flattened structure that train_battleground_rlaif_gamehistory.py | |
| consumes (e.g., RL/datasets/game_history_flat.json). | |
| """ | |
| if not os.path.exists(data_file): | |
| raise FileNotFoundError(f"Data file not found: {data_file}") | |
| with open(data_file, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError( | |
| f"Expected {data_file} to contain a JSON array of rows; got {type(data)} instead." | |
| ) | |
| rows: List[Dict[str, Any]] = [] | |
| for idx, row in enumerate(data): | |
| if not isinstance(row, dict) or "state" not in row: | |
| raise ValueError( | |
| f"Row {idx} in {data_file} is not a valid example with 'state' key." | |
| ) | |
| candidates = row.get("candidates") | |
| if not candidates: | |
| # Skip unlabeled rows | |
| continue | |
| state = row.get("state", {}) or {} | |
| gs = state.get("game_state", {}) or {} | |
| if "phase" not in row: | |
| row["phase"] = gs.get("phase", "PlayerTurn") | |
| if "turn" not in row: | |
| row["turn"] = gs.get("turn_number", row.get("step_id", 0)) | |
| rows.append(row) | |
| if not rows: | |
| raise ValueError( | |
| "No labeled rows with 'candidates' found in the flattened game_history file." | |
| ) | |
| return Dataset.from_list(rows) | |
| def load_eval_dataset( | |
| data_file: str, | |
| test_size: float = 0.1, | |
| seed: int = 42, | |
| limit: Optional[int] = None, | |
| input_mode: str = "json", | |
| use_all_data: bool = False, | |
| ): | |
| """Load evaluation dataset from flattened game_history JSON. | |
| By default mirrors the splitting logic in load_gamehistory_rlaif: we take a | |
| train/test split and use the held-out test set. If ``use_all_data`` is | |
| True, we instead evaluate on *all* rows in the file (optionally | |
| subsampled by ``limit``). | |
| """ | |
| raw = _load_flat_gamehistory(data_file) | |
| if use_all_data: | |
| test_ds = raw | |
| else: | |
| split = raw.train_test_split(test_size=test_size, seed=seed) | |
| test_ds = split["test"] | |
| def format_example(example: Dict[str, Any]) -> Dict[str, Any]: | |
| prompt = _build_prompt(example, input_mode=input_mode) | |
| candidates = example["candidates"] | |
| # Find expert candidate; fallback to max reward if missing. | |
| expert = None | |
| for c in candidates: | |
| if c.get("role") == "expert": | |
| expert = c | |
| break | |
| if expert is None: | |
| expert = max(candidates, key=lambda x: float(x.get("reward", 0.0))) | |
| return { | |
| "prompt": prompt, | |
| # expert_action is now a SEQUENCE of atomic actions | |
| "expert_actions": expert["actions"], | |
| "candidates": candidates, | |
| "game_id": example.get("game_id", ""), | |
| "step_id": example.get("step_id", 0), | |
| "turn": example.get("turn", 0), | |
| "phase": example.get("phase", "PlayerTurn"), | |
| } | |
| test_ds = test_ds.map(format_example, remove_columns=raw.column_names) | |
| if limit is not None: | |
| test_ds = test_ds.select(range(min(limit, len(test_ds)))) | |
| return test_ds | |
| # ================== Action parsing & comparison ================== | |
| def _parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]: | |
| """Parse a model completion into a sequence of atomic action dicts. | |
| Expected formats (same as training reward function): | |
| - {"actions": [ {...}, {...}, ... ]} | |
| - {"action": [ {...}, {...}, ... ]} # tolerated fallback | |
| This is more strict than the training-time JSON-mode parser that attempts to | |
| clean arbitrary text; for evaluation we assume the model largely follows | |
| the instruction and returns JSON. | |
| """ | |
| text = text.strip() | |
| # Try to locate a JSON object within the text (in case of extra chatter) | |
| start_idx = text.find("{") | |
| if start_idx == -1: | |
| return None | |
| # Heuristic: take until the last closing brace | |
| end_idx = text.rfind("}") | |
| if end_idx == -1: | |
| return None | |
| json_str = text[start_idx : end_idx + 1] | |
| try: | |
| obj = json.loads(json_str) | |
| except Exception: | |
| return None | |
| if not isinstance(obj, dict): | |
| return None | |
| seq = None | |
| if "actions" in obj and isinstance(obj["actions"], list): | |
| seq = obj["actions"] | |
| elif "action" in obj and isinstance(obj["action"], list): | |
| seq = obj["action"] | |
| if seq is None: | |
| return None | |
| actions: List[Dict[str, Any]] = [] | |
| for step in seq: | |
| if not isinstance(step, dict): | |
| return None | |
| actions.append(step) | |
| return actions | |
| def _action_sequences_equal(a: List[Dict[str, Any]], b: List[Dict[str, Any]]) -> bool: | |
| """Strict equality for sequences of atomic actions. | |
| Both length and each per-step dict must match exactly. | |
| """ | |
| if len(a) != len(b): | |
| return False | |
| for s1, s2 in zip(a, b): | |
| if s1 != s2: | |
| return False | |
| return True | |
| def _sequence_reward(pred: List[Dict[str, Any]], candidates: List[Dict[str, Any]]) -> float: | |
| """Return the reward by matching a predicted sequence to candidate sequences. | |
| This mirrors battleground_rlaif_reward: if we find a candidate whose | |
| `actions` exactly equal `pred`, return that candidate's reward, else 0.0. | |
| """ | |
| best_reward = 0.0 | |
| for cand in candidates: | |
| cand_actions = cand.get("actions") | |
| if not isinstance(cand_actions, list): | |
| continue | |
| if _action_sequences_equal(pred, cand_actions): | |
| r = float(cand.get("reward", 0.0)) | |
| if r > best_reward: | |
| best_reward = r | |
| return best_reward | |
| # ================== Model loading ================== | |
| def load_base_model(model_path: str, bf16: bool = True): | |
| """Load base model without any adapters.""" | |
| dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16 | |
| model_kwargs: Dict[str, Any] = { | |
| "torch_dtype": dtype, | |
| "trust_remote_code": True, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, use_fast=True, trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| return model, tokenizer | |
| def load_peft_model(base_model_path: str, adapter_path: str, bf16: bool = True): | |
| """Load base model with PEFT adapter and merge for inference.""" | |
| dtype = torch.bfloat16 if bf16 and torch.cuda.is_available() else torch.float16 | |
| model_kwargs: Dict[str, Any] = { | |
| "torch_dtype": dtype, | |
| "trust_remote_code": True, | |
| } | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| base_model = AutoModelForCausalLM.from_pretrained(base_model_path, **model_kwargs) | |
| model = PeftModel.from_pretrained(base_model, adapter_path) | |
| model = model.merge_and_unload() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| base_model_path, use_fast=True, trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| return model, tokenizer | |
| # ================== Evaluation loop ================== | |
| def evaluate_model( | |
| model, | |
| tokenizer, | |
| test_ds, | |
| max_new_tokens: int = 128, | |
| batch_size: int = 8, | |
| verbose: bool = False, | |
| ) -> Dict[str, Any]: | |
| """Evaluate model on game_history Battlegrounds test set. | |
| Metrics: | |
| - exact_match_seq: predicted sequence == expert sequence (strict) | |
| - avg_reward: average candidate reward of predicted sequences | |
| - parse_failure_rate: fraction of samples where output could not be parsed | |
| """ | |
| model.eval() | |
| device = next(model.parameters()).device | |
| exact_correct = 0 | |
| total_reward = 0.0 | |
| total = 0 | |
| parse_failures = 0 | |
| results: List[Dict[str, Any]] = [] | |
| for i in range(0, len(test_ds), batch_size): | |
| batch = test_ds[i : i + batch_size] | |
| prompts = batch["prompt"] if isinstance(batch["prompt"], list) else [batch["prompt"]] | |
| expert_seqs = ( | |
| batch["expert_actions"] | |
| if isinstance(batch["expert_actions"], list) | |
| else [batch["expert_actions"]] | |
| ) | |
| candidates_list = ( | |
| batch["candidates"] | |
| if isinstance(batch["candidates"], list) | |
| else [batch["candidates"]] | |
| ) | |
| inputs = tokenizer( | |
| prompts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=1024, | |
| ).to(device) | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| for j, (output, prompt, expert_actions, candidates) in enumerate( | |
| zip(outputs, prompts, expert_seqs, candidates_list) | |
| ): | |
| input_len = inputs["input_ids"][j].shape[0] | |
| generated = tokenizer.decode(output[input_len:], skip_special_tokens=True) | |
| pred_seq = _parse_actions_from_completion(generated) | |
| is_exact_match = False | |
| reward = 0.0 | |
| if pred_seq is None: | |
| parse_failures += 1 | |
| else: | |
| is_exact_match = _action_sequences_equal(pred_seq, expert_actions) | |
| reward = _sequence_reward(pred_seq, candidates) | |
| if is_exact_match: | |
| exact_correct += 1 | |
| total_reward += reward | |
| total += 1 | |
| result = { | |
| "game_id": batch["game_id"][j] | |
| if isinstance(batch["game_id"], list) | |
| else batch["game_id"], | |
| "step_id": batch["step_id"][j] | |
| if isinstance(batch["step_id"], list) | |
| else batch["step_id"], | |
| "turn": batch["turn"][j] | |
| if isinstance(batch["turn"], list) | |
| else batch["turn"], | |
| "phase": batch["phase"][j] | |
| if isinstance(batch["phase"], list) | |
| else batch["phase"], | |
| "expert_actions": expert_actions, | |
| "predicted_actions": pred_seq, | |
| "generated_text": generated.strip()[:200], | |
| "exact_match": is_exact_match, | |
| "reward": reward, | |
| } | |
| results.append(result) | |
| if verbose and not is_exact_match: | |
| print(f"\n[WRONG] Game: {result['game_id']}, Step: {result['step_id']}") | |
| print(f" Expert: {expert_actions}") | |
| print(f" Pred: {pred_seq}") | |
| print(f" Gen: {generated[:150]}") | |
| exact_match_acc = exact_correct / total if total > 0 else 0.0 | |
| avg_reward = total_reward / total if total > 0 else 0.0 | |
| return { | |
| "exact_match_acc": exact_match_acc, | |
| "avg_reward": avg_reward, | |
| "parse_failure_rate": parse_failures / total if total > 0 else 0.0, | |
| "total_samples": total, | |
| "results": results, | |
| } | |
| # ================== Main CLI ================== | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Evaluate Battlegrounds RLAIF models (game_history pipeline): No FT, SFT, SFT+GRPO" | |
| ) | |
| parser.add_argument( | |
| "--base-model", | |
| default=DEFAULT_MODEL_ID, | |
| help="Base model path (Qwen instruct checkpoint).", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default=DEFAULT_OUTPUT_DIR, | |
| help="Directory containing SFT and GRPO checkpoints (same as training).", | |
| ) | |
| parser.add_argument( | |
| "--data-file", | |
| default=DEFAULT_DATA_FILE, | |
| help=( | |
| "Path to flattened game_history JSON (e.g., RL/datasets/game_history_flat.json) " | |
| "with per-turn rows and multi-candidate annotations." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--sft-adapter", | |
| default=None, | |
| help="Path to SFT adapter (default: <output-dir>/sft_model).", | |
| ) | |
| parser.add_argument( | |
| "--grpo-adapter", | |
| default=None, | |
| help="Path to GRPO adapter (default: <output-dir>/grpo_model).", | |
| ) | |
| parser.add_argument( | |
| "--eval-samples", | |
| type=int, | |
| default=50, | |
| help="Number of test samples to evaluate (default: 50, use -1 for full set).", | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=8, | |
| help="Batch size for inference.", | |
| ) | |
| parser.add_argument( | |
| "--max-new-tokens", type=int, default=512, help="Max tokens to generate." | |
| ) | |
| parser.add_argument( | |
| "--disable-bf16", | |
| action="store_true", | |
| help="Use fp16 instead of bf16.", | |
| ) | |
| parser.add_argument( | |
| "--verbose", action="store_true", help="Print mismatched predictions." | |
| ) | |
| parser.add_argument( | |
| "--eval-no-ft", | |
| action="store_true", | |
| help="Evaluate base model (no fine-tuning).", | |
| ) | |
| parser.add_argument("--eval-sft", action="store_true", help="Evaluate SFT model.") | |
| parser.add_argument( | |
| "--eval-grpo", action="store_true", help="Evaluate SFT+GRPO model." | |
| ) | |
| parser.add_argument( | |
| "--save-results", | |
| default=None, | |
| help="Path to save detailed results as JSON.", | |
| ) | |
| parser.add_argument( | |
| "--eval-all-data", | |
| action="store_true", | |
| help=( | |
| "If set, do not create a train/test split; evaluate on all rows " | |
| "from the flattened game_history file (optionally subsampled by " | |
| "--eval-samples)." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--input-mode", | |
| choices=["json", "nl"], | |
| default="json", | |
| help=( | |
| "Input format for game state: 'json' uses nested JSON; 'nl' uses a natural " | |
| "language description built from the nested state." | |
| ), | |
| ) | |
| args = parser.parse_args() | |
| bf16 = not args.disable_bf16 | |
| # Default: evaluate all if none specified | |
| eval_all = not (args.eval_no_ft or args.eval_sft or args.eval_grpo) | |
| if eval_all: | |
| args.eval_no_ft = True | |
| args.eval_sft = True | |
| args.eval_grpo = True | |
| # Resolve adapter paths | |
| sft_adapter = args.sft_adapter or os.path.join(args.output_dir, "sft_model") | |
| grpo_adapter = args.grpo_adapter or os.path.join(args.output_dir, "grpo_model") | |
| # Handle eval_samples=-1 as full set | |
| eval_samples = None if args.eval_samples == -1 else args.eval_samples | |
| # Load test data | |
| print("Loading game_history Battlegrounds test set...") | |
| if not os.path.exists(args.data_file): | |
| print(f"ERROR: Data file not found: {args.data_file}") | |
| return | |
| test_ds = load_eval_dataset( | |
| args.data_file, | |
| limit=eval_samples, | |
| input_mode=args.input_mode, | |
| use_all_data=args.eval_all_data, | |
| ) | |
| print(f"Test samples: {len(test_ds)}") | |
| all_results: Dict[str, Any] = {} | |
| # ===== Evaluate No FT (base model) ===== | |
| if args.eval_no_ft: | |
| print("\n" + "=" * 60) | |
| print("Evaluating: No Fine-Tuning (Base Model)") | |
| print("=" * 60) | |
| model, tokenizer = load_base_model(args.base_model, bf16=bf16) | |
| metrics = evaluate_model( | |
| model, | |
| tokenizer, | |
| test_ds, | |
| max_new_tokens=args.max_new_tokens, | |
| batch_size=args.batch_size, | |
| verbose=args.verbose, | |
| ) | |
| print(f"[No FT] Exact Seq Match: {metrics['exact_match_acc']:.4f}") | |
| print(f"[No FT] Avg Reward: {metrics['avg_reward']:.4f}") | |
| print(f"[No FT] Parse Failures: {metrics['parse_failure_rate']:.2%}") | |
| all_results["no_ft"] = metrics | |
| del model | |
| torch.cuda.empty_cache() | |
| # ===== Evaluate SFT ===== | |
| if args.eval_sft: | |
| print("\n" + "=" * 60) | |
| print("Evaluating: SFT Fine-Tuned Model") | |
| print("=" * 60) | |
| if not os.path.exists(sft_adapter): | |
| print(f"[SKIP] SFT adapter not found at: {sft_adapter}") | |
| else: | |
| model, tokenizer = load_peft_model(args.base_model, sft_adapter, bf16=bf16) | |
| metrics = evaluate_model( | |
| model, | |
| tokenizer, | |
| test_ds, | |
| max_new_tokens=args.max_new_tokens, | |
| batch_size=args.batch_size, | |
| verbose=args.verbose, | |
| ) | |
| print(f"[SFT] Exact Seq Match: {metrics['exact_match_acc']:.4f}") | |
| print(f"[SFT] Avg Reward: {metrics['avg_reward']:.4f}") | |
| print(f"[SFT] Parse Failures: {metrics['parse_failure_rate']:.2%}") | |
| all_results["sft"] = metrics | |
| del model | |
| torch.cuda.empty_cache() | |
| # ===== Evaluate SFT + GRPO ===== | |
| if args.eval_grpo: | |
| print("\n" + "=" * 60) | |
| print("Evaluating: SFT + GRPO Fine-Tuned Model") | |
| print("=" * 60) | |
| grpo_epoch_dir = os.path.join(args.output_dir, "grpo") | |
| adapters_to_eval: List[Tuple[str, str]] = [] | |
| # If user did not override --grpo-adapter and epoch checkpoints exist, | |
| # evaluate all checkpoint-* directories under output_dir/grpo plus final grpo_model. | |
| default_grpo_adapter = os.path.join(args.output_dir, "grpo_model") | |
| using_default_adapter = (args.grpo_adapter is None) or ( | |
| grpo_adapter == default_grpo_adapter | |
| ) | |
| if using_default_adapter and os.path.isdir(grpo_epoch_dir): | |
| checkpoint_names = [ | |
| d | |
| for d in os.listdir(grpo_epoch_dir) | |
| if d.startswith("checkpoint") | |
| and os.path.isdir(os.path.join(grpo_epoch_dir, d)) | |
| ] | |
| checkpoint_names.sort() | |
| for name in checkpoint_names: | |
| path = os.path.join(grpo_epoch_dir, name) | |
| label = f"sft_grpo_{name}" | |
| adapters_to_eval.append((label, path)) | |
| if os.path.exists(grpo_adapter): | |
| adapters_to_eval.append(("sft_grpo_final", grpo_adapter)) | |
| else: | |
| if os.path.exists(grpo_adapter): | |
| adapters_to_eval.append(("sft_grpo", grpo_adapter)) | |
| if not adapters_to_eval: | |
| print( | |
| f"[SKIP] No GRPO adapters found. Expected at: {grpo_adapter} or under {grpo_epoch_dir}" | |
| ) | |
| else: | |
| for label, adapter_path in adapters_to_eval: | |
| print("\n" + "-" * 60) | |
| print(f"Evaluating GRPO adapter: {label}") | |
| print(f"Path: {adapter_path}") | |
| model, tokenizer = load_peft_model( | |
| args.base_model, adapter_path, bf16=bf16 | |
| ) | |
| metrics = evaluate_model( | |
| model, | |
| tokenizer, | |
| test_ds, | |
| max_new_tokens=args.max_new_tokens, | |
| batch_size=args.batch_size, | |
| verbose=args.verbose, | |
| ) | |
| print(f"[{label}] Exact Seq Match: {metrics['exact_match_acc']:.4f}") | |
| print(f"[{label}] Avg Reward: {metrics['avg_reward']:.4f}") | |
| print(f"[{label}] Parse Failures: {metrics['parse_failure_rate']:.2%}") | |
| all_results[label] = metrics | |
| del model | |
| torch.cuda.empty_cache() | |
| # ===== Summary ===== | |
| print("\n" + "=" * 60) | |
| print("SUMMARY (game_history pipeline)") | |
| print("=" * 60) | |
| print(f"{'Model':<16} {'ExactSeq':<10} {'Reward':<10} {'ParseFail':<10}") | |
| print("-" * 52) | |
| for name, data in all_results.items(): | |
| if "results" in data: | |
| print( | |
| f"{name:<16} {data['exact_match_acc']:<10.4f} " | |
| f"{data['avg_reward']:<10.4f} {data['parse_failure_rate']:<10.2%}" | |
| ) | |
| # Save results | |
| if args.save_results: | |
| save_data = { | |
| name: { | |
| "exact_match_acc": data["exact_match_acc"], | |
| "avg_reward": data["avg_reward"], | |
| "parse_failure_rate": data["parse_failure_rate"], | |
| "total_samples": data["total_samples"], | |
| "sample_predictions": data["results"][:10], | |
| } | |
| for name, data in all_results.items() | |
| if "results" in data | |
| } | |
| with open(args.save_results, "w", encoding="utf-8") as f: | |
| json.dump(save_data, f, indent=2, ensure_ascii=False) | |
| print(f"\nResults saved to: {args.save_results}") | |
| if __name__ == "__main__": | |
| main() | |