#!/usr/bin/env python3 """OCC GRPO Training Demo — Minimal end-to-end GRPO with OCC reward hook. Trains Qwen2.5-0.5B-Instruct with GRPO using a cost-adjusted marginal impact reward. The reward combines correctness, format, cost penalty, confident-wrong penalty, and abstention bonus. Dataset: trl-lib/DeepMath-103K (column: "solution") """ import re, json, torch from datasets import load_dataset from trl import GRPOTrainer, GRPOConfig def occ_reward(completions, solution, completion_ids=None, prompts=None, **kwargs): """ OCC cost-adjusted reward for GRPO. Dataset must have 'solution' column (DeepMath-103K format). Reward = correctness(±1.0) + format(+0.1) + cost(-0.001/tok) + confident_wrong(-0.5) + abstention(+0.3) """ rewards = [] for i, comp in enumerate(completions): if isinstance(comp, list): content = comp[0].get("content", "") else: content = str(comp) gt = str(solution[i]) if i < len(solution) else "" cl = content.lower() # Extract final answer from \boxed{} or "answer is X" final = None bm = re.search(r"\\boxed\{(.*?)\}", content) if bm: final = bm.group(1).strip() else: am = re.search(r"(?:answer|result|solution)\s*(?:is|=)\s*([^\s,.]+)", cl) if am: final = am.group(1).strip() correctness = 1.0 if (final and final == gt.strip()) else (-0.5 if not final else -1.0) format_r = 0.1 if re.search(r"|think", content, re.I) else 0.0 n_tok = len(completion_ids[i]) if completion_ids else len(content.split()) cost_p = -0.001 * n_tok is_conf = any(m in cl for m in ["definitely","certainly","obviously","clearly"]) cw_p = -0.5 if (is_conf and correctness < 0) else 0.0 is_abst = any(m in cl for m in ["don't know","cannot determine","uncertain","not sure"]) abst_b = 0.3 if is_abst else 0.0 rewards.append(correctness + format_r + cost_p + cw_p + abst_b) if kwargs.get("log_extra"): kwargs["log_extra"]("correctness", [1.0 if r > 0 else -1.0 if r < 0 else 0.0 for r in rewards]) return rewards # ═══ Main ═══ print("[OCC-GRPO] Loading dataset...", flush=True) dataset = load_dataset("trl-lib/DeepMath-103K", split="train").select(range(100)) print(f"[OCC-GRPO] Loaded {len(dataset)} examples", flush=True) args = GRPOConfig( output_dir="./occ_grpo_output", per_device_train_batch_size=2, max_steps=30, logging_steps=5, save_steps=30, learning_rate=1e-6, bf16=torch.cuda.is_bf16_supported(), fp16=not torch.cuda.is_bf16_supported(), gradient_checkpointing=True, gradient_accumulation_steps=2, max_completion_length=256, num_generations=4, generation_batch_size=4, report_to="none", disable_tqdm=True, logging_strategy="steps", logging_first_step=True, remove_unused_columns=False, ) trainer = GRPOTrainer( model="Qwen/Qwen2.5-0.5B-Instruct", args=args, reward_funcs=occ_reward, train_dataset=dataset, ) print("[OCC-GRPO] Training...", flush=True) result = trainer.train() print(f"[OCC-GRPO] Done. Global step: {result.global_step}, loss: {result.training_loss:.4f}", flush=True) trainer.save_model("./occ_grpo_output/final") print("[OCC-GRPO] Model saved to ./occ_grpo_output/final", flush=True) summary = { "method": "GRPO with OCC cost-adjusted reward", "model": "Qwen/Qwen2.5-0.5B-Instruct", "dataset": "trl-lib/DeepMath-103K (100 subset)", "steps": result.global_step, "training_loss": result.training_loss, } print(json.dumps(summary, indent=2), flush=True)