Codeseys's picture
Wave 18: 14 backlog items closed + 3-reviewer cross-family review
54efac8
Raw
History Blame Contribute Delete
13.2 kB
"""SDPO wiring smoke through `ClaudeCodeIngester` on Qwen2.5-0.5B-Instruct (CPU, ~30s).
This is the third example in the SDPO progression:
examples/gsm8k_grpo/ -- plain GRPO, no SDPO (alpha=0)
examples/gsm8k_grpo_with_sdpo/ -- SDPO on hand-crafted GSM8K problems
examples/sdpo_with_real_traces/ -- SDPO column fed by `ClaudeCodeIngester`
Where the previous SDPO example fabricated chat messages from 2 GSM8K
prompts, this example uses `ClaudeCodeIngester` to read a Claude Code
session JSONL fixture (the same one pinned by Spike 007's tests) and
build an SDPO-shaped batch from the actual `TraceState` objects the
ingester emits.
NOTE — what this example does and does NOT prove:
✅ The `ClaudeCodeIngester` → `compose_loss` plumbing works end-to-end.
✅ The SDPO column fires (sdpo_jsd > 0); gradients flow; loss decreases.
❌ Production-quality SDPO alignment. The HINT is appended at the end
of the messages list, so student/teacher contexts diverge in their
right-edge tokens. The SDPO mask therefore covers DIFFERENT content
in student vs teacher (see README §"Why the SDPO signal here is
'wiring proof,' not 'production-quality'"). For real training,
use `composer_replication.trainer.data_collator.ComposerDataCollator`
which builds properly-aligned `ctx_teacher_input_ids` and
`sdpo_loss_mask` columns.
❌ Validation on a real-content session. The default fixture is
hand-authored (PII reasons). To run on real content, point
`FIXTURE_PATH` at one of your own `~/.claude/projects/...` files.
This *partially* closes the V5 gap from `docs/VISION_VALIDATION.md`
§ 4.3 — the ingestion pipeline is validated; the real-data run depends
on the user.
Usage:
pip install -e ".[train]"
python examples/sdpo_with_real_traces/run.py
Cross-references:
- composer_replication.ingestion.claude_code.ClaudeCodeIngester
- composer_replication.trainer.data_collator (production-grade alignment)
- examples/gsm8k_grpo_with_sdpo/run.py (sibling — synthetic prompts)
- spikes/007-real-trace-ingestion/fixtures/synthetic_session.jsonl
(the fixture Spike 007's tests pin against)
"""
from __future__ import annotations
import logging
import math
import sys
import time
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from composer_replication import compose_loss
from composer_replication.ingestion.claude_code import ClaudeCodeIngester
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct"
N_STEPS = 5
T = 64 # sequence length — bumped vs gsm8k example because real
# Claude Code traces have longer system + tool-result turns
LR = 1e-5
ALPHA_SDPO = 0.5
BETA_REPLAY = 0.0 # DPO column off — focus on SDPO with real traces
OUTPUT_DIR = Path(__file__).resolve().parent / "output"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
# The Spike 007 fixture is the only "real" Claude Code session JSONL we
# can ship in the repo (real user sessions contain PII). It's a short
# 8-message session (find Python files > 1MB) hand-authored to match
# the actual Claude Code v2.1.143 wire format.
FIXTURE_PATH = (
Path(__file__).resolve().parents[2]
/ "spikes" / "007-real-trace-ingestion" / "fixtures" / "synthetic_session.jsonl"
)
# Hint inserted at the END of the trace context (just before the
# assistant generation marker). This produces a non-zero SDPO signal
# but with a methodological caveat documented in the README:
#
# The SDPO loss-mask region (right-most ALIGN_LEN positions) covers
# DIFFERENT token content in student vs teacher — student's last
# tokens are the user's tool-result, teacher's last tokens are the
# HINT. So the JSD measured here reflects the model's prediction
# divergence on different inputs, NOT a clean per-position teacher-
# vs-student divergence on the same content.
#
# This is acceptable for a WIRING SMOKE TEST (proves the channel fires
# on real-trace input, gradients flow, code path doesn't crash) but
# would NOT be acceptable for a production training run. Production
# SDPO requires a `ComposerDataCollator` that emits properly-aligned
# `ctx_teacher_input_ids` and `sdpo_loss_mask` columns — see
# `composer_replication/trainer/data_collator.py:_build_sdpo_fields`.
#
# A more rigorous version of this example would: (a) take the assistant
# turn from the trace as the "target action" the student is predicting,
# (b) align student/teacher contexts so the assistant turn is at the
# same position in both, (c) place the HINT before that turn in the
# teacher only, (d) mask only the assistant-response positions. That's
# what the data_collator does in production. Out of scope for this
# wiring smoke test.
HINT = "Hint: prefer using `rg` over `find` for speed and simpler globs."
# ---------------------------------------------------------------------------
# Build inputs from a real Claude Code trace
# ---------------------------------------------------------------------------
def build_inputs_from_trace(tokenizer, fixture_path: Path) -> tuple[dict[str, torch.Tensor], int]:
"""Ingest fixture_path → TraceStates → SDPO-shaped batch.
Each TraceState has a `["messages"]` field (list of OpenAI-style chat
dicts up to and including the assistant turn the state was sliced
at). We treat the LAST `B` TraceStates as our batch, with:
- student context = state["messages"] (no hint)
- teacher context = state["messages"] + [{"role": "system", "content": HINT}]
Returns (inputs_dict, n_states_total).
"""
ingester = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=True)
states = list(ingester.ingest(fixture_path))
if not states:
raise RuntimeError(
f"No TraceState yielded from {fixture_path}; check the fixture is valid."
)
# Use the last 2 states (mimicking B=2 from gsm8k_grpo_with_sdpo)
# for direct comparison. If the trace has fewer than 2 assistant
# turns, replicate.
if len(states) >= 2:
batch_states = states[-2:]
else:
batch_states = states + states # replicate to fill B=2
student_msgs = [list(s["messages"]) for s in batch_states]
# Insert HINT as a system turn at the END of the messages list
# (just before apply_chat_template appends the assistant marker).
# This guarantees the hint survives left-truncation AND is
# positionally close enough to the assistant turn to actually
# perturb the model's logits.
teacher_msgs = [
list(s["messages"]) + [{"role": "system", "content": HINT}]
for s in batch_states
]
student_strs = [
tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
for m in student_msgs
]
teacher_strs = [
tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True)
for m in teacher_msgs
]
# LEFT-pad + LEFT-truncate so right edge (the assistant marker + final
# user/tool turn) aligns across student and teacher and across the
# batch. Same alignment discipline as gsm8k_grpo_with_sdpo per Wave 17b.
original_pad = tokenizer.padding_side
original_trunc = tokenizer.truncation_side
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
try:
s_tok = tokenizer(
student_strs, max_length=T, truncation=True,
padding="max_length", return_tensors="pt",
)
t_tok = tokenizer(
teacher_strs, max_length=T, truncation=True,
padding="max_length", return_tensors="pt",
)
finally:
tokenizer.padding_side = original_pad
tokenizer.truncation_side = original_trunc
B = s_tok["input_ids"].shape[0]
ALIGN_LEN = T // 2
response_mask = torch.zeros(B, T, dtype=torch.long)
response_mask[:, -ALIGN_LEN:] = 1
sdpo_loss_mask = response_mask.clone()
inputs = {
"input_ids": s_tok["input_ids"],
"response_mask": response_mask,
"ctx_teacher_input_ids": t_tok["input_ids"],
"sdpo_loss_mask": sdpo_loss_mask,
}
return inputs, len(states)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> int:
torch.manual_seed(42)
log_path = OUTPUT_DIR.parent / "run.log"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(log_path, mode="w")],
)
log = logging.getLogger("sdpo_with_real_traces")
log.info("=" * 64)
log.info("SDPO + REAL Claude Code traces + Qwen2.5-0.5B-Instruct (CPU)")
log.info("=" * 64)
# Verify fixture exists before model load — fast-fail
if not FIXTURE_PATH.is_file():
log.error("Fixture not found at %s", FIXTURE_PATH)
log.error(
"This example requires the Spike 007 fixture at "
"spikes/007-real-trace-ingestion/fixtures/synthetic_session.jsonl. "
"If you're running outside the repo, point this script at any "
"Claude Code session JSONL file."
)
return 1
log.info("[1/5] Fixture path: %s (size=%d bytes)",
FIXTURE_PATH, FIXTURE_PATH.stat().st_size)
log.info("[2/5] Loading model + tokenizer ...")
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32)
model.to("cpu")
n_params = sum(p.numel() for p in model.parameters())
log.info(" loaded in %.1fs (%.3fB params)", time.time() - t0, n_params / 1e9)
log.info("[3/5] Ingesting trace + building SDPO batch (T=%d) ...", T)
inputs, n_states = build_inputs_from_trace(tokenizer, FIXTURE_PATH)
for k, v in inputs.items():
log.info(" %s: shape=%s dtype=%s", k, tuple(v.shape), v.dtype)
B = inputs["input_ids"].shape[0]
log.info("[4/5] Running %d SGD steps with alpha_sdpo=%.2f on B=%d ...",
N_STEPS, ALPHA_SDPO, B)
optim = torch.optim.SGD(model.parameters(), lr=LR)
history: list[dict[str, float]] = []
model.train()
t0 = time.time()
for step in range(N_STEPS):
optim.zero_grad()
out = compose_loss(
model, inputs,
alpha_sdpo=ALPHA_SDPO, beta_replay=BETA_REPLAY,
)
out.total.backward()
gnorm = sum(
p.grad.abs().sum().item() for p in model.parameters() if p.grad is not None
)
optim.step()
components = out.detached()
components["grad_norm"] = gnorm
history.append(components)
log.info(
" step %d/%d: total=%.4f lm_ce=%.4f sdpo_jsd=%.4f trace_replay_dpo=%.4f |grad|=%.2e",
step + 1, N_STEPS,
components["total"], components["lm_ce"],
components["sdpo_jsd"], components["trace_replay_dpo"],
gnorm,
)
dt = time.time() - t0
log.info("Training complete in %.1fs (avg %.1fs/step)", dt, dt / N_STEPS)
log.info("[5/5] Verifying SDPO column wiring on real trace ...")
sdpo_values = [h["sdpo_jsd"] for h in history]
assert all(s > 0.0 for s in sdpo_values), (
f"SDPO column is identically zero — channel did not fire on real trace. "
f"sdpo_jsd values: {sdpo_values}"
)
log.info(" ✓ sdpo_jsd > 0 at every step (min=%.4f max=%.4f)",
min(sdpo_values), max(sdpo_values))
diffs = [abs(h["total"] - h["lm_ce"]) for h in history]
assert all(d > 1e-6 for d in diffs), (
f"total ≈ lm_ce — SDPO contribution negligible. abs(total-lm_ce)={diffs}"
)
log.info(" ✓ total != lm_ce at every step (min |diff|=%.4f max=%.4f)",
min(diffs), max(diffs))
gnorms = [h["grad_norm"] for h in history]
assert all(g > 0.0 for g in gnorms), f"Some steps had zero |grad|: {gnorms}"
assert all(math.isfinite(g) for g in gnorms), f"Non-finite |grad|: {gnorms}"
log.info(" ✓ |grad| > 0 and finite at every step (min=%.2e max=%.2e)",
min(gnorms), max(gnorms))
log.info("=" * 64)
log.info("Summary")
log.info("=" * 64)
log.info(" trace fixture: %s", FIXTURE_PATH.name)
log.info(" ingested states: %d (used last %d for batch)", n_states, B)
log.info(" steps: %d", N_STEPS)
log.info(" alpha_sdpo: %.2f", ALPHA_SDPO)
log.info(" model params: %.3fB", n_params / 1e9)
log.info(" total step 1: %.4f", history[0]["total"])
log.info(" total step %d: %.4f", N_STEPS, history[-1]["total"])
log.info(" wall-clock: %.1fs", dt)
log.info("=" * 64)
log.info("✅ SDPO column wiring verified end-to-end on ClaudeCodeIngester output.")
return 0
if __name__ == "__main__":
sys.exit(main())