File size: 7,268 Bytes
aedd6ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) Meta Platforms, Inc. and affiliates.

"""
CRUXEval-O dataset: deterministic train/val split + ground-truth execution
traces -> (input_ids, labels) for teacher-forcing CODI.

Neutral data layer shared by training (``cwm.training.data``) and eval
(``evals.cruxeval.run_eval_codi``); depends on nothing in either, so the
split and trace format never drift. Thin HuggingFace-tokenizer wrapper over
the verbatim Table 9 trace generator (``.ground_truth`` / ``.trace_format``):
build the seeded prompt, tokenize ``prompt + render_frames_to_generation(frames)``,
and mask the prompt out of the labels (teacher-forced, so labels == input_ids
with the prompt prefix set to ``-100``).
"""

from __future__ import annotations

from .ground_truth import ground_truth_trace, make_trace_context
from .trace_format import (
    ACTION_SEP,
    LINE_SEP,
    TraceEvent,
    render_frames_to_generation,
)

IGNORE_INDEX = -100
def _prompt_str(code: str, input_str: str) -> str:
    ctx = make_trace_context(code, input_str)
    return f"<|trace_context_start|>{ctx}<|frame_sep|><|call_sep|>{{}}<|action_sep|>def main():\n<|frame_sep|>"


def _tokenize_trace(code, input_str, tokenizer, *, max_seq_len, max_frames):
    """``(prompt_ids, trace_ids, spans)``; None to skip. Trace must terminate in
    RETURN/EXCEPTION and have >=1 LINE span. Span ``(i, j)``: ``trace_ids[i]`` is
    ``<|line_sep|>``, ``j`` its ``<|action_sep|>``, ``trace_ids[i+1:j]`` the locals
    a CODI student swaps for a latent block. Single source of membership so the SFT
    baseline and CODI train on identical data."""
    frames, error = ground_truth_trace(code, input_str, align_to_prompt=True, max_frames=max_frames)
    if not frames or error == "frames_exceeded":
        return None
    if frames[-1].event not in (TraceEvent.RETURN, TraceEvent.EXCEPTION):
        return None
    # Qwen has no BOS (bos_token_id is None); CWM did. Prepend only if present.
    bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
    prompt_ids = bos + tokenizer.encode(_prompt_str(code, input_str), add_special_tokens=False)
    trace_ids = tokenizer.encode(render_frames_to_generation(frames), add_special_tokens=False)
    if len(prompt_ids) + len(trace_ids) > max_seq_len:
        return None
    ls = tokenizer.convert_tokens_to_ids(LINE_SEP)
    asep = tokenizer.convert_tokens_to_ids(ACTION_SEP)
    spans, i, n = [], 0, len(trace_ids)
    while i < n:
        if trace_ids[i] == ls:
            j = i + 1
            while j < n and trace_ids[j] != asep:
                j += 1
            if j == n:
                break
            spans.append((i, j))
            i = j + 1
        else:
            i += 1
    if not spans:
        return None
    return prompt_ids, trace_ids, spans


def build_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
    """SFT ``(input_ids, labels)`` with the prompt masked; None to skip."""
    r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
    if r is None:
        return None
    prompt_ids, trace_ids, _ = r
    return prompt_ids + trace_ids, [IGNORE_INDEX] * len(prompt_ids) + trace_ids


def build_codi_example(code, input_str, tokenizer, *, max_seq_len, max_frames=-1):
    """Multi-span CODI example ``{prompt_ids, trace_ids, spans}``; None to skip."""
    r = _tokenize_trace(code, input_str, tokenizer, max_seq_len=max_seq_len, max_frames=max_frames)
    if r is None:
        return None
    prompt_ids, trace_ids, spans = r
    return {"prompt_ids": prompt_ids, "trace_ids": trace_ids, "spans": spans}


def _load_cache(cache_dir, n_samples):
    """Load precomputed tokenized examples (precompute.py); slice to n_samples."""
    from datasets import load_from_disk

    ex = list(load_from_disk(cache_dir))
    return ex[:n_samples] if n_samples > 0 else ex


def build_codi_dataset(
    tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
    max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
) -> list[dict]:
    """CODI examples (prompt/reasoning/answer) over ``sources``, or a precomputed cache."""
    if cache_dir:
        ex = _load_cache(cache_dir, n_samples)
        return [e for e in ex if len(e["prompt_ids"]) + len(e["trace_ids"]) <= max_seq_len]
    rows = rows_for_sources(sources)
    if n_samples > 0:
        rows = rows[:n_samples]
    out = []
    for r in rows:
        try:
            out.append(build_codi_example(r["code"], r["input"], tokenizer,
                                          max_seq_len=max_seq_len, max_frames=max_frames))
        except Exception:
            pass
    return [ex for ex in out if ex is not None]


def build_codi_single_dataset(
    tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
    max_seq_len: int = 4096, max_frames: int = -1, cache_dir: str | None = None
) -> list[dict]:
    """Faithful single-block CODI: split each trace at its last ``<|return_sep|>`` into
    ``{prompt_ids, reasoning_ids, answer_ids}`` (reasoning = whole trace, answer = final
    RETURN frame). Derived from the multi-span examples; no separate cache needed."""
    rsep = tokenizer.convert_tokens_to_ids("<|return_sep|>")
    out = []
    for e in build_codi_dataset(tokenizer, sources=sources, n_samples=n_samples,
                                max_seq_len=max_seq_len, max_frames=max_frames, cache_dir=cache_dir):
        t = e["trace_ids"]
        idx = [i for i, x in enumerate(t) if x == rsep]
        if not idx or idx[-1] == 0:
            continue
        out.append({"prompt_ids": e["prompt_ids"], "reasoning_ids": t[:idx[-1]], "answer_ids": t[idx[-1]:]})
    return out


def rows_for_sources(sources):
    """Merge {id,code,input,output} rows across sources (all rows; train vs test
    is split by dataset, e.g. cruxeval is held out for eval)."""
    from . import sources as _src

    rows = []
    for name in sources:
        for i, row in enumerate(_src.load_one(name)):
            missing = [k for k in ("id", "code", "input", "output") if k not in row]
            if missing:
                raise ValueError(f"{name} row {i} missing keys: {missing}")
            if not all(isinstance(row[k], str) for k in ("code", "input", "output")):
                raise TypeError(f"{name} row {i} must use string code/input/output")
            row = dict(row)
            row["id"] = str(row["id"])
            rows.append(row)
    return rows


def build_dataset(
    tokenizer, *, sources=("mbpp", "humaneval", "pyx"), n_samples: int = -1,
    max_seq_len: int = 8192, max_frames: int = -1, cache_dir: str | None = None
) -> list[tuple[list[int], list[int]]]:
    """Tokenized trace examples over ``sources``, or a precomputed cache."""
    if cache_dir:
        ex = _load_cache(cache_dir, n_samples)
        return [(e["input_ids"], e["labels"]) for e in ex if len(e["input_ids"]) <= max_seq_len]
    rows = rows_for_sources(sources)
    if n_samples > 0:
        rows = rows[:n_samples]
    examples = (
        build_example(
            r["code"], r["input"], tokenizer,
            max_seq_len=max_seq_len, max_frames=max_frames,
        )
        for r in rows
    )
    return [ex for ex in examples if ex is not None]