Baladithya Balamurugan
Wave 1: fix 8 failing tests + unblock Docker E2E + dep/doc debt
c11cf49
Raw
History Blame Contribute Delete
9.07 kB
"""Plain GRPO + verifiable reward on 100 GSM8K rows (Qwen2.5-0.5B-Instruct, CPU).
This is the minimum-viable end-to-end recipe a new user is most likely to want
from a GRPO framework: wire the framework's `ComposerReplicationTrainer` into a
real dataset (GSM8K) with a real verifiable reward (regex-extract `#### NUMBER`
and string-compare against gold) and run a couple of outer steps to verify the
training loop works.
What this script demonstrates:
- `ComposerReplicationTrainer` with `alpha_sdpo=0` and `beta_replay=0` (plain
GRPO — channels 2 and 3 disabled). This is the v0.1 recommended ablation
baseline per `docs/USER_GUIDE.md` §8 Recipe A.
- A regex-based reward that returns 1.0 when the model's `#### NUMBER` line
matches the gold answer, 0.0 otherwise. RLVR-style. No reward model.
- CPU-only execution. Slow but works without a GPU; one outer step takes
several minutes because TRL generates `num_generations` rollouts per
prompt and we keep them small (4 generations, 64 max completion tokens).
Usage:
pip install -e ".[train]"
python examples/gsm8k_grpo/run.py
Cross-references:
- `docs/USER_GUIDE.md` §8 — Recipe A: TRL `GRPOTrainer` subclass
- `docs/INTEGRATION_RECIPES.md` Recipe 1 — minimum-viable Python script
- `docs/adrs/ADR-008-drgrpo-sdpo-live-channel.md` — SDPO design (not used here; see
`run_with_sdpo.py` for the SDPO variant)
"""
from __future__ import annotations
import logging
import os
import random
import re
import sys
import time
from pathlib import Path
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from composer_replication import ComposerReplicationTrainer
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct"
N_TRAIN_ROWS = 100 # toy size — see README "Production scaling" notes
N_OUTER_STEPS = 2 # just enough to verify the loop runs
NUM_GENERATIONS = 4 # rollouts per prompt; keep small on CPU
MAX_PROMPT_LEN = 256
MAX_COMPLETION_LEN = 64
OUTPUT_DIR = Path(__file__).resolve().parent / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------------
# Reward function — verifiable (regex extract + match)
# ---------------------------------------------------------------------------
# GSM8K answer format: the gold answer ends with `#### NUMBER`. We require the
# model to emit the same `#### NUMBER` marker. This is the canonical RLVR
# reward used in the GRPO/DeepSeek-R1 literature on math word problems.
_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
def _extract_answer(text: str) -> str | None:
"""Pull the last `#### NUMBER` group out of `text`. Returns the numeric
string (so `'#### 72'` → `'72'`), or None if no marker is found."""
matches = _ANSWER_RE.findall(text or "")
return matches[-1].strip() if matches else None
def gsm8k_reward(completions, **kwargs):
"""TRL-format reward callable.
Args:
completions: list of generated completions for one batch.
Either list[str] (text) or list[list[dict]] (conversational); we
normalize both. TRL passes the rollout completions here.
kwargs: arbitrary dataset columns. We expect 'gold_answer' (str) and
optionally 'prompts' (TRL passes the input prompts as kwargs).
Returns:
list[float] with len == len(completions). 1.0 if the regex-extracted
answer matches the gold, else 0.0.
"""
gold = kwargs.get("gold_answer")
if gold is None:
return [0.0] * len(completions)
rewards: list[float] = []
for completion, gold_ans in zip(completions, gold, strict=False):
# Conversational completions: list of {"role", "content"} dicts.
if isinstance(completion, list):
text = "\n".join(m.get("content", "") for m in completion)
else:
text = str(completion)
pred = _extract_answer(text)
if pred is not None and pred == str(gold_ans).strip():
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards
# ---------------------------------------------------------------------------
# Data loading
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
"You are a math tutor. Solve the problem step by step. "
"End your answer with `#### N` where N is the final numeric answer."
)
def build_dataset():
raw = load_dataset("openai/gsm8k", "main", split=f"train[:{N_TRAIN_ROWS}]")
def _format(row):
# TRL GRPOTrainer accepts conversational `prompt` (list[dict]). We
# pre-extract the gold numeric answer so the reward function can do
# an exact-match.
gold = _extract_answer(row["answer"]) or ""
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": row["question"]},
],
"gold_answer": gold,
}
return raw.map(_format, remove_columns=raw.column_names)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> int:
# Reproducibility
random.seed(42)
torch.manual_seed(42)
log_path = OUTPUT_DIR.parent / "run.log"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_path, mode="w"),
],
)
log = logging.getLogger("gsm8k_grpo")
log.info("=" * 64)
log.info("Plain GRPO + GSM8K + Qwen2.5-0.5B-Instruct (CPU)")
log.info("=" * 64)
log.info("[1/4] Loading model + tokenizer ...")
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32)
model.to("cpu")
log.info(" loaded in %.1fs (%.3fB params)",
time.time() - t0,
sum(p.numel() for p in model.parameters()) / 1e9)
log.info("[2/4] Loading %d GSM8K rows ...", N_TRAIN_ROWS)
dataset = build_dataset()
log.info(" example row: prompt=%s ... gold=%s",
dataset[0]["prompt"][1]["content"][:80], dataset[0]["gold_answer"])
log.info("[3/4] Building ComposerReplicationTrainer (alpha_sdpo=0, beta_replay=0) ...")
# Lazy import: GRPOConfig requires `trl` (in the [train] extra). The
# framework's __init__ falls back gracefully when TRL is missing, but
# GRPOConfig does not.
from trl import GRPOConfig
config = GRPOConfig(
output_dir=str(OUTPUT_DIR),
per_device_train_batch_size=NUM_GENERATIONS, # 1 prompt × num_generations rollouts
gradient_accumulation_steps=1,
num_generations=NUM_GENERATIONS,
# NOTE: TRL 1.5+ dropped GRPOConfig.max_prompt_length; prompts are
# tokenized by the rollout pipeline at generation time. Use
# tokenizer.model_max_length to bound prompts.
max_completion_length=MAX_COMPLETION_LEN,
learning_rate=1e-5,
max_steps=N_OUTER_STEPS,
logging_steps=1,
save_strategy="no",
report_to=[],
# CPU-only — disable cuda/mps auto-detect.
no_cuda=True,
use_cpu=True,
# Plain-GRPO sanity: disable the KL-to-reference penalty (beta=0) so
# there's no reference-model forward pass on CPU.
beta=0.0,
seed=42,
bf16=False,
fp16=False,
)
trainer = ComposerReplicationTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[gsm8k_reward],
train_dataset=dataset,
args=config,
# Channels 2 (SDPO) + 3 (trace-replay DPO) disabled — pure GRPO.
alpha_sdpo=0.0,
beta_replay=0.0,
)
log.info("[4/4] Training for %d outer steps ...", N_OUTER_STEPS)
t0 = time.time()
train_result = trainer.train()
dt = time.time() - t0
log.info("Training complete in %.1fs", dt)
# Persist final state
final_dir = OUTPUT_DIR / "final"
final_dir.mkdir(exist_ok=True)
trainer.save_model(str(final_dir))
log.info("Final model saved to %s", final_dir)
# Summary
metrics = train_result.metrics
log.info("=" * 64)
log.info("Summary")
log.info("=" * 64)
log.info(" steps: %s", metrics.get("train_steps", N_OUTER_STEPS))
log.info(" train_loss: %.6f", metrics.get("train_loss", float("nan")))
log.info(" train_runtime: %.1fs", metrics.get("train_runtime", dt))
log.info(" log file: %s", log_path)
return 0
if __name__ == "__main__":
sys.exit(main())