add code/ loader snapshot
Browse files- code/data/__init__.py +0 -0
- code/data/dataset.py +171 -0
- code/data/ground_truth.py +106 -0
- code/data/sources.py +38 -0
- code/data/trace_format.py +229 -0
- code/eval/__init__.py +0 -0
- code/eval/eval_cruxeval_codi.py +119 -0
- code/eval/eval_cruxeval_sft.py +129 -0
- code/tokens.py +39 -0
- code/train/__init__.py +0 -0
- code/train/train_codi.py +280 -0
- code/wb.py +18 -0
code/data/__init__.py
ADDED
|
File without changes
|
code/data/dataset.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
CRUXEval-O dataset: deterministic train/val split + ground-truth execution
|
| 5 |
+
traces -> (input_ids, labels) for teacher-forcing CODI.
|
| 6 |
+
|
| 7 |
+
Neutral data layer shared by training (``cwm.training.data``) and eval
|
| 8 |
+
(``evals.cruxeval.run_eval_codi``); depends on nothing in either, so the
|
| 9 |
+
split and trace format never drift. Thin HuggingFace-tokenizer wrapper over
|
| 10 |
+
the verbatim Table 9 trace generator (``.ground_truth`` / ``.trace_format``):
|
| 11 |
+
build the seeded prompt, tokenize ``prompt + render_frames_to_generation(frames)``,
|
| 12 |
+
and mask the prompt out of the labels (teacher-forced, so labels == input_ids
|
| 13 |
+
with the prompt prefix set to ``-100``).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from .ground_truth import ground_truth_trace, make_trace_context
|
| 19 |
+
from .trace_format import (
|
| 20 |
+
ACTION_SEP,
|
| 21 |
+
LINE_SEP,
|
| 22 |
+
TraceEvent,
|
| 23 |
+
render_frames_to_generation,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
IGNORE_INDEX = -100
|
| 27 |
+
def _prompt_str(code: str, input_str: str) -> str:
|
| 28 |
+
ctx = make_trace_context(code, input_str)
|
| 29 |
+
return f"<|trace_context_start|>{ctx}<|frame_sep|><|call_sep|>{{}}<|action_sep|>def main():\n<|frame_sep|>"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _tokenize_trace(code, input_str, tokenizer, *, max_seq_len, max_frames):
|
| 33 |
+
"""``(prompt_ids, trace_ids, spans)``; None to skip. Trace must terminate in
|
| 34 |
+
RETURN/EXCEPTION and have >=1 LINE span. Span ``(i, j)``: ``trace_ids[i]`` is
|
| 35 |
+
``<|line_sep|>``, ``j`` its ``<|action_sep|>``, ``trace_ids[i+1:j]`` the locals
|
| 36 |
+
a CODI student swaps for a latent block. Single source of membership so the SFT
|
| 37 |
+
baseline and CODI train on identical data."""
|
| 38 |
+
frames, error = ground_truth_trace(code, input_str, align_to_prompt=True, max_frames=max_frames)
|
| 39 |
+
if not frames or error == "frames_exceeded":
|
| 40 |
+
return None
|
| 41 |
+
if frames[-1].event not in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
|
| 42 |
+
return None
|
| 43 |
+
# Qwen has no BOS (bos_token_id is None); CWM did. Prepend only if present.
|
| 44 |
+
bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
|
| 45 |
+
prompt_ids = bos + tokenizer.encode(_prompt_str(code, input_str), add_special_tokens=False)
|
| 46 |
+
trace_ids = tokenizer.encode(render_frames_to_generation(frames), add_special_tokens=False)
|
| 47 |
+
if len(prompt_ids) + len(trace_ids) > max_seq_len:
|
| 48 |
+
return None
|
| 49 |
+
ls = tokenizer.convert_tokens_to_ids(LINE_SEP)
|
| 50 |
+
asep = tokenizer.convert_tokens_to_ids(ACTION_SEP)
|
| 51 |
+
spans, i, n = [], 0, len(trace_ids)
|
| 52 |
+
while i < n:
|
| 53 |
+
if trace_ids[i] == ls:
|
| 54 |
+
j = i + 1
|
| 55 |
+
while j < n and trace_ids[j] != asep:
|
| 56 |
+
j += 1
|
| 57 |
+
if j == n:
|
| 58 |
+
break
|
| 59 |
+
spans.append((i, j))
|
| 60 |
+
i = j + 1
|
| 61 |
+
else:
|
| 62 |
+
i += 1
|
| 63 |
+
if not spans:
|
| 64 |
+
return None
|
| 65 |
+
return prompt_ids, trace_ids, spans
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def build_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
|
| 69 |
+
"""SFT ``(input_ids, labels)`` with the prompt masked; None to skip."""
|
| 70 |
+
r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
|
| 71 |
+
if r is None:
|
| 72 |
+
return None
|
| 73 |
+
prompt_ids, trace_ids, _ = r
|
| 74 |
+
return prompt_ids + trace_ids, [IGNORE_INDEX] * len(prompt_ids) + trace_ids
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def build_codi_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
|
| 78 |
+
"""Multi-span CODI example ``{prompt_ids, trace_ids, spans}``; None to skip."""
|
| 79 |
+
r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
|
| 80 |
+
if r is None:
|
| 81 |
+
return None
|
| 82 |
+
prompt_ids, trace_ids, spans = r
|
| 83 |
+
return {"prompt_ids": prompt_ids, "trace_ids": trace_ids, "spans": spans}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _load_cache(cache_dir, n_samples):
|
| 87 |
+
"""Load precomputed tokenized examples (precompute.py); slice to n_samples."""
|
| 88 |
+
from datasets import load_from_disk
|
| 89 |
+
|
| 90 |
+
ex = list(load_from_disk(cache_dir))
|
| 91 |
+
return ex[:n_samples] if n_samples > 0 else ex
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def build_codi_dataset(
|
| 95 |
+
tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
|
| 96 |
+
max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
|
| 97 |
+
) -> list[dict]:
|
| 98 |
+
"""CODI examples (prompt/reasoning/answer) over ``sources``, or a precomputed cache."""
|
| 99 |
+
if cache_dir:
|
| 100 |
+
ex = _load_cache(cache_dir, n_samples)
|
| 101 |
+
return [e for e in ex if len(e["prompt_ids"]) + len(e["trace_ids"]) <= max_seq_len]
|
| 102 |
+
rows = rows_for_sources(sources)
|
| 103 |
+
if n_samples > 0:
|
| 104 |
+
rows = rows[:n_samples]
|
| 105 |
+
out = []
|
| 106 |
+
for r in rows:
|
| 107 |
+
try:
|
| 108 |
+
out.append(build_codi_example(r["code"], r["input"], tokenizer,
|
| 109 |
+
max_seq_len=max_seq_len, max_frames=max_frames))
|
| 110 |
+
except Exception:
|
| 111 |
+
pass
|
| 112 |
+
return [ex for ex in out if ex is not None]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def build_codi_single_dataset(
|
| 116 |
+
tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
|
| 117 |
+
max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
|
| 118 |
+
) -> list[dict]:
|
| 119 |
+
"""Faithful single-block CODI: split each trace at its last ``<|return_sep|>`` into
|
| 120 |
+
``{prompt_ids, reasoning_ids, answer_ids}`` (reasoning = whole trace, answer = final
|
| 121 |
+
RETURN frame). Derived from the multi-span examples; no separate cache needed."""
|
| 122 |
+
rsep = tokenizer.convert_tokens_to_ids("<|return_sep|>")
|
| 123 |
+
out = []
|
| 124 |
+
for e in build_codi_dataset(tokenizer, sources=sources, n_samples=n_samples,
|
| 125 |
+
max_seq_len=max_seq_len, max_frames=max_frames, cache_dir=cache_dir):
|
| 126 |
+
t = e["trace_ids"]
|
| 127 |
+
idx = [i for i, x in enumerate(t) if x == rsep]
|
| 128 |
+
if not idx or idx[-1] == 0:
|
| 129 |
+
continue
|
| 130 |
+
out.append({"prompt_ids": e["prompt_ids"], "reasoning_ids": t[:idx[-1]], "answer_ids": t[idx[-1]:]})
|
| 131 |
+
return out
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def rows_for_sources(sources):
|
| 135 |
+
"""Merge {id,code,input,output} rows across sources (all rows; train vs test
|
| 136 |
+
is split by dataset, e.g. cruxeval is held out for eval)."""
|
| 137 |
+
from . import sources as _src
|
| 138 |
+
|
| 139 |
+
rows = []
|
| 140 |
+
for name in sources:
|
| 141 |
+
for i, row in enumerate(_src.load_one(name)):
|
| 142 |
+
missing = [k for k in ("id", "code", "input", "output") if k not in row]
|
| 143 |
+
if missing:
|
| 144 |
+
raise ValueError(f"{name} row {i} missing keys: {missing}")
|
| 145 |
+
if not all(isinstance(row[k], str) for k in ("code", "input", "output")):
|
| 146 |
+
raise TypeError(f"{name} row {i} must use string code/input/output")
|
| 147 |
+
row = dict(row)
|
| 148 |
+
row["id"] = str(row["id"])
|
| 149 |
+
rows.append(row)
|
| 150 |
+
return rows
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def build_dataset(
|
| 154 |
+
tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
|
| 155 |
+
max_seq_len: int = 8192, max_frames: int = -1, cache_dir: str | None = None
|
| 156 |
+
) -> list[tuple[list[int], list[int]]]:
|
| 157 |
+
"""Tokenized trace examples over ``sources``, or a precomputed cache."""
|
| 158 |
+
if cache_dir:
|
| 159 |
+
ex = _load_cache(cache_dir, n_samples)
|
| 160 |
+
return [(e["input_ids"], e["labels"]) for e in ex if len(e["input_ids"]) <= max_seq_len]
|
| 161 |
+
rows = rows_for_sources(sources)
|
| 162 |
+
if n_samples > 0:
|
| 163 |
+
rows = rows[:n_samples]
|
| 164 |
+
examples = (
|
| 165 |
+
build_example(
|
| 166 |
+
r["code"], r["input"], tokenizer,
|
| 167 |
+
max_seq_len=max_seq_len, max_frames=max_frames,
|
| 168 |
+
)
|
| 169 |
+
for r in rows
|
| 170 |
+
)
|
| 171 |
+
return [ex for ex in examples if ex is not None]
|
code/data/ground_truth.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
"""Ground-truth execution traces in CWM's frame format.
|
| 4 |
+
|
| 5 |
+
Runs ``f(input)`` under ``sys.settrace`` and records CALL/LINE/RETURN/EXCEPTION
|
| 6 |
+
frames with diff-based locals (unchanged vars render as ``".."``), values via
|
| 7 |
+
``repr``. A synthetic ``def main(): return f(<input>)`` wraps the function; the
|
| 8 |
+
seeded ``call main()`` frame is dropped by default to align with the trace
|
| 9 |
+
prompt. Not a bit-exact replica of Meta's internal tracer (see README.md).
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import linecache
|
| 15 |
+
import sys
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from .trace_format import DIFF_PLACEHOLDER, TraceEvent, TraceFrame, normalize_source
|
| 19 |
+
|
| 20 |
+
_FILENAME = "<cwm_trace>"
|
| 21 |
+
_ENTRY = "main"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class _FramesExceeded(Exception):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def make_trace_context(code: str, input_str: str) -> str:
|
| 29 |
+
return f"\n{code}\ndef main(): # << START_OF_TRACE\n return f({input_str})\n"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def render_value(value: Any) -> str:
|
| 33 |
+
try:
|
| 34 |
+
return repr(value)
|
| 35 |
+
except Exception:
|
| 36 |
+
return "<unrepr>"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def ground_truth_trace(
|
| 40 |
+
code: str, input_str: str, align_to_prompt: bool = True, max_frames: int = -1
|
| 41 |
+
) -> tuple[list[TraceFrame], str | None]:
|
| 42 |
+
"""Return (frames, error) for executing ``f(input_str)``. ``error`` is
|
| 43 |
+
non-None if the program raised; frames up to that point are still returned.
|
| 44 |
+
"""
|
| 45 |
+
context = make_trace_context(code, input_str)
|
| 46 |
+
linecache.cache[_FILENAME] = (len(context), None, context.splitlines(keepends=True), _FILENAME)
|
| 47 |
+
|
| 48 |
+
frames: list[TraceFrame] = []
|
| 49 |
+
scope_prev: dict[int, dict[str, str]] = {} # id(frame) -> last rendered locals
|
| 50 |
+
entry = None
|
| 51 |
+
|
| 52 |
+
def source(frame):
|
| 53 |
+
return normalize_source(linecache.getline(_FILENAME, frame.f_lineno))
|
| 54 |
+
|
| 55 |
+
def diff_locals(frame):
|
| 56 |
+
prev = scope_prev.get(id(frame), {})
|
| 57 |
+
out, rendered = {}, {}
|
| 58 |
+
for name, val in frame.f_locals.items():
|
| 59 |
+
r = render_value(val)
|
| 60 |
+
rendered[name] = r
|
| 61 |
+
out[name] = DIFF_PLACEHOLDER if prev.get(name) == r else r
|
| 62 |
+
scope_prev[id(frame)] = rendered
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
def trace(frame, event, arg):
|
| 66 |
+
nonlocal entry
|
| 67 |
+
# Abort loop-heavy programs, but only from our file (not GC/__del__ frames).
|
| 68 |
+
if max_frames > 0 and len(frames) >= max_frames and frame.f_code.co_filename == _FILENAME:
|
| 69 |
+
raise _FramesExceeded
|
| 70 |
+
if entry is None:
|
| 71 |
+
if event == "call" and frame.f_code.co_name == _ENTRY:
|
| 72 |
+
entry = id(frame)
|
| 73 |
+
else:
|
| 74 |
+
return None
|
| 75 |
+
# Only trace user code from our context, not library frames.
|
| 76 |
+
if frame.f_code.co_filename != _FILENAME:
|
| 77 |
+
return None
|
| 78 |
+
if event == "call":
|
| 79 |
+
frames.append(TraceFrame(event=TraceEvent.CALL, source=source(frame), locals=diff_locals(frame)))
|
| 80 |
+
elif event == "line":
|
| 81 |
+
frames.append(TraceFrame(event=TraceEvent.LINE, source=source(frame), locals=diff_locals(frame)))
|
| 82 |
+
elif event == "return":
|
| 83 |
+
frames.append(TraceFrame(event=TraceEvent.RETURN, source=source(frame), arg=render_value(arg)))
|
| 84 |
+
elif event == "exception":
|
| 85 |
+
name = getattr(arg[0], "__name__", str(arg[0]))
|
| 86 |
+
frames.append(TraceFrame(event=TraceEvent.EXCEPTION, source=source(frame), arg=render_value(name)))
|
| 87 |
+
return trace
|
| 88 |
+
|
| 89 |
+
ns: dict[str, Any] = {}
|
| 90 |
+
exec(compile(context, _FILENAME, "exec"), ns) # define f, main untraced
|
| 91 |
+
error = None
|
| 92 |
+
old = sys.gettrace()
|
| 93 |
+
sys.settrace(trace)
|
| 94 |
+
try:
|
| 95 |
+
ns[_ENTRY]()
|
| 96 |
+
except _FramesExceeded:
|
| 97 |
+
error = "frames_exceeded"
|
| 98 |
+
except Exception as e:
|
| 99 |
+
error = f"{type(e).__name__}: {e}"
|
| 100 |
+
finally:
|
| 101 |
+
sys.settrace(old)
|
| 102 |
+
|
| 103 |
+
# Drop the seeded ``call main()`` frame so frames align with the prompt.
|
| 104 |
+
if align_to_prompt and frames and frames[0].event == TraceEvent.CALL and frames[0].source.startswith("def main()"):
|
| 105 |
+
frames = frames[1:]
|
| 106 |
+
return frames, error
|
code/data/sources.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset name(s) -> merged {id, code, input, output} rows.
|
| 2 |
+
|
| 3 |
+
Add a converted dataset by running its folder's convert.py (saves ./data via
|
| 4 |
+
save_to_disk) and listing it in _LOCAL. cruxeval keeps its own Hub-fallback
|
| 5 |
+
loader and is held out entirely for eval (eval_cruxeval_*.py), never trained on.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
_LOCAL = {"mbpp": "MBPP", "humaneval": "HumanEval", "pyx": "PyX"} # name -> folder, data in ./data
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_cruxeval():
|
| 17 |
+
"""Full CRUXEval-O (the held-out test set). Prefer a local save_to_disk copy;
|
| 18 |
+
the HF builder FileLock dies on NFS caches."""
|
| 19 |
+
local_dir = os.environ.get("CRUXEVAL_DIR")
|
| 20 |
+
if local_dir and os.path.isdir(local_dir):
|
| 21 |
+
from datasets import load_from_disk
|
| 22 |
+
|
| 23 |
+
return list(load_from_disk(local_dir))
|
| 24 |
+
from datasets import load_dataset
|
| 25 |
+
|
| 26 |
+
return list(load_dataset("cruxeval-org/cruxeval", split="test"))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def load_one(name: str) -> list[dict]:
|
| 30 |
+
key = name.strip().lower()
|
| 31 |
+
if key == "cruxeval":
|
| 32 |
+
return load_cruxeval()
|
| 33 |
+
if key in _LOCAL:
|
| 34 |
+
from datasets import load_from_disk
|
| 35 |
+
|
| 36 |
+
d = os.environ.get(key.upper() + "_DIR") or str(Path(__file__).parent / _LOCAL[key] / "data")
|
| 37 |
+
return list(load_from_disk(d))
|
| 38 |
+
raise ValueError(f"unknown data source {name!r}; pick from {['cruxeval', *_LOCAL]}")
|
code/data/trace_format.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Shared CWM execution-trace representation and parsing.
|
| 5 |
+
|
| 6 |
+
CWM predicts an execution trace as a sequence of *frames*, each consisting of an
|
| 7 |
+
*observation* (the local-variable state) and an *action* (the executed source
|
| 8 |
+
line). The on-the-wire format (see PROMPTING_GUIDE.md and demos/cwmdbg.py) is:
|
| 9 |
+
|
| 10 |
+
<|call_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
|
| 11 |
+
<|line_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
|
| 12 |
+
<|return_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
|
| 13 |
+
<|exception_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
|
| 14 |
+
|
| 15 |
+
`$LOCALS` is a JSON object mapping variable names to *string* values; each value
|
| 16 |
+
is the JSON encoding of the underlying Python value (e.g. `"5"`, `"\"abc\""`,
|
| 17 |
+
`"[1, 2]"`). Locals use a diff-based representation: a variable whose value is
|
| 18 |
+
unchanged since the previous frame in the same scope is rendered as the
|
| 19 |
+
placeholder string `".."`. `$VALUE` (return/exception frames) is the JSON
|
| 20 |
+
encoding of the returned/raised value, stored as a JSON string.
|
| 21 |
+
|
| 22 |
+
This module is GPU-free and import-light so it can be unit-tested directly.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import json
|
| 28 |
+
from dataclasses import dataclass, field
|
| 29 |
+
from enum import Enum
|
| 30 |
+
|
| 31 |
+
# Literal piece strings as they appear when a generation is decoded with
|
| 32 |
+
# cut_at_stop_tokens=False (matches CWMInstructTokenizer.*_ID constants).
|
| 33 |
+
CALL_SEP = "<|call_sep|>"
|
| 34 |
+
LINE_SEP = "<|line_sep|>"
|
| 35 |
+
RETURN_SEP = "<|return_sep|>"
|
| 36 |
+
EXCEPTION_SEP = "<|exception_sep|>"
|
| 37 |
+
ACTION_SEP = "<|action_sep|>"
|
| 38 |
+
ARG_SEP = "<|arg_sep|>"
|
| 39 |
+
FRAME_SEP = "<|frame_sep|>"
|
| 40 |
+
END_OF_TEXT = "<|end_of_text|>"
|
| 41 |
+
|
| 42 |
+
DIFF_PLACEHOLDER = ".."
|
| 43 |
+
_START_MARKER = " # << START_OF_TRACE"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TraceEvent(Enum):
|
| 47 |
+
CALL = "call"
|
| 48 |
+
LINE = "line"
|
| 49 |
+
RETURN = "return"
|
| 50 |
+
EXCEPTION = "exception"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
_EVENT_TOKENS: dict[str, TraceEvent] = {
|
| 54 |
+
CALL_SEP: TraceEvent.CALL,
|
| 55 |
+
LINE_SEP: TraceEvent.LINE,
|
| 56 |
+
RETURN_SEP: TraceEvent.RETURN,
|
| 57 |
+
EXCEPTION_SEP: TraceEvent.EXCEPTION,
|
| 58 |
+
}
|
| 59 |
+
_EVENT_TO_TOKEN: dict[TraceEvent, str] = {v: k for k, v in _EVENT_TOKENS.items()}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class TraceFrame:
|
| 64 |
+
"""A single execution-trace frame.
|
| 65 |
+
|
| 66 |
+
`locals_str` is the raw `$LOCALS` text exactly as it appears between the
|
| 67 |
+
event token and `<|action_sep|>` (empty string for return/exception
|
| 68 |
+
frames). `locals` is its parsed form (a dict of name -> JSON-string-value),
|
| 69 |
+
or None if it failed to parse as a JSON object. `source` is the action line
|
| 70 |
+
with the START_OF_TRACE marker and trailing newline stripped.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
event: TraceEvent
|
| 74 |
+
source: str
|
| 75 |
+
locals_str: str = ""
|
| 76 |
+
locals: dict[str, str] | None = None
|
| 77 |
+
arg: str | None = None
|
| 78 |
+
malformed: bool = False
|
| 79 |
+
# Token counts (filled when a tokenizer is available); used for the
|
| 80 |
+
# "Avg State/Action Length (Token)" statistics rows of Table 9.
|
| 81 |
+
state_tokens: int = 0
|
| 82 |
+
action_tokens: int = 0
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def has_locals(self) -> bool:
|
| 86 |
+
return self.event in (TraceEvent.CALL, TraceEvent.LINE)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def normalize_source(source: str) -> str:
|
| 90 |
+
"""Strip the trace start marker and trailing newline from a source line."""
|
| 91 |
+
return source.rstrip("\n").rstrip(_START_MARKER).rstrip()
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def parse_locals(locals_str: str) -> dict[str, str] | None:
|
| 95 |
+
"""Parse a `$LOCALS` payload into a dict, or None if it is not a JSON object."""
|
| 96 |
+
locals_str = locals_str.strip()
|
| 97 |
+
if locals_str == "":
|
| 98 |
+
return {}
|
| 99 |
+
try:
|
| 100 |
+
obj = json.loads(locals_str)
|
| 101 |
+
except json.JSONDecodeError:
|
| 102 |
+
return None
|
| 103 |
+
if not isinstance(obj, dict):
|
| 104 |
+
return None
|
| 105 |
+
# Values are always JSON strings; coerce defensively.
|
| 106 |
+
return {str(k): v if isinstance(v, str) else json.dumps(v) for k, v in obj.items()}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def parse_generated_trace(generation: str) -> tuple[list[TraceFrame], bool]:
|
| 110 |
+
"""Parse a full-trace generation string into frames.
|
| 111 |
+
|
| 112 |
+
Returns (frames, well_formed). `well_formed` is True when every frame had a
|
| 113 |
+
leading event token and an `<|action_sep|>` (and an `<|arg_sep|>` for
|
| 114 |
+
return/exception frames) and the generation contained no leftover garbage
|
| 115 |
+
between the last frame and end-of-text. This drives the "Valid Trace Format"
|
| 116 |
+
metric. Individual frames are still returned even when malformed so that the
|
| 117 |
+
other metrics can be computed over whatever parsed cleanly.
|
| 118 |
+
"""
|
| 119 |
+
# Everything after end-of-text is irrelevant.
|
| 120 |
+
if END_OF_TEXT in generation:
|
| 121 |
+
generation = generation.split(END_OF_TEXT, 1)[0]
|
| 122 |
+
|
| 123 |
+
frames: list[TraceFrame] = []
|
| 124 |
+
well_formed = True
|
| 125 |
+
segments = generation.split(FRAME_SEP)
|
| 126 |
+
# The final segment is the text after the last frame_sep; for a clean trace
|
| 127 |
+
# it should be empty (the model emitted frame_sep then end_of_text).
|
| 128 |
+
trailing = segments.pop() if segments else ""
|
| 129 |
+
if trailing.strip() not in ("",):
|
| 130 |
+
well_formed = False
|
| 131 |
+
|
| 132 |
+
for seg in segments:
|
| 133 |
+
if seg.strip() == "":
|
| 134 |
+
# Stray empty segment (e.g. leading text before first token).
|
| 135 |
+
continue
|
| 136 |
+
frame, ok = _parse_segment(seg)
|
| 137 |
+
if frame is None:
|
| 138 |
+
well_formed = False
|
| 139 |
+
continue
|
| 140 |
+
well_formed = well_formed and ok
|
| 141 |
+
frames.append(frame)
|
| 142 |
+
|
| 143 |
+
if not frames:
|
| 144 |
+
well_formed = False
|
| 145 |
+
|
| 146 |
+
return frames, well_formed
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _parse_segment(seg: str) -> tuple[TraceFrame | None, bool]:
|
| 150 |
+
# Identify the (first) event token.
|
| 151 |
+
event: TraceEvent | None = None
|
| 152 |
+
for tok, evt in _EVENT_TOKENS.items():
|
| 153 |
+
idx = seg.find(tok)
|
| 154 |
+
if idx != -1:
|
| 155 |
+
event = evt
|
| 156 |
+
seg = seg[idx + len(tok):]
|
| 157 |
+
break
|
| 158 |
+
if event is None:
|
| 159 |
+
return None, False
|
| 160 |
+
|
| 161 |
+
ok = True
|
| 162 |
+
if event in (TraceEvent.CALL, TraceEvent.LINE):
|
| 163 |
+
if ACTION_SEP not in seg:
|
| 164 |
+
return (
|
| 165 |
+
TraceFrame(event=event, source="", malformed=True),
|
| 166 |
+
False,
|
| 167 |
+
)
|
| 168 |
+
locals_str, source = seg.split(ACTION_SEP, 1)
|
| 169 |
+
parsed = parse_locals(locals_str)
|
| 170 |
+
return (
|
| 171 |
+
TraceFrame(
|
| 172 |
+
event=event,
|
| 173 |
+
source=normalize_source(source),
|
| 174 |
+
locals_str=locals_str.strip(),
|
| 175 |
+
locals=parsed,
|
| 176 |
+
malformed=parsed is None,
|
| 177 |
+
),
|
| 178 |
+
ok,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# RETURN / EXCEPTION
|
| 182 |
+
if ACTION_SEP not in seg:
|
| 183 |
+
return TraceFrame(event=event, source="", malformed=True), False
|
| 184 |
+
seg = seg.split(ACTION_SEP, 1)[1]
|
| 185 |
+
if ARG_SEP in seg:
|
| 186 |
+
source, arg = seg.split(ARG_SEP, 1)
|
| 187 |
+
arg = _parse_arg(arg)
|
| 188 |
+
else:
|
| 189 |
+
source, arg = seg, None
|
| 190 |
+
ok = False
|
| 191 |
+
return (
|
| 192 |
+
TraceFrame(event=event, source=normalize_source(source), arg=arg),
|
| 193 |
+
ok,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def render_frames_to_generation(frames: list[TraceFrame]) -> str:
|
| 198 |
+
"""Render frames back to the on-the-wire generation string.
|
| 199 |
+
|
| 200 |
+
Inverse of ``parse_generated_trace`` for well-formed frames. Used by tests
|
| 201 |
+
(a ground-truth trace rendered this way must round-trip to a perfect score)
|
| 202 |
+
and to materialize a reference trace string for inspection.
|
| 203 |
+
"""
|
| 204 |
+
out: list[str] = []
|
| 205 |
+
for f in frames:
|
| 206 |
+
out.append(_EVENT_TO_TOKEN[f.event])
|
| 207 |
+
if f.has_locals:
|
| 208 |
+
out.append(json.dumps(f.locals if f.locals is not None else {}))
|
| 209 |
+
out.append(ACTION_SEP)
|
| 210 |
+
out.append(f.source)
|
| 211 |
+
if f.event in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
|
| 212 |
+
out.append(ARG_SEP)
|
| 213 |
+
out.append(json.dumps(f.arg))
|
| 214 |
+
out.append(FRAME_SEP)
|
| 215 |
+
out.append(END_OF_TEXT)
|
| 216 |
+
return "".join(out)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _parse_arg(arg_str: str) -> str | None:
|
| 220 |
+
arg_str = arg_str.strip()
|
| 221 |
+
if arg_str == "":
|
| 222 |
+
return None
|
| 223 |
+
try:
|
| 224 |
+
# The frame stores json.dumps(value_string); unwrap one level so `arg`
|
| 225 |
+
# is the source-literal value string (e.g. '"x9ja"' or '17').
|
| 226 |
+
loaded = json.loads(arg_str)
|
| 227 |
+
return loaded if isinstance(loaded, str) else arg_str
|
| 228 |
+
except json.JSONDecodeError:
|
| 229 |
+
return arg_str
|
code/eval/__init__.py
ADDED
|
File without changes
|
code/eval/eval_cruxeval_codi.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CRUXEval-O latent eval: the CODI student generates the trace, but at every
|
| 2 |
+
<|line_sep|> the frame's $LOCALS is replaced by a latent block (latent_start +
|
| 3 |
+
latent_steps recurrent latents + latent_end), mirroring training _student.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from datetime import timedelta
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
from data.dataset import _prompt_str
|
| 16 |
+
from data.sources import load_cruxeval
|
| 17 |
+
from eval.eval_cruxeval_sft import check_correct, extract_answer_trace_full
|
| 18 |
+
from tokens import add_trace_tokens, token_ids
|
| 19 |
+
from train.train_codi import CodiModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_codi(m, latent_steps, dev):
|
| 23 |
+
tok = AutoTokenizer.from_pretrained(m, use_fast=True)
|
| 24 |
+
add_trace_tokens(tok)
|
| 25 |
+
ids = token_ids(tok)
|
| 26 |
+
base = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(m), torch_dtype=torch.bfloat16)
|
| 27 |
+
model = CodiModel(base, latent_start_id=ids["<|latent_start|>"],
|
| 28 |
+
latent_end_id=ids["<|latent_end|>"], latent_steps=latent_steps)
|
| 29 |
+
if os.path.exists(f"{m}/pytorch_model.bin"): # epoch checkpoint: full CodiModel
|
| 30 |
+
model.load_state_dict(torch.load(f"{m}/pytorch_model.bin", map_location="cpu"))
|
| 31 |
+
else: # final export: backbone safetensors + separate projector
|
| 32 |
+
model.model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.bfloat16)
|
| 33 |
+
model.prj.load_state_dict(torch.load(f"{m}/thought_projector.pt", map_location="cpu"))
|
| 34 |
+
return tok, ids, model.to(dev).eval()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@torch.no_grad()
|
| 38 |
+
def gen_latent(model, prompt_ids, ls_id, eot, max_new):
|
| 39 |
+
dev = prompt_ids.device
|
| 40 |
+
o = model.model(input_ids=prompt_ids[None], use_cache=True)
|
| 41 |
+
cache, logits = o.past_key_values, o.logits[:, -1]
|
| 42 |
+
out = []
|
| 43 |
+
for _ in range(max_new):
|
| 44 |
+
t = int(logits.argmax(-1))
|
| 45 |
+
if t == eot:
|
| 46 |
+
break
|
| 47 |
+
out.append(t)
|
| 48 |
+
o = model.model(input_ids=torch.tensor([[t]], device=dev), past_key_values=cache, use_cache=True)
|
| 49 |
+
cache = o.past_key_values
|
| 50 |
+
if t == ls_id: # drop $LOCALS, insert latent block; its logits predict <|action_sep|>
|
| 51 |
+
cache, logits = model._latent_block(cache)
|
| 52 |
+
else:
|
| 53 |
+
logits = o.logits[:, -1]
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
ap = argparse.ArgumentParser()
|
| 59 |
+
ap.add_argument("--model", required=True)
|
| 60 |
+
ap.add_argument("--n_samples", type=int, default=-1)
|
| 61 |
+
ap.add_argument("--max_new_tokens", type=int, default=8192)
|
| 62 |
+
ap.add_argument("--latent_steps", type=int, default=1)
|
| 63 |
+
ap.add_argument("--out", default="")
|
| 64 |
+
args = ap.parse_args()
|
| 65 |
+
|
| 66 |
+
ddp = "RANK" in os.environ
|
| 67 |
+
rank = int(os.environ.get("RANK", 0))
|
| 68 |
+
world = int(os.environ.get("WORLD_SIZE", 1))
|
| 69 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 70 |
+
if ddp:
|
| 71 |
+
dist.init_process_group("nccl", timeout=timedelta(hours=1)) # ranks finish at different times under long gens
|
| 72 |
+
torch.cuda.set_device(local_rank)
|
| 73 |
+
|
| 74 |
+
tok, ids, model = load_codi(args.model, args.latent_steps, local_rank)
|
| 75 |
+
ls_id, eot = ids["<|line_sep|>"], ids["<|end_of_text|>"]
|
| 76 |
+
|
| 77 |
+
rows = load_cruxeval()
|
| 78 |
+
if args.n_samples > 0:
|
| 79 |
+
rows = rows[: args.n_samples]
|
| 80 |
+
n = len(rows)
|
| 81 |
+
shard = rows[rank::world]
|
| 82 |
+
|
| 83 |
+
n_correct = n_fmt = 0
|
| 84 |
+
results = []
|
| 85 |
+
for i, r in enumerate(shard):
|
| 86 |
+
enc = tok(_prompt_str(r["code"], r["input"]), return_tensors="pt",
|
| 87 |
+
add_special_tokens=False).to(local_rank)
|
| 88 |
+
gen = tok.decode(gen_latent(model, enc["input_ids"][0], ls_id, eot, args.max_new_tokens),
|
| 89 |
+
skip_special_tokens=False)
|
| 90 |
+
pred = extract_answer_trace_full(gen)
|
| 91 |
+
ok = pred is not None and check_correct(r["code"], r["output"], pred)
|
| 92 |
+
n_fmt += pred is not None
|
| 93 |
+
n_correct += ok
|
| 94 |
+
results.append({"id": r["id"], "expected": r["output"], "predicted": pred, "correct": ok, "generation": gen})
|
| 95 |
+
if rank == 0 and (i + 1) % 20 == 0:
|
| 96 |
+
print(f" rank0 {i+1}/{len(shard)} pass@1={n_correct/(i+1):.4f}", flush=True)
|
| 97 |
+
|
| 98 |
+
if ddp:
|
| 99 |
+
t = torch.tensor([n_correct, n_fmt], device=local_rank)
|
| 100 |
+
dist.all_reduce(t)
|
| 101 |
+
n_correct, n_fmt = int(t[0]), int(t[1])
|
| 102 |
+
gathered = [None] * world
|
| 103 |
+
dist.gather_object(results, gathered if rank == 0 else None, dst=0)
|
| 104 |
+
if rank == 0:
|
| 105 |
+
results = [x for part in gathered for x in part]
|
| 106 |
+
|
| 107 |
+
if rank == 0:
|
| 108 |
+
print(f"\nCRUXEval-O latent pass@1={n_correct / n:.4f} "
|
| 109 |
+
f"valid_format={n_fmt / n:.4f} (n={n}, greedy)")
|
| 110 |
+
if args.out:
|
| 111 |
+
with open(args.out, "w") as f:
|
| 112 |
+
json.dump({"pass_at_1": n_correct / n, "valid_format": n_fmt / n,
|
| 113 |
+
"n": n, "results": results}, f, indent=2)
|
| 114 |
+
if ddp:
|
| 115 |
+
dist.destroy_process_group()
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
main()
|
code/eval/eval_cruxeval_sft.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stage 1 baseline eval: CRUXEval-O output prediction via full-trace generation.
|
| 2 |
+
|
| 3 |
+
Feed the training prompt (seeds frame 0), let the SFT model generate the trace,
|
| 4 |
+
take main()'s last return value as the predicted output, score by execution.
|
| 5 |
+
Greedy => pass@1 is the exact-match fraction. Reuses cwm_andre eval logic.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import subprocess
|
| 12 |
+
import sys
|
| 13 |
+
from datetime import timedelta
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.distributed as dist
|
| 17 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 18 |
+
|
| 19 |
+
from data.dataset import _prompt_str
|
| 20 |
+
from data.sources import load_cruxeval
|
| 21 |
+
from tokens import add_trace_tokens, token_ids
|
| 22 |
+
|
| 23 |
+
ARG_SEP, FRAME_SEP, RETURN_SEP = "<|arg_sep|>", "<|frame_sep|>", "<|return_sep|>"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def extract_answer_trace_full(gen: str) -> str | None:
|
| 27 |
+
"""Value of main()'s last RETURN frame: ...<|arg_sep|>"value"<|frame_sep|>."""
|
| 28 |
+
r = gen.rfind(RETURN_SEP)
|
| 29 |
+
if r == -1:
|
| 30 |
+
return None
|
| 31 |
+
a = gen.find(ARG_SEP, r)
|
| 32 |
+
if a == -1:
|
| 33 |
+
return None
|
| 34 |
+
rest = gen[a + len(ARG_SEP):]
|
| 35 |
+
end = rest.find(FRAME_SEP)
|
| 36 |
+
val = (rest[:end] if end != -1 else rest).strip()
|
| 37 |
+
if not val:
|
| 38 |
+
return None
|
| 39 |
+
try:
|
| 40 |
+
return json.loads(val)
|
| 41 |
+
except json.JSONDecodeError:
|
| 42 |
+
return val
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def check_correct(code: str, expected: str, predicted: str, timeout: float = 3.0) -> bool:
|
| 46 |
+
"""Execute `code; assert expected == predicted` (CRUXEval semantics)."""
|
| 47 |
+
test = f"{code}\nassert {expected} == {predicted}"
|
| 48 |
+
try:
|
| 49 |
+
return subprocess.run(
|
| 50 |
+
[sys.executable, "-c", test], timeout=timeout, capture_output=True
|
| 51 |
+
).returncode == 0
|
| 52 |
+
except Exception:
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def main():
|
| 57 |
+
ap = argparse.ArgumentParser()
|
| 58 |
+
ap.add_argument("--model", required=True)
|
| 59 |
+
ap.add_argument("--n_samples", type=int, default=-1)
|
| 60 |
+
ap.add_argument("--max_new_tokens", type=int, default=8192)
|
| 61 |
+
ap.add_argument("--batch_size", type=int, default=8)
|
| 62 |
+
ap.add_argument("--out", default="")
|
| 63 |
+
args = ap.parse_args()
|
| 64 |
+
|
| 65 |
+
# DDP-style data parallelism for inference: torchrun sets RANK/WORLD_SIZE/LOCAL_RANK.
|
| 66 |
+
ddp = "RANK" in os.environ
|
| 67 |
+
rank = int(os.environ.get("RANK", 0))
|
| 68 |
+
world = int(os.environ.get("WORLD_SIZE", 1))
|
| 69 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 70 |
+
if ddp:
|
| 71 |
+
dist.init_process_group("nccl", timeout=timedelta(hours=1)) # ranks finish at different times under long gens
|
| 72 |
+
torch.cuda.set_device(local_rank)
|
| 73 |
+
|
| 74 |
+
tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
|
| 75 |
+
add_trace_tokens(tok) # idempotent; ensures trace tokens present
|
| 76 |
+
tok.padding_side = "left" # left-pad so all generated tokens start at the same offset
|
| 77 |
+
eot_id = token_ids(tok)["<|end_of_text|>"]
|
| 78 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 79 |
+
args.model, torch_dtype=torch.bfloat16).to(local_rank).eval()
|
| 80 |
+
|
| 81 |
+
rows = load_cruxeval()
|
| 82 |
+
if args.n_samples > 0:
|
| 83 |
+
rows = rows[: args.n_samples]
|
| 84 |
+
n = len(rows)
|
| 85 |
+
shard = rows[rank::world] # disjoint round-robin split across ranks
|
| 86 |
+
|
| 87 |
+
n_correct = n_fmt = 0
|
| 88 |
+
results = []
|
| 89 |
+
for bi, batch_start in enumerate(range(0, len(shard), args.batch_size)):
|
| 90 |
+
batch = shard[batch_start: batch_start + args.batch_size]
|
| 91 |
+
enc = tok([_prompt_str(r["code"], r["input"]) for r in batch],
|
| 92 |
+
return_tensors="pt", padding=True, add_special_tokens=False).to(local_rank)
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
out = model.generate(**enc, max_new_tokens=args.max_new_tokens, do_sample=False,
|
| 95 |
+
eos_token_id=eot_id, pad_token_id=eot_id)
|
| 96 |
+
for j, r in enumerate(batch):
|
| 97 |
+
gen = tok.decode(out[j, enc["input_ids"].shape[1]:], skip_special_tokens=False)
|
| 98 |
+
pred = extract_answer_trace_full(gen)
|
| 99 |
+
ok = pred is not None and check_correct(r["code"], r["output"], pred)
|
| 100 |
+
n_fmt += pred is not None
|
| 101 |
+
n_correct += ok
|
| 102 |
+
results.append({"id": r["id"], "expected": r["output"], "predicted": pred, "correct": ok, "generation": gen})
|
| 103 |
+
if rank == 0 and (bi + 1) % 5 == 0:
|
| 104 |
+
done = batch_start + len(batch)
|
| 105 |
+
print(f" rank0 {done}/{len(shard)} pass@1={n_correct/done:.4f}", flush=True)
|
| 106 |
+
|
| 107 |
+
# Reduce metrics and gather per-row results across ranks.
|
| 108 |
+
if ddp:
|
| 109 |
+
t = torch.tensor([n_correct, n_fmt], device=local_rank)
|
| 110 |
+
dist.all_reduce(t)
|
| 111 |
+
n_correct, n_fmt = int(t[0]), int(t[1])
|
| 112 |
+
gathered = [None] * world
|
| 113 |
+
dist.gather_object(results, gathered if rank == 0 else None, dst=0)
|
| 114 |
+
if rank == 0:
|
| 115 |
+
results = [x for part in gathered for x in part]
|
| 116 |
+
|
| 117 |
+
if rank == 0:
|
| 118 |
+
print(f"\nCRUXEval-O pass@1={n_correct / n:.4f} "
|
| 119 |
+
f"valid_format={n_fmt / n:.4f} (n={n}, greedy)")
|
| 120 |
+
if args.out:
|
| 121 |
+
with open(args.out, "w") as f:
|
| 122 |
+
json.dump({"pass_at_1": n_correct / n, "valid_format": n_fmt / n,
|
| 123 |
+
"n": n, "results": results}, f, indent=2)
|
| 124 |
+
if ddp:
|
| 125 |
+
dist.destroy_process_group()
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if __name__ == "__main__":
|
| 129 |
+
main()
|
code/tokens.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CWM trace special tokens + tokenizer/embedding setup for a non-CWM base."""
|
| 2 |
+
|
| 3 |
+
# Trace-format tokens (mirrors data/trace_format.py) + latent delimiters.
|
| 4 |
+
TRACE_TOKENS = [
|
| 5 |
+
"<|trace_context_start|>",
|
| 6 |
+
"<|call_sep|>", "<|line_sep|>", "<|return_sep|>", "<|exception_sep|>",
|
| 7 |
+
"<|action_sep|>", "<|arg_sep|>", "<|frame_sep|>", "<|end_of_text|>",
|
| 8 |
+
"<|latent_start|>", "<|latent_end|>",
|
| 9 |
+
]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def add_trace_tokens(tokenizer) -> int:
|
| 13 |
+
"""Add the trace tokens as special tokens. Returns the count newly added."""
|
| 14 |
+
return tokenizer.add_tokens(TRACE_TOKENS, special_tokens=True)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def resize_and_init(model, tokenizer, n_added: int) -> None:
|
| 18 |
+
"""Resize embeddings to the tokenizer; init new rows to the existing mean."""
|
| 19 |
+
old = model.get_input_embeddings().weight.shape[0]
|
| 20 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 21 |
+
if n_added <= 0:
|
| 22 |
+
return
|
| 23 |
+
seen = set()
|
| 24 |
+
for emb in (model.get_input_embeddings(), model.get_output_embeddings()):
|
| 25 |
+
if emb is None or id(emb) in seen: # tied embeddings: resize once
|
| 26 |
+
continue
|
| 27 |
+
seen.add(id(emb))
|
| 28 |
+
w = emb.weight.data
|
| 29 |
+
w[old:] = w[:old].mean(dim=0, keepdim=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def token_ids(tokenizer) -> dict[str, int]:
|
| 33 |
+
"""Map each trace token to its single id (asserts single-token encoding)."""
|
| 34 |
+
ids = {}
|
| 35 |
+
for t in TRACE_TOKENS:
|
| 36 |
+
enc = tokenizer.encode(t, add_special_tokens=False)
|
| 37 |
+
assert len(enc) == 1, f"{t!r} did not encode to a single id: {enc}"
|
| 38 |
+
ids[t] = enc[0]
|
| 39 |
+
return ids
|
code/train/__init__.py
ADDED
|
File without changes
|
code/train/train_codi.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stage 2b: per-frame CODI self-distillation (multi-span).
|
| 2 |
+
|
| 3 |
+
Shared-weight teacher+student initialized from the Stage-1 SFT model.
|
| 4 |
+
- Teacher reads the full explicit trace (prompt+trace), CE = L_teacher.
|
| 5 |
+
- Student replaces each LINE frame's $LOCALS with a latent block (latent_start +
|
| 6 |
+
`latent_steps` recurrent latents + latent_end; last hidden -> prj -> next embed)
|
| 7 |
+
and teacher-forces the rest, CE = L_student over the emitted (non-locals) text.
|
| 8 |
+
- KD aligns the hidden at each frame's `<|action_sep|>` (student after latents vs
|
| 9 |
+
teacher after locals), teacher detached. L = a*Lt + b*Ls + g*Lkd.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import os
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, Trainer, TrainingArguments
|
| 20 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 21 |
+
from transformers.utils import WEIGHTS_NAME
|
| 22 |
+
|
| 23 |
+
from data.dataset import IGNORE_INDEX, build_codi_dataset
|
| 24 |
+
from tokens import add_trace_tokens, token_ids
|
| 25 |
+
from wb import wandb_init
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CodiModel(nn.Module):
|
| 29 |
+
def __init__(self, base, *, latent_start_id, latent_end_id, latent_steps,
|
| 30 |
+
a=1.0, b=1.0, g=1.0, kd_layers=None, single_anchor=False,
|
| 31 |
+
ss_prob=0.0, ss_ramp_frac=0.5, teacher=None, kd_target="hidden", kd_temp=2.0,
|
| 32 |
+
line_sep_id=None, recon_w=0.0):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.model = base
|
| 35 |
+
h = base.config.hidden_size
|
| 36 |
+
# CODI thought projector (last hidden -> next latent input).
|
| 37 |
+
self.prj = nn.Sequential(
|
| 38 |
+
nn.Linear(h, h, bias=False), nn.GELU(),
|
| 39 |
+
nn.Linear(h, h, bias=False), nn.LayerNorm(h),
|
| 40 |
+
)
|
| 41 |
+
ref = base.get_input_embeddings().weight
|
| 42 |
+
self.prj.to(device=ref.device, dtype=ref.dtype)
|
| 43 |
+
self.latent_steps, self.a, self.b, self.g = latent_steps, a, b, g
|
| 44 |
+
self.teacher = [teacher] if teacher is not None else None # list -> hidden from state_dict/DDP/optim
|
| 45 |
+
self.kd_target, self.kd_temp = kd_target, kd_temp # hidden: smooth_l1 on kd_layers; logit: KL on lm_head
|
| 46 |
+
if kd_target == "logit" or (teacher is not None and kd_layers is None):
|
| 47 |
+
kd_layers = [-1] # logit KD is defined on the last layer only; frozen default = key (last) hidden
|
| 48 |
+
self.kd_layers = kd_layers # None -> all layers
|
| 49 |
+
self.single_anchor = single_anchor # KD at last span only (vanilla-CODI ablation)
|
| 50 |
+
# scheduled sampling: ss_p (ramped per step) of post-latent lines feed the student's own argmax
|
| 51 |
+
self.ss_prob, self.ss_ramp_frac, self.ss_p = ss_prob, ss_ramp_frac, 0.0
|
| 52 |
+
self.register_buffer("_ls_tok", torch.tensor([[latent_start_id]], dtype=torch.long), persistent=False)
|
| 53 |
+
self.register_buffer("_le_tok", torch.tensor([[latent_end_id]], dtype=torch.long), persistent=False)
|
| 54 |
+
self.body = base.model
|
| 55 |
+
self.head = base.lm_head
|
| 56 |
+
|
| 57 |
+
def _kd(self, hs):
|
| 58 |
+
return hs[1:] if self.kd_layers is None else tuple(hs[l] for l in self.kd_layers)
|
| 59 |
+
|
| 60 |
+
def _emb(self, ids):
|
| 61 |
+
return self.model.get_input_embeddings()(ids)
|
| 62 |
+
|
| 63 |
+
def _teacher(self, full_ids, labels, kd_pos):
|
| 64 |
+
pos = torch.tensor(kd_pos, device=full_ids.device)
|
| 65 |
+
if self.teacher is not None: # frozen teacher: KD targets only, no teacher CE
|
| 66 |
+
tch, dev = self.teacher[0], full_ids.device
|
| 67 |
+
if next(tch.parameters()).device != dev:
|
| 68 |
+
tch.to(dev)
|
| 69 |
+
with torch.no_grad():
|
| 70 |
+
if self.kd_target == "logit": # target = teacher's own next-token logits
|
| 71 |
+
return None, [tch(input_ids=full_ids[None], use_cache=False).logits[0, pos]]
|
| 72 |
+
hs = tch(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
|
| 73 |
+
return None, [l[0, pos] for l in self._kd(hs)]
|
| 74 |
+
with torch.no_grad(): # KD targets are detached; take hiddens without a backward graph
|
| 75 |
+
hs = self.model(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
|
| 76 |
+
kd = [l[0, pos] for l in self._kd(hs)]
|
| 77 |
+
# CE forward without output_hidden_states so grad-checkpointing actually frees layer acts.
|
| 78 |
+
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
| 79 |
+
logits = self.model(input_ids=full_ids[None], use_cache=False).logits
|
| 80 |
+
self.model.gradient_checkpointing_disable() # teacher-only; student keeps KV cache
|
| 81 |
+
ce = F.cross_entropy(logits[0, :-1], labels[1:], ignore_index=IGNORE_INDEX)
|
| 82 |
+
return ce, kd
|
| 83 |
+
|
| 84 |
+
def _latent_block(self, cache):
|
| 85 |
+
"""latent_start + `latent_steps` recurrent latents + latent_end on top of
|
| 86 |
+
`cache`. Returns (new cache, logits predicting the next real token)."""
|
| 87 |
+
o = self.body(inputs_embeds=self._emb(self._ls_tok), past_key_values=cache, use_cache=True)
|
| 88 |
+
cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
|
| 89 |
+
for _ in range(self.latent_steps):
|
| 90 |
+
o = self.body(inputs_embeds=self.prj(h), past_key_values=cache, use_cache=True)
|
| 91 |
+
cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
|
| 92 |
+
o = self.body(inputs_embeds=self._emb(self._le_tok), past_key_values=cache, use_cache=True)
|
| 93 |
+
return o.past_key_values, self.head(o.last_hidden_state[:, -1])
|
| 94 |
+
|
| 95 |
+
def _student(self, prompt_ids, trace_ids, spans):
|
| 96 |
+
# Segments cover trace_ids in order; locals (trace_ids[i+1:j]) are dropped
|
| 97 |
+
# and replaced by a latent block. kd=True marks a frame's <|action_sep|>.
|
| 98 |
+
segs, prev, kd = [], 0, False
|
| 99 |
+
for i, j in spans:
|
| 100 |
+
segs.append(("text", trace_ids[prev:i + 1], kd))
|
| 101 |
+
segs.append(("latent", None, False))
|
| 102 |
+
prev, kd = j, True
|
| 103 |
+
segs.append(("text", trace_ids[prev:], kd))
|
| 104 |
+
last = len(segs) - 1
|
| 105 |
+
|
| 106 |
+
out = self.model(inputs_embeds=self._emb(prompt_ids[None]), use_cache=True)
|
| 107 |
+
cache, prev_logits = out.past_key_values, out.logits[:, -1] # predicts trace_ids[0]
|
| 108 |
+
ce_logits, ce_targets, kd_vecs = [], [], []
|
| 109 |
+
for s, (kind, ids, kd) in enumerate(segs):
|
| 110 |
+
if kind == "latent": # prev_logits predicted dropped locals; overwrite, no CE
|
| 111 |
+
cache, prev_logits = self._latent_block(cache)
|
| 112 |
+
continue
|
| 113 |
+
inp = ids
|
| 114 |
+
if kd and 0 < self.ss_p and random.random() < self.ss_p:
|
| 115 |
+
# scheduled sampling: replace the code (not action_sep / line_sep) with the student's own
|
| 116 |
+
# argmax via a no-grad pass on a detached cache clone; CE targets below stay GT.
|
| 117 |
+
end = ids.numel() if s == last else ids.numel() - 1
|
| 118 |
+
c = DynamicCache()
|
| 119 |
+
for i, ly in enumerate(cache.layers):
|
| 120 |
+
c.update(ly.keys.detach(), ly.values.detach(), i)
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
pred = self.model(inputs_embeds=self._emb(ids[None]), past_key_values=c, use_cache=True).logits[0].argmax(-1)
|
| 123 |
+
inp = ids.clone(); inp[1:end] = pred[:end - 1]
|
| 124 |
+
ce_logits.append(prev_logits); ce_targets.append(ids[:1])
|
| 125 |
+
out = self.model(inputs_embeds=self._emb(inp[None]), past_key_values=cache,
|
| 126 |
+
use_cache=True, output_hidden_states=kd) # hiddens only for KD anchors
|
| 127 |
+
cache, logits = out.past_key_values, out.logits[0]
|
| 128 |
+
if ids.numel() > 1:
|
| 129 |
+
ce_logits.append(logits[:-1]); ce_targets.append(ids[1:])
|
| 130 |
+
prev_logits = logits[-1:]
|
| 131 |
+
if kd: # action_sep is this segment's first token
|
| 132 |
+
kd_vecs.append([hs[0, 0] for hs in self._kd(out.hidden_states)])
|
| 133 |
+
ce = F.cross_entropy(torch.cat(ce_logits), torch.cat(ce_targets))
|
| 134 |
+
s_kd = [torch.stack([v[l] for v in kd_vecs]) for l in range(len(kd_vecs[0]))]
|
| 135 |
+
return ce, s_kd
|
| 136 |
+
|
| 137 |
+
def _kd_loss(self, s_kd, t_kd):
|
| 138 |
+
s, t = torch.stack(s_kd), torch.stack(t_kd).detach()
|
| 139 |
+
if self.kd_target == "logit": # s=student hidden, t=frozen-teacher logits; KL on distributions
|
| 140 |
+
T = self.kd_temp
|
| 141 |
+
sl, tl = self.head(s).flatten(0, -2) / T, t.flatten(0, -2) / T
|
| 142 |
+
return F.kl_div(F.log_softmax(sl, -1), F.softmax(tl, -1), reduction="batchmean") * T * T
|
| 143 |
+
return F.smooth_l1_loss(s, t)
|
| 144 |
+
|
| 145 |
+
def forward(self, examples):
|
| 146 |
+
dev = self.model.get_input_embeddings().weight.device
|
| 147 |
+
tl = sl = kl = 0.0
|
| 148 |
+
for ex in examples:
|
| 149 |
+
prompt = torch.tensor(ex["prompt_ids"], device=dev)
|
| 150 |
+
trace = torch.tensor(ex["trace_ids"], device=dev)
|
| 151 |
+
spans = ex["spans"]
|
| 152 |
+
full = torch.cat([prompt, trace])
|
| 153 |
+
labels = None if self.teacher else torch.cat([full.new_full((len(prompt),), IGNORE_INDEX), trace])
|
| 154 |
+
kd_pos = [len(prompt) + j for _, j in spans]
|
| 155 |
+
t_ce, t_kd = self._teacher(full, labels, kd_pos)
|
| 156 |
+
s_ce, s_kd = self._student(prompt, trace, spans)
|
| 157 |
+
if self.single_anchor: # keep only the last frame's anchor (per layer)
|
| 158 |
+
t_kd, s_kd = [t[-1:] for t in t_kd], [s[-1:] for s in s_kd]
|
| 159 |
+
tl = tl + (t_ce if t_ce is not None else 0.0) # frozen teacher -> no teacher CE
|
| 160 |
+
sl, kl = sl + s_ce, kl + self._kd_loss(s_kd, t_kd)
|
| 161 |
+
n = len(examples)
|
| 162 |
+
loss = self.a * tl / n + self.b * sl / n + self.g * kl / n
|
| 163 |
+
t_log = (tl / n).detach() if torch.is_tensor(tl) else torch.tensor(0.0) # 0 under frozen teacher
|
| 164 |
+
return {"loss": loss, "teacher_loss": t_log,
|
| 165 |
+
"student_loss": (sl / n).detach(), "kd_loss": (kl / n).detach()}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class CodiTrainer(Trainer):
|
| 169 |
+
def compute_loss(self, model, inputs, return_outputs=False, **kw):
|
| 170 |
+
core = model.module if hasattr(model, "module") else model
|
| 171 |
+
if core.ss_prob: # linear ramp 0 -> ss_prob over the first ss_ramp_frac of training
|
| 172 |
+
core.ss_p = self._ss = core.ss_prob * min(1.0, self.state.global_step / max(1.0, core.ss_ramp_frac * self.state.max_steps))
|
| 173 |
+
out = model(inputs["examples"])
|
| 174 |
+
self._sub = {k: out[k].detach() for k in ("teacher_loss", "student_loss", "kd_loss")}
|
| 175 |
+
return (out["loss"], out) if return_outputs else out["loss"]
|
| 176 |
+
|
| 177 |
+
def log(self, logs, *a, **k): # surface sub-losses to console + wandb
|
| 178 |
+
if hasattr(self, "_sub"):
|
| 179 |
+
logs.update({k: v.item() for k, v in self._sub.items()})
|
| 180 |
+
if hasattr(self, "_ss"):
|
| 181 |
+
logs["ss_p"] = self._ss
|
| 182 |
+
super().log(logs, *a, **k)
|
| 183 |
+
|
| 184 |
+
def _save(self, output_dir=None, state_dict=None):
|
| 185 |
+
# tied backbone weights -> safetensors (5.x default) rejects shared tensors; torch.save instead.
|
| 186 |
+
output_dir = output_dir or self.args.output_dir
|
| 187 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 188 |
+
torch.save(state_dict or self.model.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
|
| 189 |
+
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
| 190 |
+
# also write config/tokenizer/projector so each ckpt is eval-loadable (small, no weight dup).
|
| 191 |
+
self.model.model.config.save_pretrained(output_dir)
|
| 192 |
+
self.tok.save_pretrained(output_dir)
|
| 193 |
+
torch.save(self.model.prj.state_dict(), os.path.join(output_dir, "thought_projector.pt"))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def main():
|
| 197 |
+
ap = argparse.ArgumentParser()
|
| 198 |
+
ap.add_argument("--model", required=True) # Stage-1 SFT dir
|
| 199 |
+
ap.add_argument("--output_dir", required=True)
|
| 200 |
+
ap.add_argument("--sources", nargs="+", default=["mbpp", "humaneval", "pyx"])
|
| 201 |
+
ap.add_argument("--cache_dir", default="data/cache/codi_train") # offline tokenized examples from precompute.py
|
| 202 |
+
ap.add_argument("--n_samples", type=int, default=-1)
|
| 203 |
+
ap.add_argument("--max_seq_len", type=int, default=4096)
|
| 204 |
+
ap.add_argument("--max_frames", type=int, default=-1)
|
| 205 |
+
ap.add_argument("--latent_steps", type=int, default=1)
|
| 206 |
+
ap.add_argument("--epochs", type=float, default=10.0)
|
| 207 |
+
ap.add_argument("--lr", type=float, default=1e-5)
|
| 208 |
+
ap.add_argument("--batch_size", type=int, default=1)
|
| 209 |
+
ap.add_argument("--grad_accum", type=int, default=4)
|
| 210 |
+
ap.add_argument("--max_steps", type=int, default=-1)
|
| 211 |
+
ap.add_argument("--save_steps", type=int, default=500)
|
| 212 |
+
ap.add_argument("--alpha", type=float, default=1.0)
|
| 213 |
+
ap.add_argument("--beta", type=float, default=1.0)
|
| 214 |
+
ap.add_argument("--gamma", type=float, default=1.0)
|
| 215 |
+
ap.add_argument("--kd_layers", nargs="+", type=int, default=None) # default: all layers (frozen -> last)
|
| 216 |
+
ap.add_argument("--frozen_teacher", default="") # path to frozen SFT teacher; "" -> shared-weight (legacy)
|
| 217 |
+
ap.add_argument("--kd_target", default="hidden", choices=["hidden", "logit"]) # key-hidden align: smooth_l1 vs KL
|
| 218 |
+
ap.add_argument("--kd_temp", type=float, default=2.0) # logit-KD temperature
|
| 219 |
+
ap.add_argument("--single_anchor", action="store_true") # KD at last frame only (vanilla CODI)
|
| 220 |
+
ap.add_argument("--ss_prob", type=float, default=0.0) # scheduled-sampling max prob (0 = off)
|
| 221 |
+
ap.add_argument("--ss_ramp_frac", type=float, default=0.5) # ramp ss_prob over this frac of steps
|
| 222 |
+
args = ap.parse_args()
|
| 223 |
+
|
| 224 |
+
tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
|
| 225 |
+
add_trace_tokens(tok) # idempotent
|
| 226 |
+
ids = token_ids(tok)
|
| 227 |
+
base = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16)
|
| 228 |
+
base.config.use_cache = True
|
| 229 |
+
teacher = None
|
| 230 |
+
if args.frozen_teacher:
|
| 231 |
+
teacher = AutoModelForCausalLM.from_pretrained(args.frozen_teacher, torch_dtype=torch.bfloat16)
|
| 232 |
+
teacher.config.use_cache = False
|
| 233 |
+
teacher.eval().requires_grad_(False)
|
| 234 |
+
model = CodiModel(base, latent_start_id=ids["<|latent_start|>"], latent_end_id=ids["<|latent_end|>"],
|
| 235 |
+
latent_steps=args.latent_steps, a=args.alpha, b=args.beta, g=args.gamma,
|
| 236 |
+
kd_layers=args.kd_layers, single_anchor=args.single_anchor,
|
| 237 |
+
ss_prob=args.ss_prob, ss_ramp_frac=args.ss_ramp_frac,
|
| 238 |
+
teacher=teacher, kd_target=args.kd_target, kd_temp=args.kd_temp)
|
| 239 |
+
|
| 240 |
+
ds = build_codi_dataset(tok, sources=args.sources, cache_dir=args.cache_dir,
|
| 241 |
+
n_samples=args.n_samples, max_seq_len=args.max_seq_len, max_frames=args.max_frames)
|
| 242 |
+
print(f"{len(ds)} codi examples, latent_steps={args.latent_steps}")
|
| 243 |
+
|
| 244 |
+
report_to = wandb_init(args, "codi")
|
| 245 |
+
|
| 246 |
+
targs = TrainingArguments(
|
| 247 |
+
output_dir=args.output_dir,
|
| 248 |
+
per_device_train_batch_size=args.batch_size,
|
| 249 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 250 |
+
num_train_epochs=args.epochs,
|
| 251 |
+
max_steps=args.max_steps,
|
| 252 |
+
learning_rate=args.lr,
|
| 253 |
+
lr_scheduler_type="cosine",
|
| 254 |
+
warmup_ratio=0.03,
|
| 255 |
+
weight_decay=0.1,
|
| 256 |
+
max_grad_norm=1.0,
|
| 257 |
+
bf16=True,
|
| 258 |
+
optim="paged_adamw_8bit",
|
| 259 |
+
ddp_find_unused_parameters=False,
|
| 260 |
+
logging_steps=5,
|
| 261 |
+
save_strategy="steps",
|
| 262 |
+
save_steps=args.save_steps,
|
| 263 |
+
save_total_limit=None,
|
| 264 |
+
report_to=report_to,
|
| 265 |
+
remove_unused_columns=False,
|
| 266 |
+
label_names=[],
|
| 267 |
+
)
|
| 268 |
+
trainer = CodiTrainer(
|
| 269 |
+
model=model, args=targs, train_dataset=ds,
|
| 270 |
+
data_collator=lambda b: {"examples": b},
|
| 271 |
+
)
|
| 272 |
+
trainer.tok = tok
|
| 273 |
+
# Native checkpoints (CodiModel wrapper + optimizer) auto-resume if interrupted.
|
| 274 |
+
ckpt = get_last_checkpoint(args.output_dir) if os.path.isdir(args.output_dir) else None
|
| 275 |
+
trainer.train(resume_from_checkpoint=ckpt)
|
| 276 |
+
trainer._save_checkpoint(trainer.model, trial=None) # final step as a resumable, eval-loadable checkpoint-<step>
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
main()
|
code/wb.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""wandb: default-on, offline (compute nodes have no internet -> `wandb sync` later),
|
| 2 |
+
never blocks training. Returns report_to for TrainingArguments."""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def wandb_init(args, stage):
|
| 8 |
+
if int(os.environ.get("RANK", "0")) != 0: # rank0 only under DDP
|
| 9 |
+
return []
|
| 10 |
+
try:
|
| 11 |
+
import wandb
|
| 12 |
+
os.environ.setdefault("WANDB_MODE", "offline")
|
| 13 |
+
wandb.init(project="codi_trace", name=f"{stage}-{os.path.basename(args.output_dir)}",
|
| 14 |
+
dir=args.output_dir, config=vars(args))
|
| 15 |
+
return ["wandb"]
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(f"wandb disabled: {e}")
|
| 18 |
+
return []
|