File size: 8,418 Bytes
9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c bc7101b 9477b5c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """GSM8K data loading and batch formatting for BLT-Reasoner."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import List, Optional
import torch
from torch.utils.data import Dataset
GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
def extract_final_number(answer_text: str) -> Optional[str]:
m = GSM8K_ANSWER_RE.search(answer_text)
return m.group(1) if m else None
def format_prompt(question: str) -> str:
"""Minimal instruction format. We do NOT add a Chat template because at
1.5B we want the model to think with the latent loop, not chat-template
overhead.
"""
return f"Q: {question}\nA:"
def format_answer(answer_text: str) -> str:
"""Keep the full GSM8K-style answer (reasoning + #### NUMBER). The
primary LM loss is computed on this. The InfoNCE target is encoded
separately from only the #### NUMBER portion.
"""
return answer_text.strip()
class GSM8KDataset(Dataset):
def __init__(self, split: str = "train", max_examples: Optional[int] = None):
from datasets import load_dataset
ds = load_dataset("gsm8k", "main", split=split)
self.examples = []
for ex in ds:
num = extract_final_number(ex["answer"])
if num is None:
continue
self.examples.append({
"question": ex["question"],
"answer": ex["answer"],
"final": num,
})
if max_examples and len(self.examples) >= max_examples:
break
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
def extract_boxed_answer(solution_text: str) -> Optional[str]:
"""Extract the final answer from a MATH solution: content of the LAST
``\\boxed{...}`` in the text. Handles nested braces.
"""
if not solution_text:
return None
idx = solution_text.rfind("\\boxed{")
if idx < 0:
# Some answers use \boxed { ... } with a space; also try that
idx = solution_text.rfind("\\boxed ")
if idx < 0:
return None
start = solution_text.find("{", idx)
if start < 0:
return None
start += 1
depth = 1
i = start
while i < len(solution_text) and depth > 0:
c = solution_text[i]
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
return solution_text[start:i].strip()
i += 1
return None
class MATHDataset(Dataset):
"""Hendrycks et al. MATH dataset.
Loads via HF datasets. The original `hendrycks/competition_math` has been
removed from the Hub; we use `qwedsacf/competition_math` (12500 rows,
all in a single `train` split) and DETERMINISTICALLY split into our own
train/test partition. Split = 11000 train / 1500 test, by row index.
This is not the official Hendrycks split (which is 7500/5000); using it
means our test problems may overlap with the original train. For relative
in-recipe comparisons that's fine, but absolute numbers cannot be compared
directly to published results.
Schema after adaptation matches GSM8KDataset:
{"question": problem, "answer": full_solution_text, "final": boxed_value}
final is the content of the LAST ``\\boxed{...}`` in the solution.
"""
_N_TRAIN = 11000
_N_TEST = 1500
_SRC = "qwedsacf/competition_math"
def __init__(self, split: str = "train", max_examples: Optional[int] = None):
from datasets import load_dataset
ds = load_dataset(self._SRC, split="train", trust_remote_code=True)
# Deterministic partition by row index (no shuffle — fixed across calls).
if split == "train":
indices = range(0, self._N_TRAIN)
elif split == "test":
indices = range(self._N_TRAIN, self._N_TRAIN + self._N_TEST)
else:
raise ValueError(f"unknown split: {split}")
self.examples = []
for i in indices:
if i >= len(ds):
break
ex = ds[int(i)]
sol = ex.get("solution") or ex.get("answer") or ""
prob = ex.get("problem") or ex.get("question") or ""
if not (sol and prob):
continue
boxed = extract_boxed_answer(sol)
if boxed is None:
continue
self.examples.append({
"question": prob,
"answer": sol,
"final": boxed,
})
if max_examples and len(self.examples) >= max_examples:
break
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
@dataclass
class BatchTensors:
x_ids: torch.Tensor # [B, P] left-padded
x_attn: torch.Tensor # [B, P]
y_ids: torch.Tensor # [B, L_y] right-padded
y_attn: torch.Tensor # [B, L_y]
final_strs: List[str] # used to encode InfoNCE positives (just "#### N")
full_answer_strs: List[str] = None # full y text (reasoning + answer) for richer InfoNCE
y_chunk_strs: List[List[str]] = None # [B][K] — y split into K chunks for per-slot InfoNCE
def split_y_into_chunks(answer_text: str, K: int) -> List[str]:
"""Split y into K substrings of (roughly) equal character length.
Used for per-CoT-step InfoNCE targets, where each latent slot k is
asked to encode its corresponding chunk_k of the gold y. Character-based
splitting is dataset-agnostic and avoids parsing CoT-format issues; the
chunks are not perfect sentence boundaries but they're close enough on
GSM8K (where answers are typically 100-200 chars) and the frozen-base
encoder used for the target produces a smooth gradient regardless of
exact split point.
"""
n = len(answer_text)
if n == 0:
return ["<pad>"] * K
if K <= 0:
return []
base = n // K
rem = n % K
out: List[str] = []
pos = 0
for k in range(K):
# First `rem` chunks get one extra char so we cover the whole string.
sz = base + (1 if k < rem else 0)
chunk = answer_text[pos : pos + sz]
out.append(chunk if chunk else "<pad>")
pos += sz
return out
def collate_batch(batch, tokenizer, max_prompt_len: int = 256, max_answer_len: int = 256) -> BatchTensors:
prompts = [format_prompt(b["question"]) for b in batch]
answers = [format_answer(b["answer"]) for b in batch]
finals = [f"#### {b['final']}" for b in batch]
full_answers = [a for a in answers] # full GSM8K answer text (reasoning + "#### N")
# Left-pad prompts so last position is always real (needed for the latent loop)
enc_p = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
max_length=max_prompt_len, padding_side="left",
add_special_tokens=False)
enc_y = tokenizer(answers, return_tensors="pt", padding=True, truncation=True,
max_length=max_answer_len, padding_side="right",
add_special_tokens=False)
# Append EOS to the right of y so the model learns to stop
eos = tokenizer.eos_token_id
L_y = enc_y["input_ids"].size(1)
if L_y > 0:
# replace the first pad in each row with EOS (or extend by 1 if no pad)
y_ids = enc_y["input_ids"]
y_attn = enc_y["attention_mask"]
# Find each row's last real position; place EOS at first pad after it.
lengths = y_attn.sum(dim=1)
B = y_ids.size(0)
# Pad to L_y+1 to guarantee room for EOS
y_ids = torch.cat([y_ids, torch.full((B, 1), tokenizer.pad_token_id, dtype=y_ids.dtype)], dim=1)
y_attn = torch.cat([y_attn, torch.zeros(B, 1, dtype=y_attn.dtype)], dim=1)
for b in range(B):
pos = int(lengths[b].item())
y_ids[b, pos] = eos
y_attn[b, pos] = 1
else:
y_ids = enc_y["input_ids"]
y_attn = enc_y["attention_mask"]
return BatchTensors(
x_ids=enc_p["input_ids"],
x_attn=enc_p["attention_mask"],
y_ids=y_ids,
y_attn=y_attn,
final_strs=finals,
full_answer_strs=full_answers,
y_chunk_strs=None, # filled in lazily by callers that need per-slot InfoNCE
)
|