from __future__ import annotations import argparse import hashlib import inspect import json import os import sys import time from dataclasses import dataclass from typing import Any, Dict, List import torch from datasets import Dataset from peft import LoraConfig, PeftModel, get_peft_model from safetensors.torch import load_file as load_safetensors_file from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback, set_seed CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) PARENT_DIR = os.path.dirname(CURRENT_DIR) if PARENT_DIR not in sys.path: sys.path.insert(0, PARENT_DIR) from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row from checkpoint_utils import ensure_final_checkpoint_dir, save_model_artifacts from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt from multi_output_cell_policy.rewards import score_prediction_text from multi_output_cell_policy.shared_multi_output_policy import make_solved_grid_from_row try: import wandb except Exception: wandb = None @dataclass class Args: model_name: str train_jsonl: str eval_jsonl: str output_dir: str cache_dir: str init_adapter_dir: str seed: int gpu_id: int stage_i: int total_empties_hint: int per_device_train_batch_size: int gradient_accumulation_steps: int num_train_epochs: float learning_rate: float logging_steps: int save_steps: int eval_steps: int eval_rows: int num_generations: int max_prompt_length: int max_completion_length: int beta: float lora_r: int lora_alpha: int lora_dropout: float enable_gradient_checkpointing: bool use_wandb: bool wandb_entity: str wandb_project: str wandb_run_name: str wandb_mode: str wandb_group: str wandb_run_id: str debug_print_limit: int limit_train_rows: int reward_good_value: float penalty_bad_value: float penalty_malformed: float penalty_empty: float penalty_singleton: float penalty_missing: float exact_match_bonus: float cardinality_mismatch_penalty: float eval_value_precision_stop: float eval_value_recall_stop: float eval_solve_rate_stop: float min_steps_before_stop: int max_wall_clock_seconds: int max_steps: int resume_from_checkpoint: str def configure_hf_cache(cache_dir: str) -> str: cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) hub_dir = os.path.join(cache_dir, "hub") transformers_dir = os.path.join(cache_dir, "transformers") os.makedirs(hub_dir, exist_ok=True) os.makedirs(transformers_dir, exist_ok=True) os.environ["HF_HOME"] = cache_dir os.environ["HF_HUB_CACHE"] = hub_dir os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir os.environ["TRANSFORMERS_CACHE"] = transformers_dir os.environ.setdefault("HF_HUB_DISABLE_XET", "1") return cache_dir def configure_wandb_dirs(output_dir: str) -> None: wandb_dir = os.path.join(output_dir, "wandb_runtime") os.makedirs(wandb_dir, exist_ok=True) os.environ.setdefault("WANDB_DIR", wandb_dir) os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) def pick_dtype() -> torch.dtype: if torch.cuda.is_available(): try: if torch.cuda.is_bf16_supported(): return torch.bfloat16 except Exception: pass return torch.float16 def ensure_trl_fsdp_compat() -> None: try: import torch.distributed.fsdp as fsdp if not hasattr(fsdp, "FSDPModule") and hasattr(fsdp, "FullyShardedDataParallel"): fsdp.FSDPModule = fsdp.FullyShardedDataParallel except Exception: pass def load_trainable_adapter(base_model: torch.nn.Module, adapter_dir: str) -> torch.nn.Module: try: return PeftModel.from_pretrained(base_model, adapter_dir, is_trainable=True) except Exception: config_path = os.path.join(adapter_dir, "adapter_config.json") model_path = os.path.join(adapter_dir, "adapter_model.safetensors") if not (os.path.exists(config_path) and os.path.exists(model_path)): raise with open(config_path, "r", encoding="utf-8") as f: cfg = json.load(f) lora = LoraConfig( r=int(cfg["r"]), lora_alpha=int(cfg["lora_alpha"]), lora_dropout=float(cfg["lora_dropout"]), bias=str(cfg.get("bias", "none")), task_type=str(cfg.get("task_type", "CAUSAL_LM")), target_modules=list(cfg["target_modules"]), ) model = get_peft_model(base_model, lora) state = load_safetensors_file(model_path) remapped: Dict[str, torch.Tensor] = {} for key, value in state.items(): new_key = key.replace(".lora_A.weight", ".lora_A.default.weight") new_key = new_key.replace(".lora_B.weight", ".lora_B.default.weight") remapped[new_key] = value model.load_state_dict(remapped, strict=False) return model def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue rows.append(json.loads(line)) if limit_rows > 0 and len(rows) >= limit_rows: break return rows def build_grpo_records( rows: List[Dict[str, Any]], *, tokenizer: Any, stage_i: int, total_empties_hint: int, progress_every_rows: int = 10, progress_callback: Any = None, ) -> List[Dict[str, Any]]: records: List[Dict[str, Any]] = [] for row_idx, row in enumerate(rows, start=1): solved = make_solved_grid_from_row(row) for ex in build_cell_examples_from_row(row): prompt = build_multi_output_cell_prompt( ex.grid, target_cell=ex.target_cell, stage_i=stage_i, tokenizer=tokenizer, turn_idx=ex.turn_idx, total_turns=ex.total_turns, prev_output_flag=None, total_empties_hint=total_empties_hint, ) records.append( { "prompt": prompt, "grid_json": json.dumps(ex.grid.tolist(), separators=(",", ":")), "solved_json": json.dumps(solved.tolist(), separators=(",", ":")), "target_row": int(ex.target_cell[0]), "target_col": int(ex.target_cell[1]), "stage_i": int(stage_i), } ) if progress_callback is not None and ( row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0 ): progress_callback(row_idx, len(rows), len(records)) return records def _prepared_data_dir(args: Args) -> str: path = os.path.join(PARENT_DIR, "_prepared_data", "multi_output_cell_policy") os.makedirs(path, exist_ok=True) return path def _prepared_grpo_cache_path(args: Args) -> str: payload = json.dumps( { "train_jsonl": os.path.abspath(args.train_jsonl), "stage_i": int(args.stage_i), "total_empties_hint": int(args.total_empties_hint), "limit_train_rows": int(args.limit_train_rows), "model_name": str(args.model_name), }, sort_keys=True, ).encode("utf-8") digest = hashlib.sha1(payload).hexdigest()[:20] return os.path.join(_prepared_data_dir(args), f"grpo_stage{int(args.stage_i):02d}_{digest}.jsonl") def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: tmp_path = f"{path}.tmp" with open(tmp_path, "w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row, separators=(",", ":")) + "\n") os.replace(tmp_path, path) def _read_jsonl(path: str) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: rows.append(json.loads(line)) return rows def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None: start = time.time() while not os.path.exists(path): if time.time() - start > timeout_s: raise TimeoutError(f"Timed out waiting for prepared cache: {path}") time.sleep(2.0) def load_or_build_grpo_records( args: Args, *, rows: List[Dict[str, Any]], tokenizer: Any, rank: int, world_size: int, progress_callback: Any = None, ) -> List[Dict[str, Any]]: cache_path = _prepared_grpo_cache_path(args) if os.path.exists(cache_path): return _read_jsonl(cache_path) if rank == 0: print(f"[dataset build][grpo stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) records = build_grpo_records( rows, tokenizer=tokenizer, stage_i=args.stage_i, total_empties_hint=args.total_empties_hint, progress_every_rows=10, progress_callback=progress_callback, ) _write_jsonl(cache_path, records) return records _wait_for_cache(cache_path) return _read_jsonl(cache_path) def make_reward_func(args: Args): def reward_func(completions, grid_json, solved_json, target_row, target_col, stage_i, **kwargs): rewards: List[float] = [] for completion, grid_s, solved_s, rr, cc, stage_val in zip( completions, grid_json, solved_json, target_row, target_col, stage_i ): info = score_prediction_text( text=str(completion), grid=torch.tensor(json.loads(grid_s), dtype=torch.long).numpy(), solved=torch.tensor(json.loads(solved_s), dtype=torch.long).numpy(), target_cell=(int(rr), int(cc)), stage_i=int(stage_val), reward_good_value=args.reward_good_value, penalty_bad_value=args.penalty_bad_value, penalty_malformed=args.penalty_malformed, penalty_empty=args.penalty_empty, penalty_singleton=args.penalty_singleton, penalty_missing=args.penalty_missing, exact_match_bonus=args.exact_match_bonus, cardinality_mismatch_penalty=args.cardinality_mismatch_penalty, ) rewards.append(float(info["reward"])) return rewards return reward_func @torch.no_grad() def run_eval( *, args: Args, rows: List[Dict[str, Any]], model: torch.nn.Module, tokenizer: Any, device: torch.device, ) -> Dict[str, float]: model.eval() total_cells = 0 parse_ok = 0.0 canonical_ok = 0.0 exact_set_match = 0.0 includes_gt = 0.0 precision_sum = 0.0 recall_sum = 0.0 predicted_size_sum = 0.0 good_count_sum = 0.0 bad_count_sum = 0.0 solve_ok = 0 printed = 0 for row in rows: solved = make_solved_grid_from_row(row) row_all_exact = True for ex in build_cell_examples_from_row(row): prompt = build_multi_output_cell_prompt( ex.grid, target_cell=ex.target_cell, stage_i=args.stage_i, tokenizer=tokenizer, turn_idx=ex.turn_idx, total_turns=ex.total_turns, prev_output_flag=None, total_empties_hint=args.total_empties_hint, ) enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) enc = {k: v.to(device) for k, v in enc.items()} out = model.generate( **enc, max_new_tokens=max(1, int(args.max_completion_length)), do_sample=False, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) pred_text = tokenizer.decode(out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True).strip() info = score_prediction_text( text=pred_text, grid=ex.grid, solved=solved, target_cell=ex.target_cell, stage_i=args.stage_i, reward_good_value=args.reward_good_value, penalty_bad_value=args.penalty_bad_value, penalty_malformed=args.penalty_malformed, penalty_empty=args.penalty_empty, penalty_singleton=args.penalty_singleton, penalty_missing=args.penalty_missing, exact_match_bonus=args.exact_match_bonus, cardinality_mismatch_penalty=args.cardinality_mismatch_penalty, ) total_cells += 1 parse_ok += float(info["parse_ok"]) canonical_ok += float(info["strict_canonical"]) exact_set_match += float(info["exact_set_match"]) includes_gt += float(info["includes_ground_truth"]) precision_sum += float(info["value_precision"]) recall_sum += float(info["value_recall"]) predicted_size_sum += float(info["num_predicted_values"]) good_count_sum += float(info["num_i_consistent_values"]) bad_count_sum += float(info["num_non_i_consistent_values"]) if float(info["exact_set_match"]) < 0.5: row_all_exact = False if printed < int(args.debug_print_limit): rr, cc = ex.target_cell print(f"[baseline grpo eval debug] target=({rr+1},{cc+1}) output={pred_text!r}", flush=True) print( f"[baseline grpo eval debug] target_values={info['target_values']} predicted_values={info['predicted_values']}", flush=True, ) printed += 1 solve_ok += int(row_all_exact) return { "parse_rate": float(parse_ok / max(1, total_cells)), "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), "value_precision": float(precision_sum / max(1, total_cells)), "value_recall": float(recall_sum / max(1, total_cells)), "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), "solve_rate": float(solve_ok / max(1, len(rows))), "eval_cells": float(total_cells), } def unwrap_training_model(model: Any) -> Any: current = model while hasattr(current, "module"): current = current.module return current class CustomEvalCallback(TrainerCallback): def __init__( self, args: Args, eval_rows: List[Dict[str, Any]], tokenizer: Any, device: torch.device, wb_run: Any, is_main_process: bool, ): self.args = args self.eval_rows = eval_rows self.tokenizer = tokenizer self.device = device self.wb_run = wb_run self.is_main_process = is_main_process self.last_logged_step = -1 def on_step_end(self, args, state, control, **kwargs): step = int(state.global_step) eval_every = int(self.args.eval_steps) if step <= 0 or step % eval_every != 0: return control world_size = int(os.environ.get("WORLD_SIZE", "1")) use_dist = world_size > 1 and torch.distributed.is_available() and torch.distributed.is_initialized() stop_tensor = torch.zeros(1, dtype=torch.int32, device=self.device) if self.is_main_process: if step != self.last_logged_step: model = kwargs.get("model") if model is not None: metrics = run_eval( args=self.args, rows=self.eval_rows, model=unwrap_training_model(model), tokenizer=self.tokenizer, device=self.device, ) self.last_logged_step = step print( f"[baseline grpo custom eval step {step}] parse={metrics['parse_rate']:.3f} " f"solve={metrics['solve_rate']:.3f} " f"avg_set_size={metrics['avg_predicted_set_size']:.3f} " f"good={metrics['avg_num_i_consistent_values']:.3f} " f"bad={metrics['avg_num_non_i_consistent_values']:.3f}", flush=True, ) if self.args.use_wandb and self.wb_run is not None: payload = {f"custom_eval/{k}": float(v) for k, v in metrics.items()} payload["custom_eval/global_step"] = float(step) wandb.log(payload) if ( float(self.args.eval_value_precision_stop) > 0.0 and float(self.args.eval_value_recall_stop) > 0.0 and step >= int(self.args.min_steps_before_stop) and float(metrics["value_precision"]) >= float(self.args.eval_value_precision_stop) and float(metrics["value_recall"]) >= float(self.args.eval_value_recall_stop) ): print( f"[baseline grpo custom eval step {step}] stopping early: " f"value_precision={metrics['value_precision']:.3f} >= {float(self.args.eval_value_precision_stop):.3f} " f"and value_recall={metrics['value_recall']:.3f} >= {float(self.args.eval_value_recall_stop):.3f}", flush=True, ) stop_tensor[0] = 1 if ( int(stop_tensor.item()) == 0 and float(self.args.eval_solve_rate_stop) > 0.0 and step >= int(self.args.min_steps_before_stop) and float(metrics["solve_rate"]) >= float(self.args.eval_solve_rate_stop) ): print( f"[baseline grpo custom eval step {step}] stopping early: " f"solve_rate={metrics['solve_rate']:.3f} >= {float(self.args.eval_solve_rate_stop):.3f}", flush=True, ) stop_tensor[0] = 1 if use_dist: torch.distributed.broadcast(stop_tensor, src=0) if int(stop_tensor.item()) != 0: control.should_training_stop = True return control class FinalCheckpointCallback(TrainerCallback): def __init__(self, output_dir: str, tokenizer: Any, is_main_process: bool): self.output_dir = output_dir self.tokenizer = tokenizer self.is_main_process = is_main_process def _save(self, model: Any) -> None: if self.is_main_process: save_model_artifacts(unwrap_training_model(model), self.tokenizer, ensure_final_checkpoint_dir(self.output_dir)) def on_save(self, args, state, control, **kwargs): model = kwargs.get("model") if model is not None: self._save(model) return control def on_train_end(self, args, state, control, **kwargs): model = kwargs.get("model") if model is not None: self._save(model) return control class WallClockStopCallback(TrainerCallback): def __init__(self, max_wall_clock_seconds: int): self.max_wall_clock_seconds = int(max_wall_clock_seconds) self.start_time = time.time() def on_step_end(self, args, state, control, **kwargs): if self.max_wall_clock_seconds > 0 and (time.time() - self.start_time) >= float(self.max_wall_clock_seconds): control.should_training_stop = True return control def parse_args() -> Args: p = argparse.ArgumentParser() p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") p.add_argument("--train_jsonl", type=str, required=True) p.add_argument("--eval_jsonl", type=str, default="") p.add_argument("--output_dir", type=str, required=True) p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum-CoT/.hf_cache") p.add_argument("--init_adapter_dir", type=str, required=True) p.add_argument("--seed", type=int, default=0) p.add_argument("--gpu_id", type=int, default=0) p.add_argument("--stage_i", type=int, default=1) p.add_argument("--total_empties_hint", type=int, default=10) p.add_argument("--per_device_train_batch_size", type=int, default=2) p.add_argument("--gradient_accumulation_steps", type=int, default=4) p.add_argument("--num_train_epochs", type=float, default=0.5) p.add_argument("--learning_rate", type=float, default=1e-6) p.add_argument("--logging_steps", type=int, default=5) p.add_argument("--save_steps", type=int, default=25) p.add_argument("--eval_steps", type=int, default=25) p.add_argument("--eval_rows", type=int, default=20) p.add_argument("--num_generations", type=int, default=2) p.add_argument("--max_prompt_length", type=int, default=1024) p.add_argument("--max_completion_length", type=int, default=24) p.add_argument("--beta", type=float, default=0.0) p.add_argument("--lora_r", type=int, default=8) p.add_argument("--lora_alpha", type=int, default=16) p.add_argument("--lora_dropout", type=float, default=0.05) p.add_argument("--enable_gradient_checkpointing", action="store_true") p.add_argument("--use_wandb", action="store_true") p.add_argument("--wandb_entity", type=str, default="") p.add_argument("--wandb_project", type=str, default="sudoku-multi-output-grpo") p.add_argument("--wandb_run_name", type=str, default="") p.add_argument("--wandb_mode", type=str, default="online") p.add_argument("--wandb_group", type=str, default="") p.add_argument("--wandb_run_id", type=str, default="") p.add_argument("--debug_print_limit", type=int, default=3) p.add_argument("--limit_train_rows", type=int, default=0) p.add_argument("--reward_good_value", type=float, default=1.0) p.add_argument("--penalty_bad_value", type=float, default=1.75) p.add_argument("--penalty_malformed", type=float, default=4.0) p.add_argument("--penalty_empty", type=float, default=0.5) p.add_argument("--penalty_singleton", type=float, default=1.5) p.add_argument( "--penalty_missing", type=float, default=0.0, help="Per-missing-value penalty: reward -= penalty_missing * |target_set \\ predicted_set|. " "Defaults to 0 (legacy); set ~0.75 at stage>=2 to push recall up.", ) p.add_argument( "--exact_match_bonus", type=float, default=0.0, help="Bonus added only when set(predicted_values) == set(target_values) and prediction is non-empty. " "Defaults to 0; set ~2.0 to strictly dominate partial supersets.", ) p.add_argument( "--cardinality_mismatch_penalty", type=float, default=0.0, help="Penalty when len(predicted_values) < len(target_values) for multi-value targets " "(stage-agnostic). Defaults to 0; set ~1.0 at stage>=2 to deter under-prediction.", ) p.add_argument("--eval_value_precision_stop", type=float, default=0.0) p.add_argument("--eval_value_recall_stop", type=float, default=0.0) p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) p.add_argument("--min_steps_before_stop", type=int, default=0) p.add_argument("--max_wall_clock_seconds", type=int, default=0) p.add_argument("--max_steps", type=int, default=0) p.add_argument("--resume_from_checkpoint", type=str, default="") ns = p.parse_args() return Args(**vars(ns)) def main() -> None: args = parse_args() preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) is_main_process = rank == 0 if preset_visible_devices: if is_main_process: print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True) elif int(args.gpu_id) >= 0: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) set_seed(args.seed + rank) os.makedirs(args.output_dir, exist_ok=True) ensure_final_checkpoint_dir(args.output_dir) cache_dir = configure_hf_cache(args.cache_dir) configure_wandb_dirs(args.output_dir) if is_main_process: print(f"Using Hugging Face cache dir: {cache_dir}", flush=True) wb_run = None if is_main_process and args.use_wandb and wandb is not None: init_kwargs = { "project": args.wandb_project, "name": args.wandb_run_name or None, "mode": args.wandb_mode, "group": args.wandb_group or None, "id": args.wandb_run_id or None, } if str(args.wandb_entity).strip(): init_kwargs["entity"] = args.wandb_entity wb_run = wandb.init(**init_kwargs) print(f"W&B run id: {wb_run.id}", flush=True) print(f"W&B run URL: {wb_run.url}", flush=True) wandb.log({"prep/rows_done": 0.0, "prep/records_built": 0.0, "prep/cache_hit": 0.0}) rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) eval_source = args.eval_jsonl if str(args.eval_jsonl).strip() else args.train_jsonl eval_rows = load_jsonl_rows(eval_source, limit_rows=max(1, int(args.eval_rows))) tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") if is_main_process: print(f"Using device: {device}", flush=True) base = AutoModelForCausalLM.from_pretrained( args.model_name, cache_dir=cache_dir, torch_dtype=pick_dtype(), low_cpu_mem_usage=True, ) model = load_trainable_adapter(base, args.init_adapter_dir) if is_main_process: print(f"Loaded init adapter: {args.init_adapter_dir}", flush=True) if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() if hasattr(model, "config"): model.config.use_cache = False if world_size <= 1: model.to(device) model.train() def on_prep_progress(rows_done: int, total_rows: int, records_built: int) -> None: if is_main_process: print( f"[dataset build][grpo stage {args.stage_i}] rows={rows_done}/{total_rows} records={records_built}", flush=True, ) if wb_run is not None: wandb.log({"prep/rows_done": float(rows_done), "prep/records_built": float(records_built)}) train_records = load_or_build_grpo_records( args, rows=rows, tokenizer=tokenizer, rank=rank, world_size=world_size, progress_callback=on_prep_progress, ) if is_main_process and wb_run is not None: wandb.log( { "prep/cache_hit": float(os.path.exists(_prepared_grpo_cache_path(args))), "prep/records_final": float(len(train_records)), } ) train_dataset = Dataset.from_list(train_records) reward_func = make_reward_func(args) if int(args.limit_train_rows) > 0 and int(args.max_steps) <= 0: args.max_steps = 1 ensure_trl_fsdp_compat() from trl import GRPOConfig, GRPOTrainer config_kwargs = { "output_dir": args.output_dir, "per_device_train_batch_size": args.per_device_train_batch_size, "gradient_accumulation_steps": args.gradient_accumulation_steps, "num_train_epochs": args.num_train_epochs, "learning_rate": args.learning_rate, "logging_steps": args.logging_steps, "save_steps": args.save_steps, "eval_strategy": "steps", "eval_steps": args.eval_steps, "max_prompt_length": args.max_prompt_length, "max_completion_length": args.max_completion_length, "num_generations": args.num_generations, "beta": args.beta, "bf16": (pick_dtype() == torch.bfloat16), "report_to": ["wandb"] if args.use_wandb and is_main_process else [], "remove_unused_columns": False, } if int(args.max_steps) > 0: config_kwargs["max_steps"] = int(args.max_steps) grpo_config_params = inspect.signature(GRPOConfig.__init__).parameters unsupported_keys = sorted(key for key in config_kwargs if key not in grpo_config_params) for key in unsupported_keys: config_kwargs.pop(key, None) if is_main_process and unsupported_keys: print(f"Skipping unsupported GRPOConfig args: {', '.join(unsupported_keys)}", flush=True) config = GRPOConfig(**config_kwargs) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=[reward_func], args=config, train_dataset=train_dataset, eval_dataset=train_dataset.select(range(min(len(train_dataset), max(1, int(args.eval_rows))))), ) trainer.add_callback(CustomEvalCallback(args, eval_rows, tokenizer, device, wb_run, is_main_process)) trainer.add_callback(FinalCheckpointCallback(args.output_dir, tokenizer, is_main_process)) trainer.add_callback(WallClockStopCallback(args.max_wall_clock_seconds)) trainer.train(resume_from_checkpoint=args.resume_from_checkpoint or None) final_model = unwrap_training_model(trainer.model) if is_main_process: eval_metrics = run_eval(args=args, rows=eval_rows, model=final_model, tokenizer=tokenizer, device=device) print( f"[baseline grpo final eval] parse={eval_metrics['parse_rate']:.3f} " f"canonical={eval_metrics['strict_canonical_rate']:.3f} " f"exact={eval_metrics['exact_set_match_rate']:.3f} precision={eval_metrics['value_precision']:.3f} " f"recall={eval_metrics['value_recall']:.3f} solve={eval_metrics['solve_rate']:.3f}", flush=True, ) if wb_run is not None: wandb.log({f"custom_eval/{k}": float(v) for k, v in eval_metrics.items()}) trainer.save_model(args.output_dir) save_model_artifacts(final_model, tokenizer, ensure_final_checkpoint_dir(args.output_dir)) if wb_run is not None: wb_run.finish() if __name__ == "__main__": main()