Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Plain GRPO + verifiable reward on 100 GSM8K rows (Qwen2.5-0.5B-Instruct, CPU). | |
| This is the minimum-viable end-to-end recipe a new user is most likely to want | |
| from a GRPO framework: wire the framework's `ComposerReplicationTrainer` into a | |
| real dataset (GSM8K) with a real verifiable reward (regex-extract `#### NUMBER` | |
| and string-compare against gold) and run a couple of outer steps to verify the | |
| training loop works. | |
| What this script demonstrates: | |
| - `ComposerReplicationTrainer` with `alpha_sdpo=0` and `beta_replay=0` (plain | |
| GRPO — channels 2 and 3 disabled). This is the v0.1 recommended ablation | |
| baseline per `docs/USER_GUIDE.md` §8 Recipe A. | |
| - A regex-based reward that returns 1.0 when the model's `#### NUMBER` line | |
| matches the gold answer, 0.0 otherwise. RLVR-style. No reward model. | |
| - CPU-only execution. Slow but works without a GPU; one outer step takes | |
| several minutes because TRL generates `num_generations` rollouts per | |
| prompt and we keep them small (4 generations, 64 max completion tokens). | |
| Usage: | |
| pip install -e ".[train]" | |
| python examples/gsm8k_grpo/run.py | |
| Cross-references: | |
| - `docs/USER_GUIDE.md` §8 — Recipe A: TRL `GRPOTrainer` subclass | |
| - `docs/INTEGRATION_RECIPES.md` Recipe 1 — minimum-viable Python script | |
| - `docs/adrs/ADR-008-drgrpo-sdpo-live-channel.md` — SDPO design (not used here; see | |
| `run_with_sdpo.py` for the SDPO variant) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import os | |
| import random | |
| import re | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from datasets import load_dataset | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from composer_replication import ComposerReplicationTrainer | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct" | |
| N_TRAIN_ROWS = 100 # toy size — see README "Production scaling" notes | |
| N_OUTER_STEPS = 2 # just enough to verify the loop runs | |
| NUM_GENERATIONS = 4 # rollouts per prompt; keep small on CPU | |
| MAX_PROMPT_LEN = 256 | |
| MAX_COMPLETION_LEN = 64 | |
| OUTPUT_DIR = Path(__file__).resolve().parent / "output" | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # --------------------------------------------------------------------------- | |
| # Reward function — verifiable (regex extract + match) | |
| # --------------------------------------------------------------------------- | |
| # GSM8K answer format: the gold answer ends with `#### NUMBER`. We require the | |
| # model to emit the same `#### NUMBER` marker. This is the canonical RLVR | |
| # reward used in the GRPO/DeepSeek-R1 literature on math word problems. | |
| _ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") | |
| def _extract_answer(text: str) -> str | None: | |
| """Pull the last `#### NUMBER` group out of `text`. Returns the numeric | |
| string (so `'#### 72'` → `'72'`), or None if no marker is found.""" | |
| matches = _ANSWER_RE.findall(text or "") | |
| return matches[-1].strip() if matches else None | |
| def gsm8k_reward(completions, **kwargs): | |
| """TRL-format reward callable. | |
| Args: | |
| completions: list of generated completions for one batch. | |
| Either list[str] (text) or list[list[dict]] (conversational); we | |
| normalize both. TRL passes the rollout completions here. | |
| kwargs: arbitrary dataset columns. We expect 'gold_answer' (str) and | |
| optionally 'prompts' (TRL passes the input prompts as kwargs). | |
| Returns: | |
| list[float] with len == len(completions). 1.0 if the regex-extracted | |
| answer matches the gold, else 0.0. | |
| """ | |
| gold = kwargs.get("gold_answer") | |
| if gold is None: | |
| return [0.0] * len(completions) | |
| rewards: list[float] = [] | |
| for completion, gold_ans in zip(completions, gold, strict=False): | |
| # Conversational completions: list of {"role", "content"} dicts. | |
| if isinstance(completion, list): | |
| text = "\n".join(m.get("content", "") for m in completion) | |
| else: | |
| text = str(completion) | |
| pred = _extract_answer(text) | |
| if pred is not None and pred == str(gold_ans).strip(): | |
| rewards.append(1.0) | |
| else: | |
| rewards.append(0.0) | |
| return rewards | |
| # --------------------------------------------------------------------------- | |
| # Data loading | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = ( | |
| "You are a math tutor. Solve the problem step by step. " | |
| "End your answer with `#### N` where N is the final numeric answer." | |
| ) | |
| def build_dataset(): | |
| raw = load_dataset("openai/gsm8k", "main", split=f"train[:{N_TRAIN_ROWS}]") | |
| def _format(row): | |
| # TRL GRPOTrainer accepts conversational `prompt` (list[dict]). We | |
| # pre-extract the gold numeric answer so the reward function can do | |
| # an exact-match. | |
| gold = _extract_answer(row["answer"]) or "" | |
| return { | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": row["question"]}, | |
| ], | |
| "gold_answer": gold, | |
| } | |
| return raw.map(_format, remove_columns=raw.column_names) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> int: | |
| # Reproducibility | |
| random.seed(42) | |
| torch.manual_seed(42) | |
| log_path = OUTPUT_DIR.parent / "run.log" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout), | |
| logging.FileHandler(log_path, mode="w"), | |
| ], | |
| ) | |
| log = logging.getLogger("gsm8k_grpo") | |
| log.info("=" * 64) | |
| log.info("Plain GRPO + GSM8K + Qwen2.5-0.5B-Instruct (CPU)") | |
| log.info("=" * 64) | |
| log.info("[1/4] Loading model + tokenizer ...") | |
| t0 = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32) | |
| model.to("cpu") | |
| log.info(" loaded in %.1fs (%.3fB params)", | |
| time.time() - t0, | |
| sum(p.numel() for p in model.parameters()) / 1e9) | |
| log.info("[2/4] Loading %d GSM8K rows ...", N_TRAIN_ROWS) | |
| dataset = build_dataset() | |
| log.info(" example row: prompt=%s ... gold=%s", | |
| dataset[0]["prompt"][1]["content"][:80], dataset[0]["gold_answer"]) | |
| log.info("[3/4] Building ComposerReplicationTrainer (alpha_sdpo=0, beta_replay=0) ...") | |
| # Lazy import: GRPOConfig requires `trl` (in the [train] extra). The | |
| # framework's __init__ falls back gracefully when TRL is missing, but | |
| # GRPOConfig does not. | |
| from trl import GRPOConfig | |
| config = GRPOConfig( | |
| output_dir=str(OUTPUT_DIR), | |
| per_device_train_batch_size=NUM_GENERATIONS, # 1 prompt × num_generations rollouts | |
| gradient_accumulation_steps=1, | |
| num_generations=NUM_GENERATIONS, | |
| # NOTE: TRL 1.5+ dropped GRPOConfig.max_prompt_length; prompts are | |
| # tokenized by the rollout pipeline at generation time. Use | |
| # tokenizer.model_max_length to bound prompts. | |
| max_completion_length=MAX_COMPLETION_LEN, | |
| learning_rate=1e-5, | |
| max_steps=N_OUTER_STEPS, | |
| logging_steps=1, | |
| save_strategy="no", | |
| report_to=[], | |
| # CPU-only — disable cuda/mps auto-detect. | |
| no_cuda=True, | |
| use_cpu=True, | |
| # Plain-GRPO sanity: disable the KL-to-reference penalty (beta=0) so | |
| # there's no reference-model forward pass on CPU. | |
| beta=0.0, | |
| seed=42, | |
| bf16=False, | |
| fp16=False, | |
| ) | |
| trainer = ComposerReplicationTrainer( | |
| model=model, | |
| processing_class=tokenizer, | |
| reward_funcs=[gsm8k_reward], | |
| train_dataset=dataset, | |
| args=config, | |
| # Channels 2 (SDPO) + 3 (trace-replay DPO) disabled — pure GRPO. | |
| alpha_sdpo=0.0, | |
| beta_replay=0.0, | |
| ) | |
| log.info("[4/4] Training for %d outer steps ...", N_OUTER_STEPS) | |
| t0 = time.time() | |
| train_result = trainer.train() | |
| dt = time.time() - t0 | |
| log.info("Training complete in %.1fs", dt) | |
| # Persist final state | |
| final_dir = OUTPUT_DIR / "final" | |
| final_dir.mkdir(exist_ok=True) | |
| trainer.save_model(str(final_dir)) | |
| log.info("Final model saved to %s", final_dir) | |
| # Summary | |
| metrics = train_result.metrics | |
| log.info("=" * 64) | |
| log.info("Summary") | |
| log.info("=" * 64) | |
| log.info(" steps: %s", metrics.get("train_steps", N_OUTER_STEPS)) | |
| log.info(" train_loss: %.6f", metrics.get("train_loss", float("nan"))) | |
| log.info(" train_runtime: %.1fs", metrics.get("train_runtime", dt)) | |
| log.info(" log file: %s", log_path) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |