Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """SDPO wiring smoke through `ClaudeCodeIngester` on Qwen2.5-0.5B-Instruct (CPU, ~30s). | |
| This is the third example in the SDPO progression: | |
| examples/gsm8k_grpo/ -- plain GRPO, no SDPO (alpha=0) | |
| examples/gsm8k_grpo_with_sdpo/ -- SDPO on hand-crafted GSM8K problems | |
| examples/sdpo_with_real_traces/ -- SDPO column fed by `ClaudeCodeIngester` | |
| Where the previous SDPO example fabricated chat messages from 2 GSM8K | |
| prompts, this example uses `ClaudeCodeIngester` to read a Claude Code | |
| session JSONL fixture (the same one pinned by Spike 007's tests) and | |
| build an SDPO-shaped batch from the actual `TraceState` objects the | |
| ingester emits. | |
| NOTE — what this example does and does NOT prove: | |
| ✅ The `ClaudeCodeIngester` → `compose_loss` plumbing works end-to-end. | |
| ✅ The SDPO column fires (sdpo_jsd > 0); gradients flow; loss decreases. | |
| ❌ Production-quality SDPO alignment. The HINT is appended at the end | |
| of the messages list, so student/teacher contexts diverge in their | |
| right-edge tokens. The SDPO mask therefore covers DIFFERENT content | |
| in student vs teacher (see README §"Why the SDPO signal here is | |
| 'wiring proof,' not 'production-quality'"). For real training, | |
| use `composer_replication.trainer.data_collator.ComposerDataCollator` | |
| which builds properly-aligned `ctx_teacher_input_ids` and | |
| `sdpo_loss_mask` columns. | |
| ❌ Validation on a real-content session. The default fixture is | |
| hand-authored (PII reasons). To run on real content, point | |
| `FIXTURE_PATH` at one of your own `~/.claude/projects/...` files. | |
| This *partially* closes the V5 gap from `docs/VISION_VALIDATION.md` | |
| § 4.3 — the ingestion pipeline is validated; the real-data run depends | |
| on the user. | |
| Usage: | |
| pip install -e ".[train]" | |
| python examples/sdpo_with_real_traces/run.py | |
| Cross-references: | |
| - composer_replication.ingestion.claude_code.ClaudeCodeIngester | |
| - composer_replication.trainer.data_collator (production-grade alignment) | |
| - examples/gsm8k_grpo_with_sdpo/run.py (sibling — synthetic prompts) | |
| - spikes/007-real-trace-ingestion/fixtures/synthetic_session.jsonl | |
| (the fixture Spike 007's tests pin against) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import math | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from composer_replication import compose_loss | |
| from composer_replication.ingestion.claude_code import ClaudeCodeIngester | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct" | |
| N_STEPS = 5 | |
| T = 64 # sequence length — bumped vs gsm8k example because real | |
| # Claude Code traces have longer system + tool-result turns | |
| LR = 1e-5 | |
| ALPHA_SDPO = 0.5 | |
| BETA_REPLAY = 0.0 # DPO column off — focus on SDPO with real traces | |
| OUTPUT_DIR = Path(__file__).resolve().parent / "output" | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # The Spike 007 fixture is the only "real" Claude Code session JSONL we | |
| # can ship in the repo (real user sessions contain PII). It's a short | |
| # 8-message session (find Python files > 1MB) hand-authored to match | |
| # the actual Claude Code v2.1.143 wire format. | |
| FIXTURE_PATH = ( | |
| Path(__file__).resolve().parents[2] | |
| / "spikes" / "007-real-trace-ingestion" / "fixtures" / "synthetic_session.jsonl" | |
| ) | |
| # Hint inserted at the END of the trace context (just before the | |
| # assistant generation marker). This produces a non-zero SDPO signal | |
| # but with a methodological caveat documented in the README: | |
| # | |
| # The SDPO loss-mask region (right-most ALIGN_LEN positions) covers | |
| # DIFFERENT token content in student vs teacher — student's last | |
| # tokens are the user's tool-result, teacher's last tokens are the | |
| # HINT. So the JSD measured here reflects the model's prediction | |
| # divergence on different inputs, NOT a clean per-position teacher- | |
| # vs-student divergence on the same content. | |
| # | |
| # This is acceptable for a WIRING SMOKE TEST (proves the channel fires | |
| # on real-trace input, gradients flow, code path doesn't crash) but | |
| # would NOT be acceptable for a production training run. Production | |
| # SDPO requires a `ComposerDataCollator` that emits properly-aligned | |
| # `ctx_teacher_input_ids` and `sdpo_loss_mask` columns — see | |
| # `composer_replication/trainer/data_collator.py:_build_sdpo_fields`. | |
| # | |
| # A more rigorous version of this example would: (a) take the assistant | |
| # turn from the trace as the "target action" the student is predicting, | |
| # (b) align student/teacher contexts so the assistant turn is at the | |
| # same position in both, (c) place the HINT before that turn in the | |
| # teacher only, (d) mask only the assistant-response positions. That's | |
| # what the data_collator does in production. Out of scope for this | |
| # wiring smoke test. | |
| HINT = "Hint: prefer using `rg` over `find` for speed and simpler globs." | |
| # --------------------------------------------------------------------------- | |
| # Build inputs from a real Claude Code trace | |
| # --------------------------------------------------------------------------- | |
| def build_inputs_from_trace(tokenizer, fixture_path: Path) -> tuple[dict[str, torch.Tensor], int]: | |
| """Ingest fixture_path → TraceStates → SDPO-shaped batch. | |
| Each TraceState has a `["messages"]` field (list of OpenAI-style chat | |
| dicts up to and including the assistant turn the state was sliced | |
| at). We treat the LAST `B` TraceStates as our batch, with: | |
| - student context = state["messages"] (no hint) | |
| - teacher context = state["messages"] + [{"role": "system", "content": HINT}] | |
| Returns (inputs_dict, n_states_total). | |
| """ | |
| ingester = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=True) | |
| states = list(ingester.ingest(fixture_path)) | |
| if not states: | |
| raise RuntimeError( | |
| f"No TraceState yielded from {fixture_path}; check the fixture is valid." | |
| ) | |
| # Use the last 2 states (mimicking B=2 from gsm8k_grpo_with_sdpo) | |
| # for direct comparison. If the trace has fewer than 2 assistant | |
| # turns, replicate. | |
| if len(states) >= 2: | |
| batch_states = states[-2:] | |
| else: | |
| batch_states = states + states # replicate to fill B=2 | |
| student_msgs = [list(s["messages"]) for s in batch_states] | |
| # Insert HINT as a system turn at the END of the messages list | |
| # (just before apply_chat_template appends the assistant marker). | |
| # This guarantees the hint survives left-truncation AND is | |
| # positionally close enough to the assistant turn to actually | |
| # perturb the model's logits. | |
| teacher_msgs = [ | |
| list(s["messages"]) + [{"role": "system", "content": HINT}] | |
| for s in batch_states | |
| ] | |
| student_strs = [ | |
| tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) | |
| for m in student_msgs | |
| ] | |
| teacher_strs = [ | |
| tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) | |
| for m in teacher_msgs | |
| ] | |
| # LEFT-pad + LEFT-truncate so right edge (the assistant marker + final | |
| # user/tool turn) aligns across student and teacher and across the | |
| # batch. Same alignment discipline as gsm8k_grpo_with_sdpo per Wave 17b. | |
| original_pad = tokenizer.padding_side | |
| original_trunc = tokenizer.truncation_side | |
| tokenizer.padding_side = "left" | |
| tokenizer.truncation_side = "left" | |
| try: | |
| s_tok = tokenizer( | |
| student_strs, max_length=T, truncation=True, | |
| padding="max_length", return_tensors="pt", | |
| ) | |
| t_tok = tokenizer( | |
| teacher_strs, max_length=T, truncation=True, | |
| padding="max_length", return_tensors="pt", | |
| ) | |
| finally: | |
| tokenizer.padding_side = original_pad | |
| tokenizer.truncation_side = original_trunc | |
| B = s_tok["input_ids"].shape[0] | |
| ALIGN_LEN = T // 2 | |
| response_mask = torch.zeros(B, T, dtype=torch.long) | |
| response_mask[:, -ALIGN_LEN:] = 1 | |
| sdpo_loss_mask = response_mask.clone() | |
| inputs = { | |
| "input_ids": s_tok["input_ids"], | |
| "response_mask": response_mask, | |
| "ctx_teacher_input_ids": t_tok["input_ids"], | |
| "sdpo_loss_mask": sdpo_loss_mask, | |
| } | |
| return inputs, len(states) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main() -> int: | |
| torch.manual_seed(42) | |
| log_path = OUTPUT_DIR.parent / "run.log" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(log_path, mode="w")], | |
| ) | |
| log = logging.getLogger("sdpo_with_real_traces") | |
| log.info("=" * 64) | |
| log.info("SDPO + REAL Claude Code traces + Qwen2.5-0.5B-Instruct (CPU)") | |
| log.info("=" * 64) | |
| # Verify fixture exists before model load — fast-fail | |
| if not FIXTURE_PATH.is_file(): | |
| log.error("Fixture not found at %s", FIXTURE_PATH) | |
| log.error( | |
| "This example requires the Spike 007 fixture at " | |
| "spikes/007-real-trace-ingestion/fixtures/synthetic_session.jsonl. " | |
| "If you're running outside the repo, point this script at any " | |
| "Claude Code session JSONL file." | |
| ) | |
| return 1 | |
| log.info("[1/5] Fixture path: %s (size=%d bytes)", | |
| FIXTURE_PATH, FIXTURE_PATH.stat().st_size) | |
| log.info("[2/5] Loading model + tokenizer ...") | |
| t0 = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32) | |
| model.to("cpu") | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| log.info(" loaded in %.1fs (%.3fB params)", time.time() - t0, n_params / 1e9) | |
| log.info("[3/5] Ingesting trace + building SDPO batch (T=%d) ...", T) | |
| inputs, n_states = build_inputs_from_trace(tokenizer, FIXTURE_PATH) | |
| for k, v in inputs.items(): | |
| log.info(" %s: shape=%s dtype=%s", k, tuple(v.shape), v.dtype) | |
| B = inputs["input_ids"].shape[0] | |
| log.info("[4/5] Running %d SGD steps with alpha_sdpo=%.2f on B=%d ...", | |
| N_STEPS, ALPHA_SDPO, B) | |
| optim = torch.optim.SGD(model.parameters(), lr=LR) | |
| history: list[dict[str, float]] = [] | |
| model.train() | |
| t0 = time.time() | |
| for step in range(N_STEPS): | |
| optim.zero_grad() | |
| out = compose_loss( | |
| model, inputs, | |
| alpha_sdpo=ALPHA_SDPO, beta_replay=BETA_REPLAY, | |
| ) | |
| out.total.backward() | |
| gnorm = sum( | |
| p.grad.abs().sum().item() for p in model.parameters() if p.grad is not None | |
| ) | |
| optim.step() | |
| components = out.detached() | |
| components["grad_norm"] = gnorm | |
| history.append(components) | |
| log.info( | |
| " step %d/%d: total=%.4f lm_ce=%.4f sdpo_jsd=%.4f trace_replay_dpo=%.4f |grad|=%.2e", | |
| step + 1, N_STEPS, | |
| components["total"], components["lm_ce"], | |
| components["sdpo_jsd"], components["trace_replay_dpo"], | |
| gnorm, | |
| ) | |
| dt = time.time() - t0 | |
| log.info("Training complete in %.1fs (avg %.1fs/step)", dt, dt / N_STEPS) | |
| log.info("[5/5] Verifying SDPO column wiring on real trace ...") | |
| sdpo_values = [h["sdpo_jsd"] for h in history] | |
| assert all(s > 0.0 for s in sdpo_values), ( | |
| f"SDPO column is identically zero — channel did not fire on real trace. " | |
| f"sdpo_jsd values: {sdpo_values}" | |
| ) | |
| log.info(" ✓ sdpo_jsd > 0 at every step (min=%.4f max=%.4f)", | |
| min(sdpo_values), max(sdpo_values)) | |
| diffs = [abs(h["total"] - h["lm_ce"]) for h in history] | |
| assert all(d > 1e-6 for d in diffs), ( | |
| f"total ≈ lm_ce — SDPO contribution negligible. abs(total-lm_ce)={diffs}" | |
| ) | |
| log.info(" ✓ total != lm_ce at every step (min |diff|=%.4f max=%.4f)", | |
| min(diffs), max(diffs)) | |
| gnorms = [h["grad_norm"] for h in history] | |
| assert all(g > 0.0 for g in gnorms), f"Some steps had zero |grad|: {gnorms}" | |
| assert all(math.isfinite(g) for g in gnorms), f"Non-finite |grad|: {gnorms}" | |
| log.info(" ✓ |grad| > 0 and finite at every step (min=%.2e max=%.2e)", | |
| min(gnorms), max(gnorms)) | |
| log.info("=" * 64) | |
| log.info("Summary") | |
| log.info("=" * 64) | |
| log.info(" trace fixture: %s", FIXTURE_PATH.name) | |
| log.info(" ingested states: %d (used last %d for batch)", n_states, B) | |
| log.info(" steps: %d", N_STEPS) | |
| log.info(" alpha_sdpo: %.2f", ALPHA_SDPO) | |
| log.info(" model params: %.3fB", n_params / 1e9) | |
| log.info(" total step 1: %.4f", history[0]["total"]) | |
| log.info(" total step %d: %.4f", N_STEPS, history[-1]["total"]) | |
| log.info(" wall-clock: %.1fs", dt) | |
| log.info("=" * 64) | |
| log.info("✅ SDPO column wiring verified end-to-end on ClaudeCodeIngester output.") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |