"""SDPO wiring smoke through `ClaudeCodeIngester` on Qwen2.5-0.5B-Instruct (CPU, ~30s). This is the third example in the SDPO progression: examples/gsm8k_grpo/ -- plain GRPO, no SDPO (alpha=0) examples/gsm8k_grpo_with_sdpo/ -- SDPO on hand-crafted GSM8K problems examples/sdpo_with_real_traces/ -- SDPO column fed by `ClaudeCodeIngester` Where the previous SDPO example fabricated chat messages from 2 GSM8K prompts, this example uses `ClaudeCodeIngester` to read a Claude Code session JSONL fixture (the same one pinned by Spike 007's tests) and build an SDPO-shaped batch from the actual `TraceState` objects the ingester emits. NOTE — what this example does and does NOT prove: ✅ The `ClaudeCodeIngester` → `compose_loss` plumbing works end-to-end. ✅ The SDPO column fires (sdpo_jsd > 0); gradients flow; loss decreases. ❌ Production-quality SDPO alignment. The HINT is appended at the end of the messages list, so student/teacher contexts diverge in their right-edge tokens. The SDPO mask therefore covers DIFFERENT content in student vs teacher (see README §"Why the SDPO signal here is 'wiring proof,' not 'production-quality'"). For real training, use `composer_replication.trainer.data_collator.ComposerDataCollator` which builds properly-aligned `ctx_teacher_input_ids` and `sdpo_loss_mask` columns. ❌ Validation on a real-content session. The default fixture is hand-authored (PII reasons). To run on real content, point `FIXTURE_PATH` at one of your own `~/.claude/projects/...` files. This *partially* closes the V5 gap from `docs/VISION_VALIDATION.md` § 4.3 — the ingestion pipeline is validated; the real-data run depends on the user. Usage: pip install -e ".[train]" python examples/sdpo_with_real_traces/run.py Cross-references: - composer_replication.ingestion.claude_code.ClaudeCodeIngester - composer_replication.trainer.data_collator (production-grade alignment) - examples/gsm8k_grpo_with_sdpo/run.py (sibling — synthetic prompts) - spikes/007-real-trace-ingestion/fixtures/synthetic_session.jsonl (the fixture Spike 007's tests pin against) """ from __future__ import annotations import logging import math import sys import time from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer from composer_replication import compose_loss from composer_replication.ingestion.claude_code import ClaudeCodeIngester # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- MODEL_REPO = "Qwen/Qwen2.5-0.5B-Instruct" N_STEPS = 5 T = 64 # sequence length — bumped vs gsm8k example because real # Claude Code traces have longer system + tool-result turns LR = 1e-5 ALPHA_SDPO = 0.5 BETA_REPLAY = 0.0 # DPO column off — focus on SDPO with real traces OUTPUT_DIR = Path(__file__).resolve().parent / "output" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # The Spike 007 fixture is the only "real" Claude Code session JSONL we # can ship in the repo (real user sessions contain PII). It's a short # 8-message session (find Python files > 1MB) hand-authored to match # the actual Claude Code v2.1.143 wire format. FIXTURE_PATH = ( Path(__file__).resolve().parents[2] / "spikes" / "007-real-trace-ingestion" / "fixtures" / "synthetic_session.jsonl" ) # Hint inserted at the END of the trace context (just before the # assistant generation marker). This produces a non-zero SDPO signal # but with a methodological caveat documented in the README: # # The SDPO loss-mask region (right-most ALIGN_LEN positions) covers # DIFFERENT token content in student vs teacher — student's last # tokens are the user's tool-result, teacher's last tokens are the # HINT. So the JSD measured here reflects the model's prediction # divergence on different inputs, NOT a clean per-position teacher- # vs-student divergence on the same content. # # This is acceptable for a WIRING SMOKE TEST (proves the channel fires # on real-trace input, gradients flow, code path doesn't crash) but # would NOT be acceptable for a production training run. Production # SDPO requires a `ComposerDataCollator` that emits properly-aligned # `ctx_teacher_input_ids` and `sdpo_loss_mask` columns — see # `composer_replication/trainer/data_collator.py:_build_sdpo_fields`. # # A more rigorous version of this example would: (a) take the assistant # turn from the trace as the "target action" the student is predicting, # (b) align student/teacher contexts so the assistant turn is at the # same position in both, (c) place the HINT before that turn in the # teacher only, (d) mask only the assistant-response positions. That's # what the data_collator does in production. Out of scope for this # wiring smoke test. HINT = "Hint: prefer using `rg` over `find` for speed and simpler globs." # --------------------------------------------------------------------------- # Build inputs from a real Claude Code trace # --------------------------------------------------------------------------- def build_inputs_from_trace(tokenizer, fixture_path: Path) -> tuple[dict[str, torch.Tensor], int]: """Ingest fixture_path → TraceStates → SDPO-shaped batch. Each TraceState has a `["messages"]` field (list of OpenAI-style chat dicts up to and including the assistant turn the state was sliced at). We treat the LAST `B` TraceStates as our batch, with: - student context = state["messages"] (no hint) - teacher context = state["messages"] + [{"role": "system", "content": HINT}] Returns (inputs_dict, n_states_total). """ ingester = ClaudeCodeIngester(skip_sidechain=True, strip_thinking=True) states = list(ingester.ingest(fixture_path)) if not states: raise RuntimeError( f"No TraceState yielded from {fixture_path}; check the fixture is valid." ) # Use the last 2 states (mimicking B=2 from gsm8k_grpo_with_sdpo) # for direct comparison. If the trace has fewer than 2 assistant # turns, replicate. if len(states) >= 2: batch_states = states[-2:] else: batch_states = states + states # replicate to fill B=2 student_msgs = [list(s["messages"]) for s in batch_states] # Insert HINT as a system turn at the END of the messages list # (just before apply_chat_template appends the assistant marker). # This guarantees the hint survives left-truncation AND is # positionally close enough to the assistant turn to actually # perturb the model's logits. teacher_msgs = [ list(s["messages"]) + [{"role": "system", "content": HINT}] for s in batch_states ] student_strs = [ tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in student_msgs ] teacher_strs = [ tokenizer.apply_chat_template(m, tokenize=False, add_generation_prompt=True) for m in teacher_msgs ] # LEFT-pad + LEFT-truncate so right edge (the assistant marker + final # user/tool turn) aligns across student and teacher and across the # batch. Same alignment discipline as gsm8k_grpo_with_sdpo per Wave 17b. original_pad = tokenizer.padding_side original_trunc = tokenizer.truncation_side tokenizer.padding_side = "left" tokenizer.truncation_side = "left" try: s_tok = tokenizer( student_strs, max_length=T, truncation=True, padding="max_length", return_tensors="pt", ) t_tok = tokenizer( teacher_strs, max_length=T, truncation=True, padding="max_length", return_tensors="pt", ) finally: tokenizer.padding_side = original_pad tokenizer.truncation_side = original_trunc B = s_tok["input_ids"].shape[0] ALIGN_LEN = T // 2 response_mask = torch.zeros(B, T, dtype=torch.long) response_mask[:, -ALIGN_LEN:] = 1 sdpo_loss_mask = response_mask.clone() inputs = { "input_ids": s_tok["input_ids"], "response_mask": response_mask, "ctx_teacher_input_ids": t_tok["input_ids"], "sdpo_loss_mask": sdpo_loss_mask, } return inputs, len(states) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> int: torch.manual_seed(42) log_path = OUTPUT_DIR.parent / "run.log" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", handlers=[logging.StreamHandler(sys.stdout), logging.FileHandler(log_path, mode="w")], ) log = logging.getLogger("sdpo_with_real_traces") log.info("=" * 64) log.info("SDPO + REAL Claude Code traces + Qwen2.5-0.5B-Instruct (CPU)") log.info("=" * 64) # Verify fixture exists before model load — fast-fail if not FIXTURE_PATH.is_file(): log.error("Fixture not found at %s", FIXTURE_PATH) log.error( "This example requires the Spike 007 fixture at " "spikes/007-real-trace-ingestion/fixtures/synthetic_session.jsonl. " "If you're running outside the repo, point this script at any " "Claude Code session JSONL file." ) return 1 log.info("[1/5] Fixture path: %s (size=%d bytes)", FIXTURE_PATH, FIXTURE_PATH.stat().st_size) log.info("[2/5] Loading model + tokenizer ...") t0 = time.time() tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained(MODEL_REPO, torch_dtype=torch.float32) model.to("cpu") n_params = sum(p.numel() for p in model.parameters()) log.info(" loaded in %.1fs (%.3fB params)", time.time() - t0, n_params / 1e9) log.info("[3/5] Ingesting trace + building SDPO batch (T=%d) ...", T) inputs, n_states = build_inputs_from_trace(tokenizer, FIXTURE_PATH) for k, v in inputs.items(): log.info(" %s: shape=%s dtype=%s", k, tuple(v.shape), v.dtype) B = inputs["input_ids"].shape[0] log.info("[4/5] Running %d SGD steps with alpha_sdpo=%.2f on B=%d ...", N_STEPS, ALPHA_SDPO, B) optim = torch.optim.SGD(model.parameters(), lr=LR) history: list[dict[str, float]] = [] model.train() t0 = time.time() for step in range(N_STEPS): optim.zero_grad() out = compose_loss( model, inputs, alpha_sdpo=ALPHA_SDPO, beta_replay=BETA_REPLAY, ) out.total.backward() gnorm = sum( p.grad.abs().sum().item() for p in model.parameters() if p.grad is not None ) optim.step() components = out.detached() components["grad_norm"] = gnorm history.append(components) log.info( " step %d/%d: total=%.4f lm_ce=%.4f sdpo_jsd=%.4f trace_replay_dpo=%.4f |grad|=%.2e", step + 1, N_STEPS, components["total"], components["lm_ce"], components["sdpo_jsd"], components["trace_replay_dpo"], gnorm, ) dt = time.time() - t0 log.info("Training complete in %.1fs (avg %.1fs/step)", dt, dt / N_STEPS) log.info("[5/5] Verifying SDPO column wiring on real trace ...") sdpo_values = [h["sdpo_jsd"] for h in history] assert all(s > 0.0 for s in sdpo_values), ( f"SDPO column is identically zero — channel did not fire on real trace. " f"sdpo_jsd values: {sdpo_values}" ) log.info(" ✓ sdpo_jsd > 0 at every step (min=%.4f max=%.4f)", min(sdpo_values), max(sdpo_values)) diffs = [abs(h["total"] - h["lm_ce"]) for h in history] assert all(d > 1e-6 for d in diffs), ( f"total ≈ lm_ce — SDPO contribution negligible. abs(total-lm_ce)={diffs}" ) log.info(" ✓ total != lm_ce at every step (min |diff|=%.4f max=%.4f)", min(diffs), max(diffs)) gnorms = [h["grad_norm"] for h in history] assert all(g > 0.0 for g in gnorms), f"Some steps had zero |grad|: {gnorms}" assert all(math.isfinite(g) for g in gnorms), f"Non-finite |grad|: {gnorms}" log.info(" ✓ |grad| > 0 and finite at every step (min=%.2e max=%.2e)", min(gnorms), max(gnorms)) log.info("=" * 64) log.info("Summary") log.info("=" * 64) log.info(" trace fixture: %s", FIXTURE_PATH.name) log.info(" ingested states: %d (used last %d for batch)", n_states, B) log.info(" steps: %d", N_STEPS) log.info(" alpha_sdpo: %.2f", ALPHA_SDPO) log.info(" model params: %.3fB", n_params / 1e9) log.info(" total step 1: %.4f", history[0]["total"]) log.info(" total step %d: %.4f", N_STEPS, history[-1]["total"]) log.info(" wall-clock: %.1fs", dt) log.info("=" * 64) log.info("✅ SDPO column wiring verified end-to-end on ClaudeCodeIngester output.") return 0 if __name__ == "__main__": sys.exit(main())