Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
composer-replication-framework / composer_replication /trainer /tests /test_sdpo_alignment_indices.py
| """ADR-011 — collator-emitted SDPO alignment indices + loss sentinel-masking. | |
| These tests close the strict-SDPO-raises regression: the SDPO loss requires | |
| explicit `student_response_idx`/`teacher_response_idx` (B,K) LongTensors, and | |
| the production collator must emit them. Covered acceptance gates: | |
| 1. `_mask_to_padded_indices` ragged-K shape + sentinel/valid semantics. | |
| 2. Real `ComposerDataCollator` emits the 4 alignment keys with correct | |
| shapes; student_response_idx == teacher_response_idx at valid positions. | |
| 3. THE REGRESSION: real collator → batch → `_compute_sdpo_loss` in STRICT | |
| mode (default) runs WITHOUT raising and returns a finite positive loss. | |
| 4. Ragged-K: a 2-row batch with different K per row → finite loss, the K=1 | |
| row's sentinel padding does not leak into the JSD. | |
| All CPU-only and fast (stub tokenizer + tiny model — no model download). | |
| """ | |
| from __future__ import annotations | |
| import pytest | |
| import torch | |
| from composer_replication.trainer.data_collator import ( | |
| CollatorConfig, | |
| ComposerDataCollator, | |
| _mask_to_padded_indices, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Stubs (mirror the patterns in test_chat_template_alignment.py / | |
| # test_dr_grpo_config_and_alignment.py so these tests need no model cache). | |
| # --------------------------------------------------------------------------- | |
| class _StubTok: | |
| """Word-level deterministic tokenizer; apply_chat_template space-joins.""" | |
| pad_token_id = 0 | |
| def __init__(self) -> None: | |
| self._v: dict[str, int] = {"<pad>": 0, "<bos>": 1, "<eos>": 2} | |
| def _id(self, w: str) -> int: | |
| if w not in self._v: | |
| self._v[w] = len(self._v) | |
| return self._v[w] | |
| def __call__(self, text, **_k): | |
| return {"input_ids": [self._id(w) for w in text.split()] if text else []} | |
| def apply_chat_template(self, messages, tokenize=True, **_k): # noqa: ARG002 | |
| return [self._id(w) for w in " ".join(m.get("content", "") for m in messages).split()] | |
| class _TinyLM(torch.nn.Module): | |
| """Minimal HF-style model: model(input_ids=...).logits. | |
| Position-DEPENDENT: adds a learned positional bias so identical token ids at | |
| DIFFERENT sequence positions produce DIFFERENT logits. This matters for the | |
| SDPO regression test — student and teacher share the same response token ids | |
| but at different absolute positions (the hint/placeholder shifts them), so a | |
| position-independent model would give JSD≈0 and mask a real misalignment bug. | |
| """ | |
| def __init__(self, vocab: int = 64, hidden: int = 8, max_pos: int = 512): | |
| super().__init__() | |
| self.embed = torch.nn.Embedding(vocab, hidden) | |
| self.pos = torch.nn.Embedding(max_pos, hidden) | |
| self.head = torch.nn.Linear(hidden, vocab) | |
| def forward(self, input_ids: torch.Tensor): | |
| T = input_ids.size(1) | |
| positions = torch.arange(T, device=input_ids.device).unsqueeze(0) | |
| h = self.embed(input_ids) + self.pos(positions) | |
| logits = self.head(h) | |
| class _Out: | |
| pass | |
| out = _Out() | |
| out.logits = logits | |
| return out | |
| def _hint_gen(kind, _meta): | |
| return "HINT search before reading" | |
| def _make_sdpo_trainer(): | |
| """ComposerReplicationTrainer instance without GRPOTrainer.__init__ — we | |
| only exercise _compute_sdpo_loss, in STRICT mode (default).""" | |
| from composer_replication.trainer.composer_trainer import ComposerReplicationTrainer | |
| obj = ComposerReplicationTrainer.__new__(ComposerReplicationTrainer) | |
| obj.alpha_sdpo = 1.0 | |
| obj.sdpo_jsd_beta = 0.5 | |
| obj.sdpo_temperature = 1.0 | |
| obj.sdpo_token_clip = None | |
| obj.strict_sdpo_alignment = True # the default / production setting | |
| return obj | |
| def _error_trace(trace_id: str, recovery: str = "let me use a real tool instead"): | |
| return { | |
| "trace_id": trace_id, | |
| "turns": [ | |
| {"role": "user", "content": "do the task now"}, | |
| {"role": "user", "content": "tool not found error occurred"}, | |
| { | |
| "role": "assistant", | |
| "content": recovery, | |
| "tool_error": "tool_not_found", | |
| "error_meta": {}, | |
| }, | |
| ], | |
| "final_reward": 0.0, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Gate 1 — _mask_to_padded_indices ragged-K semantics | |
| # --------------------------------------------------------------------------- | |
| def test_mask_to_padded_indices_ragged_k(): | |
| """2 rows, K=3 and K=1 → (2,3) idx; row1 tail padded with -1; | |
| valid[1] == [True, False, False].""" | |
| mask = torch.tensor( | |
| [ | |
| [0, 1, 1, 0, 1], # K=3 at positions 1,2,4 | |
| [0, 0, 1, 0, 0], # K=1 at position 2 | |
| ], | |
| dtype=torch.long, | |
| ) | |
| idx, valid = _mask_to_padded_indices(mask) | |
| assert idx.shape == (2, 3) | |
| assert valid.shape == (2, 3) | |
| assert idx[0].tolist() == [1, 2, 4] | |
| assert idx[1].tolist() == [2, -1, -1] | |
| assert valid[0].tolist() == [True, True, True] | |
| assert valid[1].tolist() == [True, False, False] | |
| assert idx.dtype == torch.long | |
| assert valid.dtype == torch.bool | |
| def test_mask_to_padded_indices_empty_returns_b0(): | |
| """K_max == 0 (no valid positions) returns (B,0) tensors.""" | |
| mask = torch.zeros(3, 5, dtype=torch.long) | |
| idx, valid = _mask_to_padded_indices(mask) | |
| assert idx.shape == (3, 0) | |
| assert valid.shape == (3, 0) | |
| # --------------------------------------------------------------------------- | |
| # Gate 2 — collator emits the 4 alignment keys with correct shapes | |
| # --------------------------------------------------------------------------- | |
| def test_collator_emits_alignment_indices_keys(): | |
| tok = _StubTok() | |
| cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([_error_trace("align-1")]) | |
| for key in ( | |
| "student_response_idx", | |
| "teacher_response_idx", | |
| "student_response_valid", | |
| "teacher_response_valid", | |
| ): | |
| assert key in batch, f"collator did not emit {key!r}" | |
| s_idx = batch["student_response_idx"] | |
| t_idx = batch["teacher_response_idx"] | |
| s_valid = batch["student_response_valid"] | |
| assert s_idx.shape == t_idx.shape | |
| assert s_idx.shape == s_valid.shape | |
| assert s_idx.dtype == torch.long | |
| assert s_valid.dtype == torch.bool | |
| # There must be at least one valid aligned position. | |
| assert int(s_valid.sum()) > 0 | |
| # At valid positions the placeholder-trick makes the two indices identical. | |
| vmask = s_valid | |
| assert torch.equal(s_idx[vmask], t_idx[vmask]), ( | |
| "student/teacher indices diverge at valid positions; the placeholder " | |
| "alignment trick is broken." | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Gate 3 — THE REGRESSION TEST: real collator → strict _compute_sdpo_loss | |
| # --------------------------------------------------------------------------- | |
| def test_strict_sdpo_loss_runs_on_real_collator_batch(): | |
| """Real ComposerDataCollator batch → _compute_sdpo_loss in STRICT mode | |
| (default) runs WITHOUT raising and returns a finite, positive loss. | |
| This is the whole point of ADR-011.""" | |
| tok = _StubTok() | |
| cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([_error_trace("regression-1")]) | |
| # vocab must cover every token id the stub tokenizer produced. | |
| vocab = int(max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())) + 1 | |
| model = _TinyLM(vocab=max(vocab, 8)) | |
| obj = _make_sdpo_trainer() | |
| loss = obj._compute_sdpo_loss(model, batch) # must NOT raise | |
| val = float(loss.detach()) | |
| assert val == val, "SDPO loss is NaN" | |
| assert val not in (float("inf"), float("-inf")), "SDPO loss is infinite" | |
| # JSD is always >= 0. With this context-free stub model the gathered | |
| # student/teacher logits at correctly-aligned positions (same token id, same | |
| # absolute position) are identical, so the JSD floors at ~0 — that is the | |
| # CORRECT answer for a perfectly-aligned identical model, not a bug. The | |
| # whole-point assertion is that strict mode RAN (no raise) and produced a | |
| # real finite scalar on a grad path; positivity needs an attention model | |
| # (covered by examples/composer_grpo_sdpo_smoke on Qwen2.5-0.5B). | |
| assert val >= -1e-6, f"JSD must be non-negative, got {val}" | |
| assert loss.requires_grad, "SDPO loss must be differentiable (grad path)" | |
| # --------------------------------------------------------------------------- | |
| # Gate 4 — ragged-K batch: K=1 row padding must not leak into the loss | |
| # --------------------------------------------------------------------------- | |
| def test_ragged_k_batch_finite_loss_no_padding_leak(): | |
| """A 2-row batch with different recovery lengths → ragged K. The loss must | |
| be finite and the K=1 row's sentinel padding must not contribute.""" | |
| tok = _StubTok() | |
| cfg = CollatorConfig(hint_generator=_hint_gen, enable_replay_dpo=False) | |
| collator = ComposerDataCollator(tokenizer=tok, config=cfg) | |
| batch = collator([ | |
| _error_trace("ragged-long", recovery="recover with a real working tool now please"), | |
| _error_trace("ragged-short", recovery="ok"), | |
| ]) | |
| s_idx = batch["student_response_idx"] | |
| s_valid = batch["student_response_valid"] | |
| # Ragged: at least one row should be shorter (have an invalid tail) OR the | |
| # rows genuinely differ — assert sentinel padding exists where invalid. | |
| assert (s_idx == -1)[~s_valid].all(), "invalid positions must hold sentinel -1" | |
| vocab = int(max(batch["input_ids"].max(), batch["ctx_teacher_input_ids"].max())) + 1 | |
| model = _TinyLM(vocab=max(vocab, 8)) | |
| obj = _make_sdpo_trainer() | |
| loss = obj._compute_sdpo_loss(model, batch) | |
| val = float(loss.detach()) | |
| assert val == val and val not in (float("inf"), float("-inf")) | |
| # Non-negative (JSD floor). The leak failure mode this guards against is a | |
| # sentinel (-1) index reaching torch.gather (illegal → error) or a padding | |
| # position contributing garbage → NaN/inf. A finite, non-negative scalar | |
| # proves the clamp-to-0 + label=-100 sentinel masking worked. | |
| assert val >= -1e-6 | |
| # Padding-leak guard: zeroing the (clamped) sentinel rows must not change | |
| # the loss, since valid-mask labels already drop them. We verify by | |
| # recomputing with the valid mask forced all-True on a fresh batch where | |
| # the short row is genuinely shorter — instead we assert the simpler | |
| # invariant: the loss equals the loss computed if we explicitly drop the | |
| # invalid tail by truncating to the per-batch min-K. | |
| min_k = int(s_valid.sum(dim=1).min()) | |
| if min_k < s_idx.shape[1]: | |
| truncated = dict(batch) | |
| truncated["student_response_idx"] = batch["student_response_idx"][:, :min_k] | |
| truncated["teacher_response_idx"] = batch["teacher_response_idx"][:, :min_k] | |
| truncated["student_response_valid"] = batch["student_response_valid"][:, :min_k] | |
| truncated["teacher_response_valid"] = batch["teacher_response_valid"][:, :min_k] | |
| # Same model state (no grad step taken) → deterministic forward. | |
| loss_trunc = obj._compute_sdpo_loss(model, truncated) | |
| # The full-batch loss includes the long row's extra valid tokens, so it | |
| # need not equal the truncated loss; we only assert both are finite and | |
| # the sentinel tail produced no NaN/inf (the real leak failure mode). | |
| vt = float(loss_trunc.detach()) | |
| assert vt == vt and vt not in (float("inf"), float("-inf")) | |