File size: 3,635 Bytes
83ce3a2
 
 
be7781b
 
83ce3a2
be7781b
83ce3a2
 
be7781b
 
83ce3a2
 
be7781b
83ce3a2
be7781b
 
 
 
83ce3a2
 
be7781b
 
 
 
 
 
 
 
 
 
83ce3a2
be7781b
 
83ce3a2
be7781b
 
 
 
 
 
 
 
83ce3a2
be7781b
83ce3a2
 
 
 
 
be7781b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#!/usr/bin/env python3
"""OCC GRPO Training Demo — Minimal end-to-end GRPO with OCC reward hook.

Trains Qwen2.5-0.5B-Instruct with GRPO using a cost-adjusted marginal impact reward.
The reward combines correctness, format, cost penalty, confident-wrong penalty, and abstention bonus.

Dataset: trl-lib/DeepMath-103K (column: "solution")
"""

import re, json, torch
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig

def occ_reward(completions, solution, completion_ids=None, prompts=None, **kwargs):
    """
    OCC cost-adjusted reward for GRPO.
    
    Dataset must have 'solution' column (DeepMath-103K format).
    Reward = correctness(±1.0) + format(+0.1) + cost(-0.001/tok) + confident_wrong(-0.5) + abstention(+0.3)
    """
    rewards = []
    for i, comp in enumerate(completions):
        if isinstance(comp, list): content = comp[0].get("content", "")
        else: content = str(comp)
        gt = str(solution[i]) if i < len(solution) else ""
        cl = content.lower()

        # Extract final answer from \boxed{} or "answer is X"
        final = None
        bm = re.search(r"\\boxed\{(.*?)\}", content)
        if bm: final = bm.group(1).strip()
        else:
            am = re.search(r"(?:answer|result|solution)\s*(?:is|=)\s*([^\s,.]+)", cl)
            if am: final = am.group(1).strip()

        correctness = 1.0 if (final and final == gt.strip()) else (-0.5 if not final else -1.0)
        format_r = 0.1 if re.search(r"|think", content, re.I) else 0.0
        n_tok = len(completion_ids[i]) if completion_ids else len(content.split())
        cost_p = -0.001 * n_tok
        is_conf = any(m in cl for m in ["definitely","certainly","obviously","clearly"])
        cw_p = -0.5 if (is_conf and correctness < 0) else 0.0
        is_abst = any(m in cl for m in ["don't know","cannot determine","uncertain","not sure"])
        abst_b = 0.3 if is_abst else 0.0

        rewards.append(correctness + format_r + cost_p + cw_p + abst_b)

    if kwargs.get("log_extra"):
        kwargs["log_extra"]("correctness", [1.0 if r > 0 else -1.0 if r < 0 else 0.0 for r in rewards])
    return rewards

# ═══ Main ═══
print("[OCC-GRPO] Loading dataset...", flush=True)
dataset = load_dataset("trl-lib/DeepMath-103K", split="train").select(range(100))
print(f"[OCC-GRPO] Loaded {len(dataset)} examples", flush=True)

args = GRPOConfig(
    output_dir="./occ_grpo_output",
    per_device_train_batch_size=2,
    max_steps=30,
    logging_steps=5,
    save_steps=30,
    learning_rate=1e-6,
    bf16=torch.cuda.is_bf16_supported(),
    fp16=not torch.cuda.is_bf16_supported(),
    gradient_checkpointing=True,
    gradient_accumulation_steps=2,
    max_completion_length=256,
    num_generations=4,
    generation_batch_size=4,
    report_to="none",
    disable_tqdm=True,
    logging_strategy="steps",
    logging_first_step=True,
    remove_unused_columns=False,
)

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    args=args,
    reward_funcs=occ_reward,
    train_dataset=dataset,
)

print("[OCC-GRPO] Training...", flush=True)
result = trainer.train()
print(f"[OCC-GRPO] Done. Global step: {result.global_step}, loss: {result.training_loss:.4f}", flush=True)
trainer.save_model("./occ_grpo_output/final")
print("[OCC-GRPO] Model saved to ./occ_grpo_output/final", flush=True)

summary = {
    "method": "GRPO with OCC cost-adjusted reward",
    "model": "Qwen/Qwen2.5-0.5B-Instruct",
    "dataset": "trl-lib/DeepMath-103K (100 subset)",
    "steps": result.global_step,
    "training_loss": result.training_loss,
}
print(json.dumps(summary, indent=2), flush=True)