occ-stack / scripts /grpo_train_occ.py
narcolepticchicken's picture
Upload scripts/grpo_train_occ.py
be7781b verified
#!/usr/bin/env python3
"""OCC GRPO Training Demo β€” Minimal end-to-end GRPO with OCC reward hook.
Trains Qwen2.5-0.5B-Instruct with GRPO using a cost-adjusted marginal impact reward.
The reward combines correctness, format, cost penalty, confident-wrong penalty, and abstention bonus.
Dataset: trl-lib/DeepMath-103K (column: "solution")
"""
import re, json, torch
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
def occ_reward(completions, solution, completion_ids=None, prompts=None, **kwargs):
"""
OCC cost-adjusted reward for GRPO.
Dataset must have 'solution' column (DeepMath-103K format).
Reward = correctness(Β±1.0) + format(+0.1) + cost(-0.001/tok) + confident_wrong(-0.5) + abstention(+0.3)
"""
rewards = []
for i, comp in enumerate(completions):
if isinstance(comp, list): content = comp[0].get("content", "")
else: content = str(comp)
gt = str(solution[i]) if i < len(solution) else ""
cl = content.lower()
# Extract final answer from \boxed{} or "answer is X"
final = None
bm = re.search(r"\\boxed\{(.*?)\}", content)
if bm: final = bm.group(1).strip()
else:
am = re.search(r"(?:answer|result|solution)\s*(?:is|=)\s*([^\s,.]+)", cl)
if am: final = am.group(1).strip()
correctness = 1.0 if (final and final == gt.strip()) else (-0.5 if not final else -1.0)
format_r = 0.1 if re.search(r"|think", content, re.I) else 0.0
n_tok = len(completion_ids[i]) if completion_ids else len(content.split())
cost_p = -0.001 * n_tok
is_conf = any(m in cl for m in ["definitely","certainly","obviously","clearly"])
cw_p = -0.5 if (is_conf and correctness < 0) else 0.0
is_abst = any(m in cl for m in ["don't know","cannot determine","uncertain","not sure"])
abst_b = 0.3 if is_abst else 0.0
rewards.append(correctness + format_r + cost_p + cw_p + abst_b)
if kwargs.get("log_extra"):
kwargs["log_extra"]("correctness", [1.0 if r > 0 else -1.0 if r < 0 else 0.0 for r in rewards])
return rewards
# ═══ Main ═══
print("[OCC-GRPO] Loading dataset...", flush=True)
dataset = load_dataset("trl-lib/DeepMath-103K", split="train").select(range(100))
print(f"[OCC-GRPO] Loaded {len(dataset)} examples", flush=True)
args = GRPOConfig(
output_dir="./occ_grpo_output",
per_device_train_batch_size=2,
max_steps=30,
logging_steps=5,
save_steps=30,
learning_rate=1e-6,
bf16=torch.cuda.is_bf16_supported(),
fp16=not torch.cuda.is_bf16_supported(),
gradient_checkpointing=True,
gradient_accumulation_steps=2,
max_completion_length=256,
num_generations=4,
generation_batch_size=4,
report_to="none",
disable_tqdm=True,
logging_strategy="steps",
logging_first_step=True,
remove_unused_columns=False,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
args=args,
reward_funcs=occ_reward,
train_dataset=dataset,
)
print("[OCC-GRPO] Training...", flush=True)
result = trainer.train()
print(f"[OCC-GRPO] Done. Global step: {result.global_step}, loss: {result.training_loss:.4f}", flush=True)
trainer.save_model("./occ_grpo_output/final")
print("[OCC-GRPO] Model saved to ./occ_grpo_output/final", flush=True)
summary = {
"method": "GRPO with OCC cost-adjusted reward",
"model": "Qwen/Qwen2.5-0.5B-Instruct",
"dataset": "trl-lib/DeepMath-103K (100 subset)",
"steps": result.global_step,
"training_loss": result.training_loss,
}
print(json.dumps(summary, indent=2), flush=True)