"""GSM8K data loading and batch formatting for BLT-Reasoner.""" from __future__ import annotations import re from dataclasses import dataclass from typing import List, Optional import torch from torch.utils.data import Dataset GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") def extract_final_number(answer_text: str) -> Optional[str]: m = GSM8K_ANSWER_RE.search(answer_text) return m.group(1) if m else None def format_prompt(question: str) -> str: """Minimal instruction format. We do NOT add a Chat template because at 1.5B we want the model to think with the latent loop, not chat-template overhead. """ return f"Q: {question}\nA:" def format_answer(answer_text: str) -> str: """Keep the full GSM8K-style answer (reasoning + #### NUMBER). The primary LM loss is computed on this. The InfoNCE target is encoded separately from only the #### NUMBER portion. """ return answer_text.strip() class GSM8KDataset(Dataset): def __init__(self, split: str = "train", max_examples: Optional[int] = None): from datasets import load_dataset ds = load_dataset("gsm8k", "main", split=split) self.examples = [] for ex in ds: num = extract_final_number(ex["answer"]) if num is None: continue self.examples.append({ "question": ex["question"], "answer": ex["answer"], "final": num, }) if max_examples and len(self.examples) >= max_examples: break def __len__(self): return len(self.examples) def __getitem__(self, idx): return self.examples[idx] def extract_boxed_answer(solution_text: str) -> Optional[str]: """Extract the final answer from a MATH solution: content of the LAST ``\\boxed{...}`` in the text. Handles nested braces. """ if not solution_text: return None idx = solution_text.rfind("\\boxed{") if idx < 0: # Some answers use \boxed { ... } with a space; also try that idx = solution_text.rfind("\\boxed ") if idx < 0: return None start = solution_text.find("{", idx) if start < 0: return None start += 1 depth = 1 i = start while i < len(solution_text) and depth > 0: c = solution_text[i] if c == "{": depth += 1 elif c == "}": depth -= 1 if depth == 0: return solution_text[start:i].strip() i += 1 return None class MATHDataset(Dataset): """Hendrycks et al. MATH dataset. Loads via HF datasets. The original `hendrycks/competition_math` has been removed from the Hub; we use `qwedsacf/competition_math` (12500 rows, all in a single `train` split) and DETERMINISTICALLY split into our own train/test partition. Split = 11000 train / 1500 test, by row index. This is not the official Hendrycks split (which is 7500/5000); using it means our test problems may overlap with the original train. For relative in-recipe comparisons that's fine, but absolute numbers cannot be compared directly to published results. Schema after adaptation matches GSM8KDataset: {"question": problem, "answer": full_solution_text, "final": boxed_value} final is the content of the LAST ``\\boxed{...}`` in the solution. """ _N_TRAIN = 11000 _N_TEST = 1500 _SRC = "qwedsacf/competition_math" def __init__(self, split: str = "train", max_examples: Optional[int] = None): from datasets import load_dataset ds = load_dataset(self._SRC, split="train", trust_remote_code=True) # Deterministic partition by row index (no shuffle — fixed across calls). if split == "train": indices = range(0, self._N_TRAIN) elif split == "test": indices = range(self._N_TRAIN, self._N_TRAIN + self._N_TEST) else: raise ValueError(f"unknown split: {split}") self.examples = [] for i in indices: if i >= len(ds): break ex = ds[int(i)] sol = ex.get("solution") or ex.get("answer") or "" prob = ex.get("problem") or ex.get("question") or "" if not (sol and prob): continue boxed = extract_boxed_answer(sol) if boxed is None: continue self.examples.append({ "question": prob, "answer": sol, "final": boxed, }) if max_examples and len(self.examples) >= max_examples: break def __len__(self): return len(self.examples) def __getitem__(self, idx): return self.examples[idx] @dataclass class BatchTensors: x_ids: torch.Tensor # [B, P] left-padded x_attn: torch.Tensor # [B, P] y_ids: torch.Tensor # [B, L_y] right-padded y_attn: torch.Tensor # [B, L_y] final_strs: List[str] # used to encode InfoNCE positives (just "#### N") full_answer_strs: List[str] = None # full y text (reasoning + answer) for richer InfoNCE y_chunk_strs: List[List[str]] = None # [B][K] — y split into K chunks for per-slot InfoNCE def split_y_into_chunks(answer_text: str, K: int) -> List[str]: """Split y into K substrings of (roughly) equal character length. Used for per-CoT-step InfoNCE targets, where each latent slot k is asked to encode its corresponding chunk_k of the gold y. Character-based splitting is dataset-agnostic and avoids parsing CoT-format issues; the chunks are not perfect sentence boundaries but they're close enough on GSM8K (where answers are typically 100-200 chars) and the frozen-base encoder used for the target produces a smooth gradient regardless of exact split point. """ n = len(answer_text) if n == 0: return [""] * K if K <= 0: return [] base = n // K rem = n % K out: List[str] = [] pos = 0 for k in range(K): # First `rem` chunks get one extra char so we cover the whole string. sz = base + (1 if k < rem else 0) chunk = answer_text[pos : pos + sz] out.append(chunk if chunk else "") pos += sz return out def collate_batch(batch, tokenizer, max_prompt_len: int = 256, max_answer_len: int = 256) -> BatchTensors: prompts = [format_prompt(b["question"]) for b in batch] answers = [format_answer(b["answer"]) for b in batch] finals = [f"#### {b['final']}" for b in batch] full_answers = [a for a in answers] # full GSM8K answer text (reasoning + "#### N") # Left-pad prompts so last position is always real (needed for the latent loop) enc_p = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_prompt_len, padding_side="left", add_special_tokens=False) enc_y = tokenizer(answers, return_tensors="pt", padding=True, truncation=True, max_length=max_answer_len, padding_side="right", add_special_tokens=False) # Append EOS to the right of y so the model learns to stop eos = tokenizer.eos_token_id L_y = enc_y["input_ids"].size(1) if L_y > 0: # replace the first pad in each row with EOS (or extend by 1 if no pad) y_ids = enc_y["input_ids"] y_attn = enc_y["attention_mask"] # Find each row's last real position; place EOS at first pad after it. lengths = y_attn.sum(dim=1) B = y_ids.size(0) # Pad to L_y+1 to guarantee room for EOS y_ids = torch.cat([y_ids, torch.full((B, 1), tokenizer.pad_token_id, dtype=y_ids.dtype)], dim=1) y_attn = torch.cat([y_attn, torch.zeros(B, 1, dtype=y_attn.dtype)], dim=1) for b in range(B): pos = int(lengths[b].item()) y_ids[b, pos] = eos y_attn[b, pos] = 1 else: y_ids = enc_y["input_ids"] y_attn = enc_y["attention_mask"] return BatchTensors( x_ids=enc_p["input_ids"], x_attn=enc_p["attention_mask"], y_ids=y_ids, y_attn=y_attn, final_strs=finals, full_answer_strs=full_answers, y_chunk_strs=None, # filled in lazily by callers that need per-slot InfoNCE )