Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 10,077 Bytes
03bf323 6806cf7 03bf323 6806cf7 03bf323 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """Adapter: ClaudeCodeIngester output → ComposerDataCollator input.
The ingester (`composer_replication.ingestion.claude_code.ClaudeCodeIngester`)
emits `TraceState` dicts with a `messages` field — a list of OpenAI-style
chat dicts. The data collator (`composer_replication.trainer.data_collator
.ComposerDataCollator`) expects `TraceExample` dicts with a `turns` field —
a list of `TraceTurn` dicts where each turn carries its own role, content,
and (critically) `tool_error` field for SDPO error-site detection.
This module bridges the two. The adapter:
1. Consumes a `TraceState` from the ingester.
2. Converts its `messages` (chat dicts) → `turns` (TraceTurns).
3. Detects tool-error sites by looking for the `[TOOL_RESULT (ERROR)]`
tag the ingester writes (per Claude Code's `is_error: true` flag in
the source JSONL).
4. Marks the assistant turn IMMEDIATELY AFTER an error tool-result with
`tool_error="<error_kind>"` so the data collator's
`_build_hint_injected_trace` recognizes it as an SDPO error site.
Usage:
from composer_replication.ingestion import ClaudeCodeIngester
from composer_replication.ingestion.trace_examples import (
claude_states_to_trace_examples,
)
from composer_replication.trainer.data_collator import (
ComposerDataCollator, CollatorConfig,
)
ingester = ClaudeCodeIngester()
states = list(ingester.ingest(session_jsonl_path))
examples = claude_states_to_trace_examples(states)
config = CollatorConfig(
hint_generator=lambda kind, meta: "Hint: try a different path.",
enable_replay_dpo=False,
)
collator = ComposerDataCollator(tokenizer=tok, config=config)
batch = collator(examples)
# batch now has properly-aligned ctx_teacher_input_ids + sdpo_loss_mask
This is the production-grade alignment path. Wave 18's
`examples/sdpo_with_real_traces/` is a wiring smoke that bypasses this
adapter; Wave 19's `examples/sdpo_with_real_traces_production/` uses
this adapter for the real alignment.
"""
from __future__ import annotations
import re
from typing import Any, Iterable, Mapping
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# The ingester writes this tag for tool_results where the source JSONL had
# is_error: true. We detect error sites by string-matching this tag in the
# user-turn content. Matches the `tag = "[TOOL_RESULT (ERROR)]"` literal
# in `composer_replication.ingestion.claude_code._serialize_user_content`.
TOOL_ERROR_TAG = "[TOOL_RESULT (ERROR)]"
# Heuristic: classify the error_kind by simple keyword match on the error
# content. The data collator's `hint_generator` receives this string as
# its first argument so the hint can be tailored. These categories are a
# minimal v0 set; users can extend by passing their own classifier
# function via the `error_kind_fn` parameter.
_ERROR_KIND_PATTERNS = [
# Order matters: command_not_found must come BEFORE file_not_found
# since "command not found" would also match a generic "not found".
("command_not_found", re.compile(r"(?i)command not found")),
("file_not_found", re.compile(r"(?i)\b(file does not exist|no such file or directory|file not found)\b")),
("permission_denied", re.compile(r"(?i)permission denied")),
("syntax_error", re.compile(r"(?i)syntax\s*error")),
("connection_error", re.compile(r"(?i)\b(connection|network|timeout) (error|refused)\b")),
]
def default_classify_error(content: str) -> str:
"""Classify a tool-error message into a short error_kind string.
Returns one of the named categories above, or "tool_error" for
anything unmatched. Users can override by passing their own
`error_kind_fn` to `claude_states_to_trace_examples`.
"""
for kind, pattern in _ERROR_KIND_PATTERNS:
if pattern.search(content):
return kind
return "tool_error"
def _user_turn_has_error(msg: Mapping[str, Any], flat_content: str) -> bool:
"""Decide whether a user-role turn is a tool-error site.
Precedence (Wave 20 — eliminate TOOL_ERROR_TAG string-coupling):
1. **Structural flag** — if the message dict carries an explicit
``tool_error`` key, trust it as the source of truth. The ingester sets
``tool_error: True`` whenever the source JSONL had ``is_error: true``.
A third-party producer can set ``tool_error: False`` to assert "no
error here" even if the rendered text happens to contain the tag.
2. **String-tag fallback** — only when no structural flag is present
(older serialized traces, or producers that never learned the boolean
contract) do we fall back to matching ``TOOL_ERROR_TAG`` in the
rendered content. This keeps backward compatibility without making the
brittle string match the primary path.
Returns True iff the turn should trigger SDPO error-site handling.
"""
structural = msg.get("tool_error")
if structural is not None:
return bool(structural)
return TOOL_ERROR_TAG in flat_content
# ---------------------------------------------------------------------------
# Adapter
# ---------------------------------------------------------------------------
def claude_states_to_trace_examples(
states: Iterable[Mapping[str, Any]],
*,
error_kind_fn=default_classify_error,
final_reward: float = 0.0,
) -> list[dict[str, Any]]:
"""Convert ClaudeCodeIngester TraceState dicts → TraceExample dicts.
Each input state's `messages` list (OpenAI chat dicts) is rewritten
as a `turns` list of TraceTurn dicts. Tool-error sites are detected
by matching the `[TOOL_RESULT (ERROR)]` tag in user-role messages
(the ingester writes this tag whenever the source JSONL had
`is_error: true`). When found, the assistant turn IMMEDIATELY after
the error tool-result gets its `tool_error` field populated, which
is what `ComposerDataCollator._build_hint_injected_trace` checks via
`_is_error_turn`.
Args:
states: iterable of TraceState dicts (from `ClaudeCodeIngester.ingest`).
error_kind_fn: callable(error_content) -> str for classifying
errors. Defaults to the keyword-match classifier above.
final_reward: scalar reward for the final assistant turn (the
collator threads this into the GRPO channel; defaults to 0
since Claude Code traces don't carry RLVR rewards natively).
Returns:
list[TraceExample] (TypedDict — `{trace_id, turns, final_reward,
dpo_pairs}`). dpo_pairs is omitted (Claude Code traces don't
carry chosen/rejected pairs; use `teacher_replay.extract_dpo_pairs`
for that channel separately).
"""
examples: list[dict[str, Any]] = []
for state in states:
msgs = state.get("messages", [])
turns: list[dict[str, Any]] = []
for i, msg in enumerate(msgs):
content = msg.get("content", "")
if isinstance(content, list):
# Defensive: some tokenizers / chat formats hand back lists.
content = "\n".join(
str(c.get("text", c)) if isinstance(c, dict) else str(c)
for c in content
)
role = msg.get("role", "")
turn: dict[str, Any] = {"role": role, "content": content}
# An assistant turn is an error site iff a recent preceding
# user-role turn contained the TOOL_ERROR_TAG. Walk backward
# through user turns until we hit either an error-tagged user
# turn (mark this assistant as the error recovery turn) or a
# different role / no error tag (no error site).
#
# This handles chains where an error tool_result is followed
# by additional user turns (e.g., a follow-up tool_result on
# a successful retry) before the assistant recovery turn.
if role == "assistant" and i > 0:
error_kind_found: str | None = None
error_content_found: str | None = None
for j in range(i - 1, -1, -1):
prev = msgs[j]
if prev.get("role") != "user":
break
prev_content = prev.get("content", "")
if isinstance(prev_content, list):
prev_content = "\n".join(
str(c.get("text", c)) if isinstance(c, dict) else str(c)
for c in prev_content
)
# STRUCTURAL detection (Wave 20): the ingester sets a
# `tool_error: True` boolean on user messages whose source
# JSONL had `is_error: true`. This is the source of truth.
# We fall back to string-matching the TOOL_ERROR_TAG only
# for messages that lack the structural flag (older traces
# or third-party producers that didn't set it) — see
# `_user_turn_has_error`.
if _user_turn_has_error(prev, prev_content):
error_kind_found = error_kind_fn(prev_content)
error_content_found = prev_content
break
if error_kind_found:
turn["tool_error"] = error_kind_found
turn["error_meta"] = {
"source_role": "user",
"source_content_excerpt": (error_content_found or "")[:200],
}
turns.append(turn)
if not turns:
continue
examples.append({
"trace_id": str(state.get("state_id", "")),
"turns": turns,
"final_reward": float(final_reward),
})
return examples
__all__ = [
"claude_states_to_trace_examples",
"default_classify_error",
"TOOL_ERROR_TAG",
]
|