LauraGG's picture
Refresh code/ with latest BLT-Reasoner sources (post-campaign)
bc7101b verified
"""GSM8K data loading and batch formatting for BLT-Reasoner."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import List, Optional
import torch
from torch.utils.data import Dataset
GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
def extract_final_number(answer_text: str) -> Optional[str]:
m = GSM8K_ANSWER_RE.search(answer_text)
return m.group(1) if m else None
def format_prompt(question: str) -> str:
"""Minimal instruction format. We do NOT add a Chat template because at
1.5B we want the model to think with the latent loop, not chat-template
overhead.
"""
return f"Q: {question}\nA:"
def format_answer(answer_text: str) -> str:
"""Keep the full GSM8K-style answer (reasoning + #### NUMBER). The
primary LM loss is computed on this. The InfoNCE target is encoded
separately from only the #### NUMBER portion.
"""
return answer_text.strip()
class GSM8KDataset(Dataset):
def __init__(self, split: str = "train", max_examples: Optional[int] = None):
from datasets import load_dataset
ds = load_dataset("gsm8k", "main", split=split)
self.examples = []
for ex in ds:
num = extract_final_number(ex["answer"])
if num is None:
continue
self.examples.append({
"question": ex["question"],
"answer": ex["answer"],
"final": num,
})
if max_examples and len(self.examples) >= max_examples:
break
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
def extract_boxed_answer(solution_text: str) -> Optional[str]:
"""Extract the final answer from a MATH solution: content of the LAST
``\\boxed{...}`` in the text. Handles nested braces.
"""
if not solution_text:
return None
idx = solution_text.rfind("\\boxed{")
if idx < 0:
# Some answers use \boxed { ... } with a space; also try that
idx = solution_text.rfind("\\boxed ")
if idx < 0:
return None
start = solution_text.find("{", idx)
if start < 0:
return None
start += 1
depth = 1
i = start
while i < len(solution_text) and depth > 0:
c = solution_text[i]
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
return solution_text[start:i].strip()
i += 1
return None
class MATHDataset(Dataset):
"""Hendrycks et al. MATH dataset.
Loads via HF datasets. The original `hendrycks/competition_math` has been
removed from the Hub; we use `qwedsacf/competition_math` (12500 rows,
all in a single `train` split) and DETERMINISTICALLY split into our own
train/test partition. Split = 11000 train / 1500 test, by row index.
This is not the official Hendrycks split (which is 7500/5000); using it
means our test problems may overlap with the original train. For relative
in-recipe comparisons that's fine, but absolute numbers cannot be compared
directly to published results.
Schema after adaptation matches GSM8KDataset:
{"question": problem, "answer": full_solution_text, "final": boxed_value}
final is the content of the LAST ``\\boxed{...}`` in the solution.
"""
_N_TRAIN = 11000
_N_TEST = 1500
_SRC = "qwedsacf/competition_math"
def __init__(self, split: str = "train", max_examples: Optional[int] = None):
from datasets import load_dataset
ds = load_dataset(self._SRC, split="train", trust_remote_code=True)
# Deterministic partition by row index (no shuffle — fixed across calls).
if split == "train":
indices = range(0, self._N_TRAIN)
elif split == "test":
indices = range(self._N_TRAIN, self._N_TRAIN + self._N_TEST)
else:
raise ValueError(f"unknown split: {split}")
self.examples = []
for i in indices:
if i >= len(ds):
break
ex = ds[int(i)]
sol = ex.get("solution") or ex.get("answer") or ""
prob = ex.get("problem") or ex.get("question") or ""
if not (sol and prob):
continue
boxed = extract_boxed_answer(sol)
if boxed is None:
continue
self.examples.append({
"question": prob,
"answer": sol,
"final": boxed,
})
if max_examples and len(self.examples) >= max_examples:
break
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
return self.examples[idx]
@dataclass
class BatchTensors:
x_ids: torch.Tensor # [B, P] left-padded
x_attn: torch.Tensor # [B, P]
y_ids: torch.Tensor # [B, L_y] right-padded
y_attn: torch.Tensor # [B, L_y]
final_strs: List[str] # used to encode InfoNCE positives (just "#### N")
full_answer_strs: List[str] = None # full y text (reasoning + answer) for richer InfoNCE
y_chunk_strs: List[List[str]] = None # [B][K] — y split into K chunks for per-slot InfoNCE
def split_y_into_chunks(answer_text: str, K: int) -> List[str]:
"""Split y into K substrings of (roughly) equal character length.
Used for per-CoT-step InfoNCE targets, where each latent slot k is
asked to encode its corresponding chunk_k of the gold y. Character-based
splitting is dataset-agnostic and avoids parsing CoT-format issues; the
chunks are not perfect sentence boundaries but they're close enough on
GSM8K (where answers are typically 100-200 chars) and the frozen-base
encoder used for the target produces a smooth gradient regardless of
exact split point.
"""
n = len(answer_text)
if n == 0:
return ["<pad>"] * K
if K <= 0:
return []
base = n // K
rem = n % K
out: List[str] = []
pos = 0
for k in range(K):
# First `rem` chunks get one extra char so we cover the whole string.
sz = base + (1 if k < rem else 0)
chunk = answer_text[pos : pos + sz]
out.append(chunk if chunk else "<pad>")
pos += sz
return out
def collate_batch(batch, tokenizer, max_prompt_len: int = 256, max_answer_len: int = 256) -> BatchTensors:
prompts = [format_prompt(b["question"]) for b in batch]
answers = [format_answer(b["answer"]) for b in batch]
finals = [f"#### {b['final']}" for b in batch]
full_answers = [a for a in answers] # full GSM8K answer text (reasoning + "#### N")
# Left-pad prompts so last position is always real (needed for the latent loop)
enc_p = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True,
max_length=max_prompt_len, padding_side="left",
add_special_tokens=False)
enc_y = tokenizer(answers, return_tensors="pt", padding=True, truncation=True,
max_length=max_answer_len, padding_side="right",
add_special_tokens=False)
# Append EOS to the right of y so the model learns to stop
eos = tokenizer.eos_token_id
L_y = enc_y["input_ids"].size(1)
if L_y > 0:
# replace the first pad in each row with EOS (or extend by 1 if no pad)
y_ids = enc_y["input_ids"]
y_attn = enc_y["attention_mask"]
# Find each row's last real position; place EOS at first pad after it.
lengths = y_attn.sum(dim=1)
B = y_ids.size(0)
# Pad to L_y+1 to guarantee room for EOS
y_ids = torch.cat([y_ids, torch.full((B, 1), tokenizer.pad_token_id, dtype=y_ids.dtype)], dim=1)
y_attn = torch.cat([y_attn, torch.zeros(B, 1, dtype=y_attn.dtype)], dim=1)
for b in range(B):
pos = int(lengths[b].item())
y_ids[b, pos] = eos
y_attn[b, pos] = 1
else:
y_ids = enc_y["input_ids"]
y_attn = enc_y["attention_mask"]
return BatchTensors(
x_ids=enc_p["input_ids"],
x_attn=enc_p["attention_mask"],
y_ids=y_ids,
y_attn=y_attn,
final_strs=finals,
full_answer_strs=full_answers,
y_chunk_strs=None, # filled in lazily by callers that need per-slot InfoNCE
)