| |
| """OCC GRPO Training Demo β Minimal end-to-end GRPO with OCC reward hook. |
| |
| Trains Qwen2.5-0.5B-Instruct with GRPO using a cost-adjusted marginal impact reward. |
| The reward combines correctness, format, cost penalty, confident-wrong penalty, and abstention bonus. |
| |
| Dataset: trl-lib/DeepMath-103K (column: "solution") |
| """ |
|
|
| import re, json, torch |
| from datasets import load_dataset |
| from trl import GRPOTrainer, GRPOConfig |
|
|
| def occ_reward(completions, solution, completion_ids=None, prompts=None, **kwargs): |
| """ |
| OCC cost-adjusted reward for GRPO. |
| |
| Dataset must have 'solution' column (DeepMath-103K format). |
| Reward = correctness(Β±1.0) + format(+0.1) + cost(-0.001/tok) + confident_wrong(-0.5) + abstention(+0.3) |
| """ |
| rewards = [] |
| for i, comp in enumerate(completions): |
| if isinstance(comp, list): content = comp[0].get("content", "") |
| else: content = str(comp) |
| gt = str(solution[i]) if i < len(solution) else "" |
| cl = content.lower() |
|
|
| |
| final = None |
| bm = re.search(r"\\boxed\{(.*?)\}", content) |
| if bm: final = bm.group(1).strip() |
| else: |
| am = re.search(r"(?:answer|result|solution)\s*(?:is|=)\s*([^\s,.]+)", cl) |
| if am: final = am.group(1).strip() |
|
|
| correctness = 1.0 if (final and final == gt.strip()) else (-0.5 if not final else -1.0) |
| format_r = 0.1 if re.search(r"|think", content, re.I) else 0.0 |
| n_tok = len(completion_ids[i]) if completion_ids else len(content.split()) |
| cost_p = -0.001 * n_tok |
| is_conf = any(m in cl for m in ["definitely","certainly","obviously","clearly"]) |
| cw_p = -0.5 if (is_conf and correctness < 0) else 0.0 |
| is_abst = any(m in cl for m in ["don't know","cannot determine","uncertain","not sure"]) |
| abst_b = 0.3 if is_abst else 0.0 |
|
|
| rewards.append(correctness + format_r + cost_p + cw_p + abst_b) |
|
|
| if kwargs.get("log_extra"): |
| kwargs["log_extra"]("correctness", [1.0 if r > 0 else -1.0 if r < 0 else 0.0 for r in rewards]) |
| return rewards |
|
|
| |
| print("[OCC-GRPO] Loading dataset...", flush=True) |
| dataset = load_dataset("trl-lib/DeepMath-103K", split="train").select(range(100)) |
| print(f"[OCC-GRPO] Loaded {len(dataset)} examples", flush=True) |
|
|
| args = GRPOConfig( |
| output_dir="./occ_grpo_output", |
| per_device_train_batch_size=2, |
| max_steps=30, |
| logging_steps=5, |
| save_steps=30, |
| learning_rate=1e-6, |
| bf16=torch.cuda.is_bf16_supported(), |
| fp16=not torch.cuda.is_bf16_supported(), |
| gradient_checkpointing=True, |
| gradient_accumulation_steps=2, |
| max_completion_length=256, |
| num_generations=4, |
| generation_batch_size=4, |
| report_to="none", |
| disable_tqdm=True, |
| logging_strategy="steps", |
| logging_first_step=True, |
| remove_unused_columns=False, |
| ) |
|
|
| trainer = GRPOTrainer( |
| model="Qwen/Qwen2.5-0.5B-Instruct", |
| args=args, |
| reward_funcs=occ_reward, |
| train_dataset=dataset, |
| ) |
|
|
| print("[OCC-GRPO] Training...", flush=True) |
| result = trainer.train() |
| print(f"[OCC-GRPO] Done. Global step: {result.global_step}, loss: {result.training_loss:.4f}", flush=True) |
| trainer.save_model("./occ_grpo_output/final") |
| print("[OCC-GRPO] Model saved to ./occ_grpo_output/final", flush=True) |
|
|
| summary = { |
| "method": "GRPO with OCC cost-adjusted reward", |
| "model": "Qwen/Qwen2.5-0.5B-Instruct", |
| "dataset": "trl-lib/DeepMath-103K (100 subset)", |
| "steps": result.global_step, |
| "training_loss": result.training_loss, |
| } |
| print(json.dumps(summary, indent=2), flush=True) |
|
|