Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """SDPO real-trace TRAINING smoke — the missing forward+backward+step link. | |
| Why this exists | |
| --------------- | |
| The framework already proves two halves of the SDPO loop in isolation: | |
| * ``examples/validate_real_trace_alignment/run.py`` proves the REAL-trace | |
| DATA path: ingestion -> adapter -> collator emits a batch whose | |
| ``sdpo_loss_mask`` lands on content tokens at ~100% alignment. | |
| * ``composer_replication/tests/test_gradient_flow.py`` proves | |
| ``compose_loss`` routes finite non-zero gradients through the SDPO | |
| channel — but only on a millisecond ``TinyLM`` stand-in (no HF model). | |
| Nobody has connected them: an actual forward + backward + optimizer step | |
| of ``compose_loss`` on a REAL HuggingFace model fed by the REAL-trace | |
| collator. That is the one unproven edge, and it is exactly the (never | |
| implemented) ``sdpo_with_real_traces_production`` module that the Modal | |
| ``stage_4_sdpo_smoke`` referenced. This script IS that module, made real. | |
| What it asserts (the smoke gates) | |
| --------------------------------- | |
| 1. The collated real-trace batch drives ``compose_loss`` without crashing. | |
| 2. ``total`` loss is finite (not NaN/Inf) across all steps. | |
| 3. The SDPO channel actually FIRES: ``sdpo_jsd`` is strictly > 0 on at | |
| least one step (proves the shape-gate at loss.py:163 passed and the | |
| hint-conditioned teacher forward contributed real signal — not the | |
| silent no-op that the empty-placeholder stage_4 would have produced). | |
| 4. A real parameter MOVED after ``optimizer.step()`` (training happened). | |
| Runs on CPU against Qwen/Qwen2.5-0.5B-Instruct — the correct target for | |
| agent-trace SDPO (a small instruct model with a chat template), NOT the | |
| trained nanochat (which has no tool-use / error-recovery structure and | |
| therefore yields 0% SDPO error sites by construction). ~$0 cost. | |
| Usage | |
| ----- | |
| python examples/sdpo_real_trace_train_smoke/run.py \ | |
| [--projects-dir ~/.claude/projects] \ | |
| [--max-sessions 6] [--model Qwen/Qwen2.5-0.5B-Instruct] \ | |
| [--max-steps 5] [--lr 1e-5] | |
| Exit 0 = PASS (all gates), 1 = FAIL (a gate failed), 2 = SKIP (no | |
| error-bearing sessions / no chat-template model available). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import traceback | |
| from pathlib import Path | |
| def _discover_error_sessions(projects_dir: Path, limit: int) -> list[Path]: | |
| """Find session JSONLs containing >=1 is_error:true tool_result, skipping | |
| subagent (`agent-*`) files. Smallest first (faster, still representative).""" | |
| hits: list[tuple[int, Path]] = [] | |
| for p in projects_dir.rglob("*.jsonl"): | |
| if p.name.startswith("agent-"): | |
| continue | |
| try: | |
| text = p.read_text(encoding="utf-8", errors="ignore") | |
| except OSError: | |
| continue | |
| if '"is_error":true' in text or '"is_error": true' in text: | |
| hits.append((p.stat().st_size, p)) | |
| hits.sort(key=lambda t: t[0]) | |
| return [p for _, p in hits[:limit]] | |
| def main() -> int: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--projects-dir", default=str(Path.home() / ".claude" / "projects")) | |
| ap.add_argument("--max-sessions", type=int, default=6) | |
| ap.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct") | |
| ap.add_argument("--max-steps", type=int, default=5) | |
| ap.add_argument("--lr", type=float, default=1e-5) | |
| ap.add_argument("--alpha-sdpo", type=float, default=1.0, | |
| help="SDPO channel weight. Default 1.0 (not the lib default " | |
| "0.1) so the smoke exercises the SDPO path strongly.") | |
| ap.add_argument("--max-seq-len", type=int, default=2048, | |
| help="Cap collated seq len to keep the CPU forward cheap.") | |
| ap.add_argument("--max-examples", type=int, default=4, | |
| help="Cap examples per collated batch. Qwen2.5 vocab is " | |
| "151936, so fp32 logits are ~1.2GB per (example, " | |
| "2048-tok) forward; SDPO does 2 forwards. Lower this " | |
| "if the CPU run gets OOM-killed (exit 137).") | |
| ap.add_argument("--dtype", choices=["bf16", "fp32"], default="bf16", | |
| help="Model+activation dtype. Default bf16 halves the giant " | |
| "(B, T, 151936) logit tensors — the fp32 forward+backward " | |
| "transiently hits ~27GB and trips the gateway cgroup OOM " | |
| "killer (exit 137). bf16 keeps the smoke under the limit.") | |
| ap.add_argument( | |
| "--strip-thinking", | |
| action="store_true", | |
| help="Strip [THINKING] blocks. DEFAULT FALSE for SDPO: on real Claude " | |
| "Code traces the recovery turn is frequently pure thinking, so " | |
| "stripping empties ~67%% of error sites and the SDPO channel sees " | |
| "no signal. Keep thinking for hint-distillation.", | |
| ) | |
| args = ap.parse_args() | |
| os.environ.setdefault("HF_HUB_OFFLINE", "1") | |
| os.environ.setdefault("TRANSFORMERS_OFFLINE", "1") | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from composer_replication import compose_loss | |
| from composer_replication.ingestion import ClaudeCodeIngester | |
| from composer_replication.ingestion.trace_examples import ( | |
| claude_states_to_trace_examples, | |
| ) | |
| from composer_replication.trainer.data_collator import ( | |
| CollatorConfig, | |
| ComposerDataCollator, | |
| ) | |
| projects_dir = Path(args.projects_dir).expanduser() | |
| if not projects_dir.exists(): | |
| print(f"projects dir not found: {projects_dir}") | |
| return 2 | |
| sessions = _discover_error_sessions(projects_dir, args.max_sessions) | |
| if not sessions: | |
| print(f"no error-bearing sessions under {projects_dir}") | |
| return 2 | |
| try: | |
| tok = AutoTokenizer.from_pretrained(args.model) | |
| except Exception as e: # noqa: BLE001 | |
| print(f"could not load tokenizer {args.model}: {e!r}") | |
| return 2 | |
| if not getattr(tok, "chat_template", None): | |
| print(f"{args.model} has no chat template; pick an -Instruct model") | |
| return 2 | |
| # ------------------------------------------------------------------ | |
| # Build error-bearing trace examples from the real sessions. | |
| # ------------------------------------------------------------------ | |
| def hint_gen(kind, _meta): | |
| return f"Recover from the {kind}: re-check the path/args before retrying." | |
| pad_id = tok.pad_token_id if tok.pad_token_id is not None else tok.eos_token_id | |
| cfg = CollatorConfig( | |
| hint_generator=hint_gen, | |
| enable_replay_dpo=False, | |
| max_seq_len=args.max_seq_len, | |
| pad_token_id=pad_id, | |
| ) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| err_examples: list[dict] = [] | |
| for path in sessions: | |
| try: | |
| ing = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=args.strip_thinking) | |
| states = list(ing.ingest(path)) | |
| examples = claude_states_to_trace_examples(states) | |
| for ex in examples: | |
| if any(t.get("tool_error") for t in ex["turns"]): | |
| err_examples.append(ex) | |
| except Exception as e: # noqa: BLE001 | |
| print(f" skip {path.name[:18]}: {e!r}") | |
| if not err_examples: | |
| print("no error-turn examples extracted — nothing to train on") | |
| return 2 | |
| print(f"extracted {len(err_examples)} error-bearing examples from " | |
| f"{len(sessions)} sessions") | |
| # Build ONE collated batch (cap to 4 examples to keep the CPU step cheap). | |
| batch = collator(err_examples[:args.max_examples]) | |
| if "sdpo_loss_mask" not in batch or "ctx_teacher_input_ids" not in batch: | |
| print("collated batch has no SDPO channel (no usable error sites) — " | |
| "cannot run the SDPO training smoke") | |
| return 2 | |
| in_loss = int((batch["sdpo_loss_mask"] == 1).sum().item()) | |
| if in_loss == 0: | |
| print("sdpo_loss_mask is all-zero (empty-recovery sites) — re-run " | |
| "WITHOUT --strip-thinking to recover SDPO signal") | |
| return 2 | |
| print(f"collated batch: input_ids {tuple(batch['input_ids'].shape)}, " | |
| f"sdpo_loss_mask in-loss positions = {in_loss}") | |
| # ------------------------------------------------------------------ | |
| # Real HF model + optimizer. CPU, fp32, tiny LR. | |
| # ------------------------------------------------------------------ | |
| print(f"loading {args.model} (CPU, {args.dtype}) ...") | |
| model_dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 | |
| try: | |
| model = AutoModelForCausalLM.from_pretrained(args.model, dtype=model_dtype) | |
| except Exception as e: # noqa: BLE001 | |
| print(f"could not load model {args.model}: {e!r}") | |
| return 2 | |
| model.train() | |
| opt = torch.optim.SGD(model.parameters(), lr=args.lr) | |
| # Snapshot a trainable parameter to prove it moves after .step(). | |
| watch_name, watch_param = next( | |
| (n, p) for n, p in model.named_parameters() if p.requires_grad | |
| ) | |
| before = watch_param.detach().clone() | |
| # Move batch tensors to model device. | |
| dev = next(model.parameters()).device | |
| inputs = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in batch.items()} | |
| sdpo_fired = False | |
| finite_all = True | |
| print("=" * 64) | |
| print(f"SDPO REAL-TRACE TRAINING SMOKE (alpha_sdpo={args.alpha_sdpo}, " | |
| f"steps={args.max_steps}, lr={args.lr})") | |
| print("=" * 64) | |
| try: | |
| for step in range(args.max_steps): | |
| opt.zero_grad(set_to_none=True) | |
| comps = compose_loss( | |
| model, | |
| inputs, | |
| alpha_sdpo=args.alpha_sdpo, | |
| beta_replay=0.0, # DPO channel off — no ref logprobs in smoke | |
| ) | |
| d = comps.detached() | |
| finite = all( | |
| (x == x) and (x not in (float("inf"), float("-inf"))) | |
| for x in d.values() | |
| ) | |
| finite_all = finite_all and finite | |
| if d["sdpo_jsd"] > 0.0: | |
| sdpo_fired = True | |
| comps.total.backward() | |
| opt.step() | |
| print(f" step {step}: total={d['total']:.5f} lm_ce={d['lm_ce']:.5f} " | |
| f"sdpo_jsd={d['sdpo_jsd']:.5f} finite={finite}") | |
| except Exception as e: # noqa: BLE001 | |
| print(f" CRASH during training: {e!r}") | |
| traceback.print_exc() | |
| return 1 | |
| moved = not torch.equal(before, watch_param.detach()) | |
| delta = float((watch_param.detach() - before).abs().max()) | |
| print("-" * 64) | |
| print(f" all losses finite: {finite_all}") | |
| print(f" SDPO channel fired (>0): {sdpo_fired}") | |
| print(f" param '{watch_name[:40]}' moved: {moved} (max|Δ|={delta:.2e})") | |
| ok = finite_all and sdpo_fired and moved | |
| print(f" RESULT: {'PASS ✅' if ok else 'FAIL ❌'}") | |
| if not ok: | |
| if not finite_all: | |
| print(" ✗ a loss was non-finite") | |
| if not sdpo_fired: | |
| print(" ✗ SDPO channel never fired (>0) — shape-gate skipped it " | |
| "or mask was empty; the distillation signal did not flow") | |
| if not moved: | |
| print(" ✗ watched parameter did not move — no real training step") | |
| return 0 if ok else 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |