sirui6011 commited on
Commit
aedd6ab
·
verified ·
1 Parent(s): bed155d

add code/ loader snapshot

Browse files
code/data/__init__.py ADDED
File without changes
code/data/dataset.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """
4
+ CRUXEval-O dataset: deterministic train/val split + ground-truth execution
5
+ traces -> (input_ids, labels) for teacher-forcing CODI.
6
+
7
+ Neutral data layer shared by training (``cwm.training.data``) and eval
8
+ (``evals.cruxeval.run_eval_codi``); depends on nothing in either, so the
9
+ split and trace format never drift. Thin HuggingFace-tokenizer wrapper over
10
+ the verbatim Table 9 trace generator (``.ground_truth`` / ``.trace_format``):
11
+ build the seeded prompt, tokenize ``prompt + render_frames_to_generation(frames)``,
12
+ and mask the prompt out of the labels (teacher-forced, so labels == input_ids
13
+ with the prompt prefix set to ``-100``).
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from .ground_truth import ground_truth_trace, make_trace_context
19
+ from .trace_format import (
20
+ ACTION_SEP,
21
+ LINE_SEP,
22
+ TraceEvent,
23
+ render_frames_to_generation,
24
+ )
25
+
26
+ IGNORE_INDEX = -100
27
+ def _prompt_str(code: str, input_str: str) -> str:
28
+ ctx = make_trace_context(code, input_str)
29
+ return f"<|trace_context_start|>{ctx}<|frame_sep|><|call_sep|>{{}}<|action_sep|>def main():\n<|frame_sep|>"
30
+
31
+
32
+ def _tokenize_trace(code, input_str, tokenizer, *, max_seq_len, max_frames):
33
+ """``(prompt_ids, trace_ids, spans)``; None to skip. Trace must terminate in
34
+ RETURN/EXCEPTION and have >=1 LINE span. Span ``(i, j)``: ``trace_ids[i]`` is
35
+ ``<|line_sep|>``, ``j`` its ``<|action_sep|>``, ``trace_ids[i+1:j]`` the locals
36
+ a CODI student swaps for a latent block. Single source of membership so the SFT
37
+ baseline and CODI train on identical data."""
38
+ frames, error = ground_truth_trace(code, input_str, align_to_prompt=True, max_frames=max_frames)
39
+ if not frames or error == "frames_exceeded":
40
+ return None
41
+ if frames[-1].event not in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
42
+ return None
43
+ # Qwen has no BOS (bos_token_id is None); CWM did. Prepend only if present.
44
+ bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
45
+ prompt_ids = bos + tokenizer.encode(_prompt_str(code, input_str), add_special_tokens=False)
46
+ trace_ids = tokenizer.encode(render_frames_to_generation(frames), add_special_tokens=False)
47
+ if len(prompt_ids) + len(trace_ids) > max_seq_len:
48
+ return None
49
+ ls = tokenizer.convert_tokens_to_ids(LINE_SEP)
50
+ asep = tokenizer.convert_tokens_to_ids(ACTION_SEP)
51
+ spans, i, n = [], 0, len(trace_ids)
52
+ while i < n:
53
+ if trace_ids[i] == ls:
54
+ j = i + 1
55
+ while j < n and trace_ids[j] != asep:
56
+ j += 1
57
+ if j == n:
58
+ break
59
+ spans.append((i, j))
60
+ i = j + 1
61
+ else:
62
+ i += 1
63
+ if not spans:
64
+ return None
65
+ return prompt_ids, trace_ids, spans
66
+
67
+
68
+ def build_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
69
+ """SFT ``(input_ids, labels)`` with the prompt masked; None to skip."""
70
+ r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
71
+ if r is None:
72
+ return None
73
+ prompt_ids, trace_ids, _ = r
74
+ return prompt_ids + trace_ids, [IGNORE_INDEX] * len(prompt_ids) + trace_ids
75
+
76
+
77
+ def build_codi_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
78
+ """Multi-span CODI example ``{prompt_ids, trace_ids, spans}``; None to skip."""
79
+ r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
80
+ if r is None:
81
+ return None
82
+ prompt_ids, trace_ids, spans = r
83
+ return {"prompt_ids": prompt_ids, "trace_ids": trace_ids, "spans": spans}
84
+
85
+
86
+ def _load_cache(cache_dir, n_samples):
87
+ """Load precomputed tokenized examples (precompute.py); slice to n_samples."""
88
+ from datasets import load_from_disk
89
+
90
+ ex = list(load_from_disk(cache_dir))
91
+ return ex[:n_samples] if n_samples > 0 else ex
92
+
93
+
94
+ def build_codi_dataset(
95
+ tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
96
+ max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
97
+ ) -> list[dict]:
98
+ """CODI examples (prompt/reasoning/answer) over ``sources``, or a precomputed cache."""
99
+ if cache_dir:
100
+ ex = _load_cache(cache_dir, n_samples)
101
+ return [e for e in ex if len(e["prompt_ids"]) + len(e["trace_ids"]) <= max_seq_len]
102
+ rows = rows_for_sources(sources)
103
+ if n_samples > 0:
104
+ rows = rows[:n_samples]
105
+ out = []
106
+ for r in rows:
107
+ try:
108
+ out.append(build_codi_example(r["code"], r["input"], tokenizer,
109
+ max_seq_len=max_seq_len, max_frames=max_frames))
110
+ except Exception:
111
+ pass
112
+ return [ex for ex in out if ex is not None]
113
+
114
+
115
+ def build_codi_single_dataset(
116
+ tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
117
+ max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
118
+ ) -> list[dict]:
119
+ """Faithful single-block CODI: split each trace at its last ``<|return_sep|>`` into
120
+ ``{prompt_ids, reasoning_ids, answer_ids}`` (reasoning = whole trace, answer = final
121
+ RETURN frame). Derived from the multi-span examples; no separate cache needed."""
122
+ rsep = tokenizer.convert_tokens_to_ids("<|return_sep|>")
123
+ out = []
124
+ for e in build_codi_dataset(tokenizer, sources=sources, n_samples=n_samples,
125
+ max_seq_len=max_seq_len, max_frames=max_frames, cache_dir=cache_dir):
126
+ t = e["trace_ids"]
127
+ idx = [i for i, x in enumerate(t) if x == rsep]
128
+ if not idx or idx[-1] == 0:
129
+ continue
130
+ out.append({"prompt_ids": e["prompt_ids"], "reasoning_ids": t[:idx[-1]], "answer_ids": t[idx[-1]:]})
131
+ return out
132
+
133
+
134
+ def rows_for_sources(sources):
135
+ """Merge {id,code,input,output} rows across sources (all rows; train vs test
136
+ is split by dataset, e.g. cruxeval is held out for eval)."""
137
+ from . import sources as _src
138
+
139
+ rows = []
140
+ for name in sources:
141
+ for i, row in enumerate(_src.load_one(name)):
142
+ missing = [k for k in ("id", "code", "input", "output") if k not in row]
143
+ if missing:
144
+ raise ValueError(f"{name} row {i} missing keys: {missing}")
145
+ if not all(isinstance(row[k], str) for k in ("code", "input", "output")):
146
+ raise TypeError(f"{name} row {i} must use string code/input/output")
147
+ row = dict(row)
148
+ row["id"] = str(row["id"])
149
+ rows.append(row)
150
+ return rows
151
+
152
+
153
+ def build_dataset(
154
+ tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
155
+ max_seq_len: int = 8192, max_frames: int = -1, cache_dir: str | None = None
156
+ ) -> list[tuple[list[int], list[int]]]:
157
+ """Tokenized trace examples over ``sources``, or a precomputed cache."""
158
+ if cache_dir:
159
+ ex = _load_cache(cache_dir, n_samples)
160
+ return [(e["input_ids"], e["labels"]) for e in ex if len(e["input_ids"]) <= max_seq_len]
161
+ rows = rows_for_sources(sources)
162
+ if n_samples > 0:
163
+ rows = rows[:n_samples]
164
+ examples = (
165
+ build_example(
166
+ r["code"], r["input"], tokenizer,
167
+ max_seq_len=max_seq_len, max_frames=max_frames,
168
+ )
169
+ for r in rows
170
+ )
171
+ return [ex for ex in examples if ex is not None]
code/data/ground_truth.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """Ground-truth execution traces in CWM's frame format.
4
+
5
+ Runs ``f(input)`` under ``sys.settrace`` and records CALL/LINE/RETURN/EXCEPTION
6
+ frames with diff-based locals (unchanged vars render as ``".."``), values via
7
+ ``repr``. A synthetic ``def main(): return f(<input>)`` wraps the function; the
8
+ seeded ``call main()`` frame is dropped by default to align with the trace
9
+ prompt. Not a bit-exact replica of Meta's internal tracer (see README.md).
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import linecache
15
+ import sys
16
+ from typing import Any
17
+
18
+ from .trace_format import DIFF_PLACEHOLDER, TraceEvent, TraceFrame, normalize_source
19
+
20
+ _FILENAME = "<cwm_trace>"
21
+ _ENTRY = "main"
22
+
23
+
24
+ class _FramesExceeded(Exception):
25
+ pass
26
+
27
+
28
+ def make_trace_context(code: str, input_str: str) -> str:
29
+ return f"\n{code}\ndef main(): # << START_OF_TRACE\n return f({input_str})\n"
30
+
31
+
32
+ def render_value(value: Any) -> str:
33
+ try:
34
+ return repr(value)
35
+ except Exception:
36
+ return "<unrepr>"
37
+
38
+
39
+ def ground_truth_trace(
40
+ code: str, input_str: str, align_to_prompt: bool = True, max_frames: int = -1
41
+ ) -> tuple[list[TraceFrame], str | None]:
42
+ """Return (frames, error) for executing ``f(input_str)``. ``error`` is
43
+ non-None if the program raised; frames up to that point are still returned.
44
+ """
45
+ context = make_trace_context(code, input_str)
46
+ linecache.cache[_FILENAME] = (len(context), None, context.splitlines(keepends=True), _FILENAME)
47
+
48
+ frames: list[TraceFrame] = []
49
+ scope_prev: dict[int, dict[str, str]] = {} # id(frame) -> last rendered locals
50
+ entry = None
51
+
52
+ def source(frame):
53
+ return normalize_source(linecache.getline(_FILENAME, frame.f_lineno))
54
+
55
+ def diff_locals(frame):
56
+ prev = scope_prev.get(id(frame), {})
57
+ out, rendered = {}, {}
58
+ for name, val in frame.f_locals.items():
59
+ r = render_value(val)
60
+ rendered[name] = r
61
+ out[name] = DIFF_PLACEHOLDER if prev.get(name) == r else r
62
+ scope_prev[id(frame)] = rendered
63
+ return out
64
+
65
+ def trace(frame, event, arg):
66
+ nonlocal entry
67
+ # Abort loop-heavy programs, but only from our file (not GC/__del__ frames).
68
+ if max_frames > 0 and len(frames) >= max_frames and frame.f_code.co_filename == _FILENAME:
69
+ raise _FramesExceeded
70
+ if entry is None:
71
+ if event == "call" and frame.f_code.co_name == _ENTRY:
72
+ entry = id(frame)
73
+ else:
74
+ return None
75
+ # Only trace user code from our context, not library frames.
76
+ if frame.f_code.co_filename != _FILENAME:
77
+ return None
78
+ if event == "call":
79
+ frames.append(TraceFrame(event=TraceEvent.CALL, source=source(frame), locals=diff_locals(frame)))
80
+ elif event == "line":
81
+ frames.append(TraceFrame(event=TraceEvent.LINE, source=source(frame), locals=diff_locals(frame)))
82
+ elif event == "return":
83
+ frames.append(TraceFrame(event=TraceEvent.RETURN, source=source(frame), arg=render_value(arg)))
84
+ elif event == "exception":
85
+ name = getattr(arg[0], "__name__", str(arg[0]))
86
+ frames.append(TraceFrame(event=TraceEvent.EXCEPTION, source=source(frame), arg=render_value(name)))
87
+ return trace
88
+
89
+ ns: dict[str, Any] = {}
90
+ exec(compile(context, _FILENAME, "exec"), ns) # define f, main untraced
91
+ error = None
92
+ old = sys.gettrace()
93
+ sys.settrace(trace)
94
+ try:
95
+ ns[_ENTRY]()
96
+ except _FramesExceeded:
97
+ error = "frames_exceeded"
98
+ except Exception as e:
99
+ error = f"{type(e).__name__}: {e}"
100
+ finally:
101
+ sys.settrace(old)
102
+
103
+ # Drop the seeded ``call main()`` frame so frames align with the prompt.
104
+ if align_to_prompt and frames and frames[0].event == TraceEvent.CALL and frames[0].source.startswith("def main()"):
105
+ frames = frames[1:]
106
+ return frames, error
code/data/sources.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset name(s) -> merged {id, code, input, output} rows.
2
+
3
+ Add a converted dataset by running its folder's convert.py (saves ./data via
4
+ save_to_disk) and listing it in _LOCAL. cruxeval keeps its own Hub-fallback
5
+ loader and is held out entirely for eval (eval_cruxeval_*.py), never trained on.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ from pathlib import Path
12
+
13
+ _LOCAL = {"mbpp": "MBPP", "humaneval": "HumanEval", "pyx": "PyX"} # name -> folder, data in ./data
14
+
15
+
16
+ def load_cruxeval():
17
+ """Full CRUXEval-O (the held-out test set). Prefer a local save_to_disk copy;
18
+ the HF builder FileLock dies on NFS caches."""
19
+ local_dir = os.environ.get("CRUXEVAL_DIR")
20
+ if local_dir and os.path.isdir(local_dir):
21
+ from datasets import load_from_disk
22
+
23
+ return list(load_from_disk(local_dir))
24
+ from datasets import load_dataset
25
+
26
+ return list(load_dataset("cruxeval-org/cruxeval", split="test"))
27
+
28
+
29
+ def load_one(name: str) -> list[dict]:
30
+ key = name.strip().lower()
31
+ if key == "cruxeval":
32
+ return load_cruxeval()
33
+ if key in _LOCAL:
34
+ from datasets import load_from_disk
35
+
36
+ d = os.environ.get(key.upper() + "_DIR") or str(Path(__file__).parent / _LOCAL[key] / "data")
37
+ return list(load_from_disk(d))
38
+ raise ValueError(f"unknown data source {name!r}; pick from {['cruxeval', *_LOCAL]}")
code/data/trace_format.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ """
4
+ Shared CWM execution-trace representation and parsing.
5
+
6
+ CWM predicts an execution trace as a sequence of *frames*, each consisting of an
7
+ *observation* (the local-variable state) and an *action* (the executed source
8
+ line). The on-the-wire format (see PROMPTING_GUIDE.md and demos/cwmdbg.py) is:
9
+
10
+ <|call_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
11
+ <|line_sep|>$LOCALS<|action_sep|>$SOURCE<|frame_sep|>
12
+ <|return_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
13
+ <|exception_sep|><|action_sep|>$SOURCE<|arg_sep|>$VALUE<|frame_sep|>
14
+
15
+ `$LOCALS` is a JSON object mapping variable names to *string* values; each value
16
+ is the JSON encoding of the underlying Python value (e.g. `"5"`, `"\"abc\""`,
17
+ `"[1, 2]"`). Locals use a diff-based representation: a variable whose value is
18
+ unchanged since the previous frame in the same scope is rendered as the
19
+ placeholder string `".."`. `$VALUE` (return/exception frames) is the JSON
20
+ encoding of the returned/raised value, stored as a JSON string.
21
+
22
+ This module is GPU-free and import-light so it can be unit-tested directly.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import json
28
+ from dataclasses import dataclass, field
29
+ from enum import Enum
30
+
31
+ # Literal piece strings as they appear when a generation is decoded with
32
+ # cut_at_stop_tokens=False (matches CWMInstructTokenizer.*_ID constants).
33
+ CALL_SEP = "<|call_sep|>"
34
+ LINE_SEP = "<|line_sep|>"
35
+ RETURN_SEP = "<|return_sep|>"
36
+ EXCEPTION_SEP = "<|exception_sep|>"
37
+ ACTION_SEP = "<|action_sep|>"
38
+ ARG_SEP = "<|arg_sep|>"
39
+ FRAME_SEP = "<|frame_sep|>"
40
+ END_OF_TEXT = "<|end_of_text|>"
41
+
42
+ DIFF_PLACEHOLDER = ".."
43
+ _START_MARKER = " # << START_OF_TRACE"
44
+
45
+
46
+ class TraceEvent(Enum):
47
+ CALL = "call"
48
+ LINE = "line"
49
+ RETURN = "return"
50
+ EXCEPTION = "exception"
51
+
52
+
53
+ _EVENT_TOKENS: dict[str, TraceEvent] = {
54
+ CALL_SEP: TraceEvent.CALL,
55
+ LINE_SEP: TraceEvent.LINE,
56
+ RETURN_SEP: TraceEvent.RETURN,
57
+ EXCEPTION_SEP: TraceEvent.EXCEPTION,
58
+ }
59
+ _EVENT_TO_TOKEN: dict[TraceEvent, str] = {v: k for k, v in _EVENT_TOKENS.items()}
60
+
61
+
62
+ @dataclass
63
+ class TraceFrame:
64
+ """A single execution-trace frame.
65
+
66
+ `locals_str` is the raw `$LOCALS` text exactly as it appears between the
67
+ event token and `<|action_sep|>` (empty string for return/exception
68
+ frames). `locals` is its parsed form (a dict of name -> JSON-string-value),
69
+ or None if it failed to parse as a JSON object. `source` is the action line
70
+ with the START_OF_TRACE marker and trailing newline stripped.
71
+ """
72
+
73
+ event: TraceEvent
74
+ source: str
75
+ locals_str: str = ""
76
+ locals: dict[str, str] | None = None
77
+ arg: str | None = None
78
+ malformed: bool = False
79
+ # Token counts (filled when a tokenizer is available); used for the
80
+ # "Avg State/Action Length (Token)" statistics rows of Table 9.
81
+ state_tokens: int = 0
82
+ action_tokens: int = 0
83
+
84
+ @property
85
+ def has_locals(self) -> bool:
86
+ return self.event in (TraceEvent.CALL, TraceEvent.LINE)
87
+
88
+
89
+ def normalize_source(source: str) -> str:
90
+ """Strip the trace start marker and trailing newline from a source line."""
91
+ return source.rstrip("\n").rstrip(_START_MARKER).rstrip()
92
+
93
+
94
+ def parse_locals(locals_str: str) -> dict[str, str] | None:
95
+ """Parse a `$LOCALS` payload into a dict, or None if it is not a JSON object."""
96
+ locals_str = locals_str.strip()
97
+ if locals_str == "":
98
+ return {}
99
+ try:
100
+ obj = json.loads(locals_str)
101
+ except json.JSONDecodeError:
102
+ return None
103
+ if not isinstance(obj, dict):
104
+ return None
105
+ # Values are always JSON strings; coerce defensively.
106
+ return {str(k): v if isinstance(v, str) else json.dumps(v) for k, v in obj.items()}
107
+
108
+
109
+ def parse_generated_trace(generation: str) -> tuple[list[TraceFrame], bool]:
110
+ """Parse a full-trace generation string into frames.
111
+
112
+ Returns (frames, well_formed). `well_formed` is True when every frame had a
113
+ leading event token and an `<|action_sep|>` (and an `<|arg_sep|>` for
114
+ return/exception frames) and the generation contained no leftover garbage
115
+ between the last frame and end-of-text. This drives the "Valid Trace Format"
116
+ metric. Individual frames are still returned even when malformed so that the
117
+ other metrics can be computed over whatever parsed cleanly.
118
+ """
119
+ # Everything after end-of-text is irrelevant.
120
+ if END_OF_TEXT in generation:
121
+ generation = generation.split(END_OF_TEXT, 1)[0]
122
+
123
+ frames: list[TraceFrame] = []
124
+ well_formed = True
125
+ segments = generation.split(FRAME_SEP)
126
+ # The final segment is the text after the last frame_sep; for a clean trace
127
+ # it should be empty (the model emitted frame_sep then end_of_text).
128
+ trailing = segments.pop() if segments else ""
129
+ if trailing.strip() not in ("",):
130
+ well_formed = False
131
+
132
+ for seg in segments:
133
+ if seg.strip() == "":
134
+ # Stray empty segment (e.g. leading text before first token).
135
+ continue
136
+ frame, ok = _parse_segment(seg)
137
+ if frame is None:
138
+ well_formed = False
139
+ continue
140
+ well_formed = well_formed and ok
141
+ frames.append(frame)
142
+
143
+ if not frames:
144
+ well_formed = False
145
+
146
+ return frames, well_formed
147
+
148
+
149
+ def _parse_segment(seg: str) -> tuple[TraceFrame | None, bool]:
150
+ # Identify the (first) event token.
151
+ event: TraceEvent | None = None
152
+ for tok, evt in _EVENT_TOKENS.items():
153
+ idx = seg.find(tok)
154
+ if idx != -1:
155
+ event = evt
156
+ seg = seg[idx + len(tok):]
157
+ break
158
+ if event is None:
159
+ return None, False
160
+
161
+ ok = True
162
+ if event in (TraceEvent.CALL, TraceEvent.LINE):
163
+ if ACTION_SEP not in seg:
164
+ return (
165
+ TraceFrame(event=event, source="", malformed=True),
166
+ False,
167
+ )
168
+ locals_str, source = seg.split(ACTION_SEP, 1)
169
+ parsed = parse_locals(locals_str)
170
+ return (
171
+ TraceFrame(
172
+ event=event,
173
+ source=normalize_source(source),
174
+ locals_str=locals_str.strip(),
175
+ locals=parsed,
176
+ malformed=parsed is None,
177
+ ),
178
+ ok,
179
+ )
180
+
181
+ # RETURN / EXCEPTION
182
+ if ACTION_SEP not in seg:
183
+ return TraceFrame(event=event, source="", malformed=True), False
184
+ seg = seg.split(ACTION_SEP, 1)[1]
185
+ if ARG_SEP in seg:
186
+ source, arg = seg.split(ARG_SEP, 1)
187
+ arg = _parse_arg(arg)
188
+ else:
189
+ source, arg = seg, None
190
+ ok = False
191
+ return (
192
+ TraceFrame(event=event, source=normalize_source(source), arg=arg),
193
+ ok,
194
+ )
195
+
196
+
197
+ def render_frames_to_generation(frames: list[TraceFrame]) -> str:
198
+ """Render frames back to the on-the-wire generation string.
199
+
200
+ Inverse of ``parse_generated_trace`` for well-formed frames. Used by tests
201
+ (a ground-truth trace rendered this way must round-trip to a perfect score)
202
+ and to materialize a reference trace string for inspection.
203
+ """
204
+ out: list[str] = []
205
+ for f in frames:
206
+ out.append(_EVENT_TO_TOKEN[f.event])
207
+ if f.has_locals:
208
+ out.append(json.dumps(f.locals if f.locals is not None else {}))
209
+ out.append(ACTION_SEP)
210
+ out.append(f.source)
211
+ if f.event in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
212
+ out.append(ARG_SEP)
213
+ out.append(json.dumps(f.arg))
214
+ out.append(FRAME_SEP)
215
+ out.append(END_OF_TEXT)
216
+ return "".join(out)
217
+
218
+
219
+ def _parse_arg(arg_str: str) -> str | None:
220
+ arg_str = arg_str.strip()
221
+ if arg_str == "":
222
+ return None
223
+ try:
224
+ # The frame stores json.dumps(value_string); unwrap one level so `arg`
225
+ # is the source-literal value string (e.g. '"x9ja"' or '17').
226
+ loaded = json.loads(arg_str)
227
+ return loaded if isinstance(loaded, str) else arg_str
228
+ except json.JSONDecodeError:
229
+ return arg_str
code/eval/__init__.py ADDED
File without changes
code/eval/eval_cruxeval_codi.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CRUXEval-O latent eval: the CODI student generates the trace, but at every
2
+ <|line_sep|> the frame's $LOCALS is replaced by a latent block (latent_start +
3
+ latent_steps recurrent latents + latent_end), mirroring training _student.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ from datetime import timedelta
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+ from data.dataset import _prompt_str
16
+ from data.sources import load_cruxeval
17
+ from eval.eval_cruxeval_sft import check_correct, extract_answer_trace_full
18
+ from tokens import add_trace_tokens, token_ids
19
+ from train.train_codi import CodiModel
20
+
21
+
22
+ def load_codi(m, latent_steps, dev):
23
+ tok = AutoTokenizer.from_pretrained(m, use_fast=True)
24
+ add_trace_tokens(tok)
25
+ ids = token_ids(tok)
26
+ base = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(m), torch_dtype=torch.bfloat16)
27
+ model = CodiModel(base, latent_start_id=ids["<|latent_start|>"],
28
+ latent_end_id=ids["<|latent_end|>"], latent_steps=latent_steps)
29
+ if os.path.exists(f"{m}/pytorch_model.bin"): # epoch checkpoint: full CodiModel
30
+ model.load_state_dict(torch.load(f"{m}/pytorch_model.bin", map_location="cpu"))
31
+ else: # final export: backbone safetensors + separate projector
32
+ model.model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.bfloat16)
33
+ model.prj.load_state_dict(torch.load(f"{m}/thought_projector.pt", map_location="cpu"))
34
+ return tok, ids, model.to(dev).eval()
35
+
36
+
37
+ @torch.no_grad()
38
+ def gen_latent(model, prompt_ids, ls_id, eot, max_new):
39
+ dev = prompt_ids.device
40
+ o = model.model(input_ids=prompt_ids[None], use_cache=True)
41
+ cache, logits = o.past_key_values, o.logits[:, -1]
42
+ out = []
43
+ for _ in range(max_new):
44
+ t = int(logits.argmax(-1))
45
+ if t == eot:
46
+ break
47
+ out.append(t)
48
+ o = model.model(input_ids=torch.tensor([[t]], device=dev), past_key_values=cache, use_cache=True)
49
+ cache = o.past_key_values
50
+ if t == ls_id: # drop $LOCALS, insert latent block; its logits predict <|action_sep|>
51
+ cache, logits = model._latent_block(cache)
52
+ else:
53
+ logits = o.logits[:, -1]
54
+ return out
55
+
56
+
57
+ def main():
58
+ ap = argparse.ArgumentParser()
59
+ ap.add_argument("--model", required=True)
60
+ ap.add_argument("--n_samples", type=int, default=-1)
61
+ ap.add_argument("--max_new_tokens", type=int, default=8192)
62
+ ap.add_argument("--latent_steps", type=int, default=1)
63
+ ap.add_argument("--out", default="")
64
+ args = ap.parse_args()
65
+
66
+ ddp = "RANK" in os.environ
67
+ rank = int(os.environ.get("RANK", 0))
68
+ world = int(os.environ.get("WORLD_SIZE", 1))
69
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
70
+ if ddp:
71
+ dist.init_process_group("nccl", timeout=timedelta(hours=1)) # ranks finish at different times under long gens
72
+ torch.cuda.set_device(local_rank)
73
+
74
+ tok, ids, model = load_codi(args.model, args.latent_steps, local_rank)
75
+ ls_id, eot = ids["<|line_sep|>"], ids["<|end_of_text|>"]
76
+
77
+ rows = load_cruxeval()
78
+ if args.n_samples > 0:
79
+ rows = rows[: args.n_samples]
80
+ n = len(rows)
81
+ shard = rows[rank::world]
82
+
83
+ n_correct = n_fmt = 0
84
+ results = []
85
+ for i, r in enumerate(shard):
86
+ enc = tok(_prompt_str(r["code"], r["input"]), return_tensors="pt",
87
+ add_special_tokens=False).to(local_rank)
88
+ gen = tok.decode(gen_latent(model, enc["input_ids"][0], ls_id, eot, args.max_new_tokens),
89
+ skip_special_tokens=False)
90
+ pred = extract_answer_trace_full(gen)
91
+ ok = pred is not None and check_correct(r["code"], r["output"], pred)
92
+ n_fmt += pred is not None
93
+ n_correct += ok
94
+ results.append({"id": r["id"], "expected": r["output"], "predicted": pred, "correct": ok, "generation": gen})
95
+ if rank == 0 and (i + 1) % 20 == 0:
96
+ print(f" rank0 {i+1}/{len(shard)} pass@1={n_correct/(i+1):.4f}", flush=True)
97
+
98
+ if ddp:
99
+ t = torch.tensor([n_correct, n_fmt], device=local_rank)
100
+ dist.all_reduce(t)
101
+ n_correct, n_fmt = int(t[0]), int(t[1])
102
+ gathered = [None] * world
103
+ dist.gather_object(results, gathered if rank == 0 else None, dst=0)
104
+ if rank == 0:
105
+ results = [x for part in gathered for x in part]
106
+
107
+ if rank == 0:
108
+ print(f"\nCRUXEval-O latent pass@1={n_correct / n:.4f} "
109
+ f"valid_format={n_fmt / n:.4f} (n={n}, greedy)")
110
+ if args.out:
111
+ with open(args.out, "w") as f:
112
+ json.dump({"pass_at_1": n_correct / n, "valid_format": n_fmt / n,
113
+ "n": n, "results": results}, f, indent=2)
114
+ if ddp:
115
+ dist.destroy_process_group()
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
code/eval/eval_cruxeval_sft.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stage 1 baseline eval: CRUXEval-O output prediction via full-trace generation.
2
+
3
+ Feed the training prompt (seeds frame 0), let the SFT model generate the trace,
4
+ take main()'s last return value as the predicted output, score by execution.
5
+ Greedy => pass@1 is the exact-match fraction. Reuses cwm_andre eval logic.
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+ import os
11
+ import subprocess
12
+ import sys
13
+ from datetime import timedelta
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer
18
+
19
+ from data.dataset import _prompt_str
20
+ from data.sources import load_cruxeval
21
+ from tokens import add_trace_tokens, token_ids
22
+
23
+ ARG_SEP, FRAME_SEP, RETURN_SEP = "<|arg_sep|>", "<|frame_sep|>", "<|return_sep|>"
24
+
25
+
26
+ def extract_answer_trace_full(gen: str) -> str | None:
27
+ """Value of main()'s last RETURN frame: ...<|arg_sep|>"value"<|frame_sep|>."""
28
+ r = gen.rfind(RETURN_SEP)
29
+ if r == -1:
30
+ return None
31
+ a = gen.find(ARG_SEP, r)
32
+ if a == -1:
33
+ return None
34
+ rest = gen[a + len(ARG_SEP):]
35
+ end = rest.find(FRAME_SEP)
36
+ val = (rest[:end] if end != -1 else rest).strip()
37
+ if not val:
38
+ return None
39
+ try:
40
+ return json.loads(val)
41
+ except json.JSONDecodeError:
42
+ return val
43
+
44
+
45
+ def check_correct(code: str, expected: str, predicted: str, timeout: float = 3.0) -> bool:
46
+ """Execute `code; assert expected == predicted` (CRUXEval semantics)."""
47
+ test = f"{code}\nassert {expected} == {predicted}"
48
+ try:
49
+ return subprocess.run(
50
+ [sys.executable, "-c", test], timeout=timeout, capture_output=True
51
+ ).returncode == 0
52
+ except Exception:
53
+ return False
54
+
55
+
56
+ def main():
57
+ ap = argparse.ArgumentParser()
58
+ ap.add_argument("--model", required=True)
59
+ ap.add_argument("--n_samples", type=int, default=-1)
60
+ ap.add_argument("--max_new_tokens", type=int, default=8192)
61
+ ap.add_argument("--batch_size", type=int, default=8)
62
+ ap.add_argument("--out", default="")
63
+ args = ap.parse_args()
64
+
65
+ # DDP-style data parallelism for inference: torchrun sets RANK/WORLD_SIZE/LOCAL_RANK.
66
+ ddp = "RANK" in os.environ
67
+ rank = int(os.environ.get("RANK", 0))
68
+ world = int(os.environ.get("WORLD_SIZE", 1))
69
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
70
+ if ddp:
71
+ dist.init_process_group("nccl", timeout=timedelta(hours=1)) # ranks finish at different times under long gens
72
+ torch.cuda.set_device(local_rank)
73
+
74
+ tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
75
+ add_trace_tokens(tok) # idempotent; ensures trace tokens present
76
+ tok.padding_side = "left" # left-pad so all generated tokens start at the same offset
77
+ eot_id = token_ids(tok)["<|end_of_text|>"]
78
+ model = AutoModelForCausalLM.from_pretrained(
79
+ args.model, torch_dtype=torch.bfloat16).to(local_rank).eval()
80
+
81
+ rows = load_cruxeval()
82
+ if args.n_samples > 0:
83
+ rows = rows[: args.n_samples]
84
+ n = len(rows)
85
+ shard = rows[rank::world] # disjoint round-robin split across ranks
86
+
87
+ n_correct = n_fmt = 0
88
+ results = []
89
+ for bi, batch_start in enumerate(range(0, len(shard), args.batch_size)):
90
+ batch = shard[batch_start: batch_start + args.batch_size]
91
+ enc = tok([_prompt_str(r["code"], r["input"]) for r in batch],
92
+ return_tensors="pt", padding=True, add_special_tokens=False).to(local_rank)
93
+ with torch.no_grad():
94
+ out = model.generate(**enc, max_new_tokens=args.max_new_tokens, do_sample=False,
95
+ eos_token_id=eot_id, pad_token_id=eot_id)
96
+ for j, r in enumerate(batch):
97
+ gen = tok.decode(out[j, enc["input_ids"].shape[1]:], skip_special_tokens=False)
98
+ pred = extract_answer_trace_full(gen)
99
+ ok = pred is not None and check_correct(r["code"], r["output"], pred)
100
+ n_fmt += pred is not None
101
+ n_correct += ok
102
+ results.append({"id": r["id"], "expected": r["output"], "predicted": pred, "correct": ok, "generation": gen})
103
+ if rank == 0 and (bi + 1) % 5 == 0:
104
+ done = batch_start + len(batch)
105
+ print(f" rank0 {done}/{len(shard)} pass@1={n_correct/done:.4f}", flush=True)
106
+
107
+ # Reduce metrics and gather per-row results across ranks.
108
+ if ddp:
109
+ t = torch.tensor([n_correct, n_fmt], device=local_rank)
110
+ dist.all_reduce(t)
111
+ n_correct, n_fmt = int(t[0]), int(t[1])
112
+ gathered = [None] * world
113
+ dist.gather_object(results, gathered if rank == 0 else None, dst=0)
114
+ if rank == 0:
115
+ results = [x for part in gathered for x in part]
116
+
117
+ if rank == 0:
118
+ print(f"\nCRUXEval-O pass@1={n_correct / n:.4f} "
119
+ f"valid_format={n_fmt / n:.4f} (n={n}, greedy)")
120
+ if args.out:
121
+ with open(args.out, "w") as f:
122
+ json.dump({"pass_at_1": n_correct / n, "valid_format": n_fmt / n,
123
+ "n": n, "results": results}, f, indent=2)
124
+ if ddp:
125
+ dist.destroy_process_group()
126
+
127
+
128
+ if __name__ == "__main__":
129
+ main()
code/tokens.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CWM trace special tokens + tokenizer/embedding setup for a non-CWM base."""
2
+
3
+ # Trace-format tokens (mirrors data/trace_format.py) + latent delimiters.
4
+ TRACE_TOKENS = [
5
+ "<|trace_context_start|>",
6
+ "<|call_sep|>", "<|line_sep|>", "<|return_sep|>", "<|exception_sep|>",
7
+ "<|action_sep|>", "<|arg_sep|>", "<|frame_sep|>", "<|end_of_text|>",
8
+ "<|latent_start|>", "<|latent_end|>",
9
+ ]
10
+
11
+
12
+ def add_trace_tokens(tokenizer) -> int:
13
+ """Add the trace tokens as special tokens. Returns the count newly added."""
14
+ return tokenizer.add_tokens(TRACE_TOKENS, special_tokens=True)
15
+
16
+
17
+ def resize_and_init(model, tokenizer, n_added: int) -> None:
18
+ """Resize embeddings to the tokenizer; init new rows to the existing mean."""
19
+ old = model.get_input_embeddings().weight.shape[0]
20
+ model.resize_token_embeddings(len(tokenizer))
21
+ if n_added <= 0:
22
+ return
23
+ seen = set()
24
+ for emb in (model.get_input_embeddings(), model.get_output_embeddings()):
25
+ if emb is None or id(emb) in seen: # tied embeddings: resize once
26
+ continue
27
+ seen.add(id(emb))
28
+ w = emb.weight.data
29
+ w[old:] = w[:old].mean(dim=0, keepdim=True)
30
+
31
+
32
+ def token_ids(tokenizer) -> dict[str, int]:
33
+ """Map each trace token to its single id (asserts single-token encoding)."""
34
+ ids = {}
35
+ for t in TRACE_TOKENS:
36
+ enc = tokenizer.encode(t, add_special_tokens=False)
37
+ assert len(enc) == 1, f"{t!r} did not encode to a single id: {enc}"
38
+ ids[t] = enc[0]
39
+ return ids
code/train/__init__.py ADDED
File without changes
code/train/train_codi.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stage 2b: per-frame CODI self-distillation (multi-span).
2
+
3
+ Shared-weight teacher+student initialized from the Stage-1 SFT model.
4
+ - Teacher reads the full explicit trace (prompt+trace), CE = L_teacher.
5
+ - Student replaces each LINE frame's $LOCALS with a latent block (latent_start +
6
+ `latent_steps` recurrent latents + latent_end; last hidden -> prj -> next embed)
7
+ and teacher-forces the rest, CE = L_student over the emitted (non-locals) text.
8
+ - KD aligns the hidden at each frame's `<|action_sep|>` (student after latents vs
9
+ teacher after locals), teacher detached. L = a*Lt + b*Ls + g*Lkd.
10
+ """
11
+
12
+ import argparse
13
+ import os
14
+ import random
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, Trainer, TrainingArguments
20
+ from transformers.trainer_utils import get_last_checkpoint
21
+ from transformers.utils import WEIGHTS_NAME
22
+
23
+ from data.dataset import IGNORE_INDEX, build_codi_dataset
24
+ from tokens import add_trace_tokens, token_ids
25
+ from wb import wandb_init
26
+
27
+
28
+ class CodiModel(nn.Module):
29
+ def __init__(self, base, *, latent_start_id, latent_end_id, latent_steps,
30
+ a=1.0, b=1.0, g=1.0, kd_layers=None, single_anchor=False,
31
+ ss_prob=0.0, ss_ramp_frac=0.5, teacher=None, kd_target="hidden", kd_temp=2.0,
32
+ line_sep_id=None, recon_w=0.0):
33
+ super().__init__()
34
+ self.model = base
35
+ h = base.config.hidden_size
36
+ # CODI thought projector (last hidden -> next latent input).
37
+ self.prj = nn.Sequential(
38
+ nn.Linear(h, h, bias=False), nn.GELU(),
39
+ nn.Linear(h, h, bias=False), nn.LayerNorm(h),
40
+ )
41
+ ref = base.get_input_embeddings().weight
42
+ self.prj.to(device=ref.device, dtype=ref.dtype)
43
+ self.latent_steps, self.a, self.b, self.g = latent_steps, a, b, g
44
+ self.teacher = [teacher] if teacher is not None else None # list -> hidden from state_dict/DDP/optim
45
+ self.kd_target, self.kd_temp = kd_target, kd_temp # hidden: smooth_l1 on kd_layers; logit: KL on lm_head
46
+ if kd_target == "logit" or (teacher is not None and kd_layers is None):
47
+ kd_layers = [-1] # logit KD is defined on the last layer only; frozen default = key (last) hidden
48
+ self.kd_layers = kd_layers # None -> all layers
49
+ self.single_anchor = single_anchor # KD at last span only (vanilla-CODI ablation)
50
+ # scheduled sampling: ss_p (ramped per step) of post-latent lines feed the student's own argmax
51
+ self.ss_prob, self.ss_ramp_frac, self.ss_p = ss_prob, ss_ramp_frac, 0.0
52
+ self.register_buffer("_ls_tok", torch.tensor([[latent_start_id]], dtype=torch.long), persistent=False)
53
+ self.register_buffer("_le_tok", torch.tensor([[latent_end_id]], dtype=torch.long), persistent=False)
54
+ self.body = base.model
55
+ self.head = base.lm_head
56
+
57
+ def _kd(self, hs):
58
+ return hs[1:] if self.kd_layers is None else tuple(hs[l] for l in self.kd_layers)
59
+
60
+ def _emb(self, ids):
61
+ return self.model.get_input_embeddings()(ids)
62
+
63
+ def _teacher(self, full_ids, labels, kd_pos):
64
+ pos = torch.tensor(kd_pos, device=full_ids.device)
65
+ if self.teacher is not None: # frozen teacher: KD targets only, no teacher CE
66
+ tch, dev = self.teacher[0], full_ids.device
67
+ if next(tch.parameters()).device != dev:
68
+ tch.to(dev)
69
+ with torch.no_grad():
70
+ if self.kd_target == "logit": # target = teacher's own next-token logits
71
+ return None, [tch(input_ids=full_ids[None], use_cache=False).logits[0, pos]]
72
+ hs = tch(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
73
+ return None, [l[0, pos] for l in self._kd(hs)]
74
+ with torch.no_grad(): # KD targets are detached; take hiddens without a backward graph
75
+ hs = self.model(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
76
+ kd = [l[0, pos] for l in self._kd(hs)]
77
+ # CE forward without output_hidden_states so grad-checkpointing actually frees layer acts.
78
+ self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
79
+ logits = self.model(input_ids=full_ids[None], use_cache=False).logits
80
+ self.model.gradient_checkpointing_disable() # teacher-only; student keeps KV cache
81
+ ce = F.cross_entropy(logits[0, :-1], labels[1:], ignore_index=IGNORE_INDEX)
82
+ return ce, kd
83
+
84
+ def _latent_block(self, cache):
85
+ """latent_start + `latent_steps` recurrent latents + latent_end on top of
86
+ `cache`. Returns (new cache, logits predicting the next real token)."""
87
+ o = self.body(inputs_embeds=self._emb(self._ls_tok), past_key_values=cache, use_cache=True)
88
+ cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
89
+ for _ in range(self.latent_steps):
90
+ o = self.body(inputs_embeds=self.prj(h), past_key_values=cache, use_cache=True)
91
+ cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
92
+ o = self.body(inputs_embeds=self._emb(self._le_tok), past_key_values=cache, use_cache=True)
93
+ return o.past_key_values, self.head(o.last_hidden_state[:, -1])
94
+
95
+ def _student(self, prompt_ids, trace_ids, spans):
96
+ # Segments cover trace_ids in order; locals (trace_ids[i+1:j]) are dropped
97
+ # and replaced by a latent block. kd=True marks a frame's <|action_sep|>.
98
+ segs, prev, kd = [], 0, False
99
+ for i, j in spans:
100
+ segs.append(("text", trace_ids[prev:i + 1], kd))
101
+ segs.append(("latent", None, False))
102
+ prev, kd = j, True
103
+ segs.append(("text", trace_ids[prev:], kd))
104
+ last = len(segs) - 1
105
+
106
+ out = self.model(inputs_embeds=self._emb(prompt_ids[None]), use_cache=True)
107
+ cache, prev_logits = out.past_key_values, out.logits[:, -1] # predicts trace_ids[0]
108
+ ce_logits, ce_targets, kd_vecs = [], [], []
109
+ for s, (kind, ids, kd) in enumerate(segs):
110
+ if kind == "latent": # prev_logits predicted dropped locals; overwrite, no CE
111
+ cache, prev_logits = self._latent_block(cache)
112
+ continue
113
+ inp = ids
114
+ if kd and 0 < self.ss_p and random.random() < self.ss_p:
115
+ # scheduled sampling: replace the code (not action_sep / line_sep) with the student's own
116
+ # argmax via a no-grad pass on a detached cache clone; CE targets below stay GT.
117
+ end = ids.numel() if s == last else ids.numel() - 1
118
+ c = DynamicCache()
119
+ for i, ly in enumerate(cache.layers):
120
+ c.update(ly.keys.detach(), ly.values.detach(), i)
121
+ with torch.no_grad():
122
+ pred = self.model(inputs_embeds=self._emb(ids[None]), past_key_values=c, use_cache=True).logits[0].argmax(-1)
123
+ inp = ids.clone(); inp[1:end] = pred[:end - 1]
124
+ ce_logits.append(prev_logits); ce_targets.append(ids[:1])
125
+ out = self.model(inputs_embeds=self._emb(inp[None]), past_key_values=cache,
126
+ use_cache=True, output_hidden_states=kd) # hiddens only for KD anchors
127
+ cache, logits = out.past_key_values, out.logits[0]
128
+ if ids.numel() > 1:
129
+ ce_logits.append(logits[:-1]); ce_targets.append(ids[1:])
130
+ prev_logits = logits[-1:]
131
+ if kd: # action_sep is this segment's first token
132
+ kd_vecs.append([hs[0, 0] for hs in self._kd(out.hidden_states)])
133
+ ce = F.cross_entropy(torch.cat(ce_logits), torch.cat(ce_targets))
134
+ s_kd = [torch.stack([v[l] for v in kd_vecs]) for l in range(len(kd_vecs[0]))]
135
+ return ce, s_kd
136
+
137
+ def _kd_loss(self, s_kd, t_kd):
138
+ s, t = torch.stack(s_kd), torch.stack(t_kd).detach()
139
+ if self.kd_target == "logit": # s=student hidden, t=frozen-teacher logits; KL on distributions
140
+ T = self.kd_temp
141
+ sl, tl = self.head(s).flatten(0, -2) / T, t.flatten(0, -2) / T
142
+ return F.kl_div(F.log_softmax(sl, -1), F.softmax(tl, -1), reduction="batchmean") * T * T
143
+ return F.smooth_l1_loss(s, t)
144
+
145
+ def forward(self, examples):
146
+ dev = self.model.get_input_embeddings().weight.device
147
+ tl = sl = kl = 0.0
148
+ for ex in examples:
149
+ prompt = torch.tensor(ex["prompt_ids"], device=dev)
150
+ trace = torch.tensor(ex["trace_ids"], device=dev)
151
+ spans = ex["spans"]
152
+ full = torch.cat([prompt, trace])
153
+ labels = None if self.teacher else torch.cat([full.new_full((len(prompt),), IGNORE_INDEX), trace])
154
+ kd_pos = [len(prompt) + j for _, j in spans]
155
+ t_ce, t_kd = self._teacher(full, labels, kd_pos)
156
+ s_ce, s_kd = self._student(prompt, trace, spans)
157
+ if self.single_anchor: # keep only the last frame's anchor (per layer)
158
+ t_kd, s_kd = [t[-1:] for t in t_kd], [s[-1:] for s in s_kd]
159
+ tl = tl + (t_ce if t_ce is not None else 0.0) # frozen teacher -> no teacher CE
160
+ sl, kl = sl + s_ce, kl + self._kd_loss(s_kd, t_kd)
161
+ n = len(examples)
162
+ loss = self.a * tl / n + self.b * sl / n + self.g * kl / n
163
+ t_log = (tl / n).detach() if torch.is_tensor(tl) else torch.tensor(0.0) # 0 under frozen teacher
164
+ return {"loss": loss, "teacher_loss": t_log,
165
+ "student_loss": (sl / n).detach(), "kd_loss": (kl / n).detach()}
166
+
167
+
168
+ class CodiTrainer(Trainer):
169
+ def compute_loss(self, model, inputs, return_outputs=False, **kw):
170
+ core = model.module if hasattr(model, "module") else model
171
+ if core.ss_prob: # linear ramp 0 -> ss_prob over the first ss_ramp_frac of training
172
+ core.ss_p = self._ss = core.ss_prob * min(1.0, self.state.global_step / max(1.0, core.ss_ramp_frac * self.state.max_steps))
173
+ out = model(inputs["examples"])
174
+ self._sub = {k: out[k].detach() for k in ("teacher_loss", "student_loss", "kd_loss")}
175
+ return (out["loss"], out) if return_outputs else out["loss"]
176
+
177
+ def log(self, logs, *a, **k): # surface sub-losses to console + wandb
178
+ if hasattr(self, "_sub"):
179
+ logs.update({k: v.item() for k, v in self._sub.items()})
180
+ if hasattr(self, "_ss"):
181
+ logs["ss_p"] = self._ss
182
+ super().log(logs, *a, **k)
183
+
184
+ def _save(self, output_dir=None, state_dict=None):
185
+ # tied backbone weights -> safetensors (5.x default) rejects shared tensors; torch.save instead.
186
+ output_dir = output_dir or self.args.output_dir
187
+ os.makedirs(output_dir, exist_ok=True)
188
+ torch.save(state_dict or self.model.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
189
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
190
+ # also write config/tokenizer/projector so each ckpt is eval-loadable (small, no weight dup).
191
+ self.model.model.config.save_pretrained(output_dir)
192
+ self.tok.save_pretrained(output_dir)
193
+ torch.save(self.model.prj.state_dict(), os.path.join(output_dir, "thought_projector.pt"))
194
+
195
+
196
+ def main():
197
+ ap = argparse.ArgumentParser()
198
+ ap.add_argument("--model", required=True) # Stage-1 SFT dir
199
+ ap.add_argument("--output_dir", required=True)
200
+ ap.add_argument("--sources", nargs="+", default=["mbpp", "humaneval", "pyx"])
201
+ ap.add_argument("--cache_dir", default="data/cache/codi_train") # offline tokenized examples from precompute.py
202
+ ap.add_argument("--n_samples", type=int, default=-1)
203
+ ap.add_argument("--max_seq_len", type=int, default=4096)
204
+ ap.add_argument("--max_frames", type=int, default=-1)
205
+ ap.add_argument("--latent_steps", type=int, default=1)
206
+ ap.add_argument("--epochs", type=float, default=10.0)
207
+ ap.add_argument("--lr", type=float, default=1e-5)
208
+ ap.add_argument("--batch_size", type=int, default=1)
209
+ ap.add_argument("--grad_accum", type=int, default=4)
210
+ ap.add_argument("--max_steps", type=int, default=-1)
211
+ ap.add_argument("--save_steps", type=int, default=500)
212
+ ap.add_argument("--alpha", type=float, default=1.0)
213
+ ap.add_argument("--beta", type=float, default=1.0)
214
+ ap.add_argument("--gamma", type=float, default=1.0)
215
+ ap.add_argument("--kd_layers", nargs="+", type=int, default=None) # default: all layers (frozen -> last)
216
+ ap.add_argument("--frozen_teacher", default="") # path to frozen SFT teacher; "" -> shared-weight (legacy)
217
+ ap.add_argument("--kd_target", default="hidden", choices=["hidden", "logit"]) # key-hidden align: smooth_l1 vs KL
218
+ ap.add_argument("--kd_temp", type=float, default=2.0) # logit-KD temperature
219
+ ap.add_argument("--single_anchor", action="store_true") # KD at last frame only (vanilla CODI)
220
+ ap.add_argument("--ss_prob", type=float, default=0.0) # scheduled-sampling max prob (0 = off)
221
+ ap.add_argument("--ss_ramp_frac", type=float, default=0.5) # ramp ss_prob over this frac of steps
222
+ args = ap.parse_args()
223
+
224
+ tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
225
+ add_trace_tokens(tok) # idempotent
226
+ ids = token_ids(tok)
227
+ base = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16)
228
+ base.config.use_cache = True
229
+ teacher = None
230
+ if args.frozen_teacher:
231
+ teacher = AutoModelForCausalLM.from_pretrained(args.frozen_teacher, torch_dtype=torch.bfloat16)
232
+ teacher.config.use_cache = False
233
+ teacher.eval().requires_grad_(False)
234
+ model = CodiModel(base, latent_start_id=ids["<|latent_start|>"], latent_end_id=ids["<|latent_end|>"],
235
+ latent_steps=args.latent_steps, a=args.alpha, b=args.beta, g=args.gamma,
236
+ kd_layers=args.kd_layers, single_anchor=args.single_anchor,
237
+ ss_prob=args.ss_prob, ss_ramp_frac=args.ss_ramp_frac,
238
+ teacher=teacher, kd_target=args.kd_target, kd_temp=args.kd_temp)
239
+
240
+ ds = build_codi_dataset(tok, sources=args.sources, cache_dir=args.cache_dir,
241
+ n_samples=args.n_samples, max_seq_len=args.max_seq_len, max_frames=args.max_frames)
242
+ print(f"{len(ds)} codi examples, latent_steps={args.latent_steps}")
243
+
244
+ report_to = wandb_init(args, "codi")
245
+
246
+ targs = TrainingArguments(
247
+ output_dir=args.output_dir,
248
+ per_device_train_batch_size=args.batch_size,
249
+ gradient_accumulation_steps=args.grad_accum,
250
+ num_train_epochs=args.epochs,
251
+ max_steps=args.max_steps,
252
+ learning_rate=args.lr,
253
+ lr_scheduler_type="cosine",
254
+ warmup_ratio=0.03,
255
+ weight_decay=0.1,
256
+ max_grad_norm=1.0,
257
+ bf16=True,
258
+ optim="paged_adamw_8bit",
259
+ ddp_find_unused_parameters=False,
260
+ logging_steps=5,
261
+ save_strategy="steps",
262
+ save_steps=args.save_steps,
263
+ save_total_limit=None,
264
+ report_to=report_to,
265
+ remove_unused_columns=False,
266
+ label_names=[],
267
+ )
268
+ trainer = CodiTrainer(
269
+ model=model, args=targs, train_dataset=ds,
270
+ data_collator=lambda b: {"examples": b},
271
+ )
272
+ trainer.tok = tok
273
+ # Native checkpoints (CodiModel wrapper + optimizer) auto-resume if interrupted.
274
+ ckpt = get_last_checkpoint(args.output_dir) if os.path.isdir(args.output_dir) else None
275
+ trainer.train(resume_from_checkpoint=ckpt)
276
+ trainer._save_checkpoint(trainer.model, trial=None) # final step as a resumable, eval-loadable checkpoint-<step>
277
+
278
+
279
+ if __name__ == "__main__":
280
+ main()
code/wb.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """wandb: default-on, offline (compute nodes have no internet -> `wandb sync` later),
2
+ never blocks training. Returns report_to for TrainingArguments."""
3
+
4
+ import os
5
+
6
+
7
+ def wandb_init(args, stage):
8
+ if int(os.environ.get("RANK", "0")) != 0: # rank0 only under DDP
9
+ return []
10
+ try:
11
+ import wandb
12
+ os.environ.setdefault("WANDB_MODE", "offline")
13
+ wandb.init(project="codi_trace", name=f"{stage}-{os.path.basename(args.output_dir)}",
14
+ dir=args.output_dir, config=vars(args))
15
+ return ["wandb"]
16
+ except Exception as e:
17
+ print(f"wandb disabled: {e}")
18
+ return []