File size: 8,418 Bytes
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc7101b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9477b5c
 
 
 
 
 
bc7101b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9477b5c
 
 
 
 
 
bc7101b
9477b5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc7101b
 
9477b5c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""GSM8K data loading and batch formatting for BLT-Reasoner."""
from __future__ import annotations

import re
from dataclasses import dataclass
from typing import List, Optional

import torch
from torch.utils.data import Dataset

GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")


def extract_final_number(answer_text: str) -> Optional[str]:
    m = GSM8K_ANSWER_RE.search(answer_text)
    return m.group(1) if m else None


def format_prompt(question: str) -> str:
    """Minimal instruction format. We do NOT add a Chat template because at
    1.5B we want the model to think with the latent loop, not chat-template
    overhead.
    """
    return f"Q: {question}\nA:"


def format_answer(answer_text: str) -> str:
    """Keep the full GSM8K-style answer (reasoning + #### NUMBER). The
    primary LM loss is computed on this. The InfoNCE target is encoded
    separately from only the #### NUMBER portion.
    """
    return answer_text.strip()


class GSM8KDataset(Dataset):
    def __init__(self, split: str = "train", max_examples: Optional[int] = None):
        from datasets import load_dataset
        ds = load_dataset("gsm8k", "main", split=split)
        self.examples = []
        for ex in ds:
            num = extract_final_number(ex["answer"])
            if num is None:
                continue
            self.examples.append({
                "question": ex["question"],
                "answer": ex["answer"],
                "final": num,
            })
            if max_examples and len(self.examples) >= max_examples:
                break

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


def extract_boxed_answer(solution_text: str) -> Optional[str]:
    """Extract the final answer from a MATH solution: content of the LAST
    ``\\boxed{...}`` in the text. Handles nested braces.
    """
    if not solution_text:
        return None
    idx = solution_text.rfind("\\boxed{")
    if idx < 0:
        # Some answers use \boxed { ... } with a space; also try that
        idx = solution_text.rfind("\\boxed ")
        if idx < 0:
            return None
    start = solution_text.find("{", idx)
    if start < 0:
        return None
    start += 1
    depth = 1
    i = start
    while i < len(solution_text) and depth > 0:
        c = solution_text[i]
        if c == "{":
            depth += 1
        elif c == "}":
            depth -= 1
            if depth == 0:
                return solution_text[start:i].strip()
        i += 1
    return None


class MATHDataset(Dataset):
    """Hendrycks et al. MATH dataset.

    Loads via HF datasets. The original `hendrycks/competition_math` has been
    removed from the Hub; we use `qwedsacf/competition_math` (12500 rows,
    all in a single `train` split) and DETERMINISTICALLY split into our own
    train/test partition. Split = 11000 train / 1500 test, by row index.
    This is not the official Hendrycks split (which is 7500/5000); using it
    means our test problems may overlap with the original train. For relative
    in-recipe comparisons that's fine, but absolute numbers cannot be compared
    directly to published results.

    Schema after adaptation matches GSM8KDataset:
        {"question": problem, "answer": full_solution_text, "final": boxed_value}
    final is the content of the LAST ``\\boxed{...}`` in the solution.
    """

    _N_TRAIN = 11000
    _N_TEST = 1500
    _SRC = "qwedsacf/competition_math"

    def __init__(self, split: str = "train", max_examples: Optional[int] = None):
        from datasets import load_dataset
        ds = load_dataset(self._SRC, split="train", trust_remote_code=True)
        # Deterministic partition by row index (no shuffle — fixed across calls).
        if split == "train":
            indices = range(0, self._N_TRAIN)
        elif split == "test":
            indices = range(self._N_TRAIN, self._N_TRAIN + self._N_TEST)
        else:
            raise ValueError(f"unknown split: {split}")

        self.examples = []
        for i in indices:
            if i >= len(ds):
                break
            ex = ds[int(i)]
            sol = ex.get("solution") or ex.get("answer") or ""
            prob = ex.get("problem") or ex.get("question") or ""
            if not (sol and prob):
                continue
            boxed = extract_boxed_answer(sol)
            if boxed is None:
                continue
            self.examples.append({
                "question": prob,
                "answer": sol,
                "final": boxed,
            })
            if max_examples and len(self.examples) >= max_examples:
                break

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]


@dataclass
class BatchTensors:
    x_ids: torch.Tensor      # [B, P]  left-padded
    x_attn: torch.Tensor     # [B, P]
    y_ids: torch.Tensor      # [B, L_y]  right-padded
    y_attn: torch.Tensor     # [B, L_y]
    final_strs: List[str]    # used to encode InfoNCE positives (just "#### N")
    full_answer_strs: List[str] = None   # full y text (reasoning + answer) for richer InfoNCE
    y_chunk_strs: List[List[str]] = None  # [B][K] — y split into K chunks for per-slot InfoNCE


def split_y_into_chunks(answer_text: str, K: int) -> List[str]:
    """Split y into K substrings of (roughly) equal character length.

    Used for per-CoT-step InfoNCE targets, where each latent slot k is
    asked to encode its corresponding chunk_k of the gold y. Character-based
    splitting is dataset-agnostic and avoids parsing CoT-format issues; the
    chunks are not perfect sentence boundaries but they're close enough on
    GSM8K (where answers are typically 100-200 chars) and the frozen-base
    encoder used for the target produces a smooth gradient regardless of
    exact split point.
    """
    n = len(answer_text)
    if n == 0:
        return ["<pad>"] * K
    if K <= 0:
        return []
    base = n // K
    rem = n % K
    out: List[str] = []
    pos = 0
    for k in range(K):
        # First `rem` chunks get one extra char so we cover the whole string.
        sz = base + (1 if k < rem else 0)
        chunk = answer_text[pos : pos + sz]
        out.append(chunk if chunk else "<pad>")
        pos += sz
    return out


def collate_batch(batch, tokenizer, max_prompt_len: int = 256, max_answer_len: int = 256) -> BatchTensors:
    prompts = [format_prompt(b["question"]) for b in batch]
    answers = [format_answer(b["answer"]) for b in batch]
    finals = [f"#### {b['final']}" for b in batch]
    full_answers = [a for a in answers]  # full GSM8K answer text (reasoning + "#### N")

    # Left-pad prompts so last position is always real (needed for the latent loop)
    enc_p = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
                      max_length=max_prompt_len, padding_side="left",
                      add_special_tokens=False)
    enc_y = tokenizer(answers, return_tensors="pt", padding=True, truncation=True,
                      max_length=max_answer_len, padding_side="right",
                      add_special_tokens=False)
    # Append EOS to the right of y so the model learns to stop
    eos = tokenizer.eos_token_id
    L_y = enc_y["input_ids"].size(1)
    if L_y > 0:
        # replace the first pad in each row with EOS (or extend by 1 if no pad)
        y_ids = enc_y["input_ids"]
        y_attn = enc_y["attention_mask"]
        # Find each row's last real position; place EOS at first pad after it.
        lengths = y_attn.sum(dim=1)
        B = y_ids.size(0)
        # Pad to L_y+1 to guarantee room for EOS
        y_ids = torch.cat([y_ids, torch.full((B, 1), tokenizer.pad_token_id, dtype=y_ids.dtype)], dim=1)
        y_attn = torch.cat([y_attn, torch.zeros(B, 1, dtype=y_attn.dtype)], dim=1)
        for b in range(B):
            pos = int(lengths[b].item())
            y_ids[b, pos] = eos
            y_attn[b, pos] = 1
    else:
        y_ids = enc_y["input_ids"]
        y_attn = enc_y["attention_mask"]

    return BatchTensors(
        x_ids=enc_p["input_ids"],
        x_attn=enc_p["attention_mask"],
        y_ids=y_ids,
        y_attn=y_attn,
        final_strs=finals,
        full_answer_strs=full_answers,
        y_chunk_strs=None,  # filled in lazily by callers that need per-slot InfoNCE
    )