| """GSM8K data loading and batch formatting for BLT-Reasoner.""" |
| from __future__ import annotations |
|
|
| import re |
| from dataclasses import dataclass |
| from typing import List, Optional |
|
|
| import torch |
| from torch.utils.data import Dataset |
|
|
| GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") |
|
|
|
|
| def extract_final_number(answer_text: str) -> Optional[str]: |
| m = GSM8K_ANSWER_RE.search(answer_text) |
| return m.group(1) if m else None |
|
|
|
|
| def format_prompt(question: str) -> str: |
| """Minimal instruction format. We do NOT add a Chat template because at |
| 1.5B we want the model to think with the latent loop, not chat-template |
| overhead. |
| """ |
| return f"Q: {question}\nA:" |
|
|
|
|
| def format_answer(answer_text: str) -> str: |
| """Keep the full GSM8K-style answer (reasoning + #### NUMBER). The |
| primary LM loss is computed on this. The InfoNCE target is encoded |
| separately from only the #### NUMBER portion. |
| """ |
| return answer_text.strip() |
|
|
|
|
| class GSM8KDataset(Dataset): |
| def __init__(self, split: str = "train", max_examples: Optional[int] = None): |
| from datasets import load_dataset |
| ds = load_dataset("gsm8k", "main", split=split) |
| self.examples = [] |
| for ex in ds: |
| num = extract_final_number(ex["answer"]) |
| if num is None: |
| continue |
| self.examples.append({ |
| "question": ex["question"], |
| "answer": ex["answer"], |
| "final": num, |
| }) |
| if max_examples and len(self.examples) >= max_examples: |
| break |
|
|
| def __len__(self): |
| return len(self.examples) |
|
|
| def __getitem__(self, idx): |
| return self.examples[idx] |
|
|
|
|
| def extract_boxed_answer(solution_text: str) -> Optional[str]: |
| """Extract the final answer from a MATH solution: content of the LAST |
| ``\\boxed{...}`` in the text. Handles nested braces. |
| """ |
| if not solution_text: |
| return None |
| idx = solution_text.rfind("\\boxed{") |
| if idx < 0: |
| |
| idx = solution_text.rfind("\\boxed ") |
| if idx < 0: |
| return None |
| start = solution_text.find("{", idx) |
| if start < 0: |
| return None |
| start += 1 |
| depth = 1 |
| i = start |
| while i < len(solution_text) and depth > 0: |
| c = solution_text[i] |
| if c == "{": |
| depth += 1 |
| elif c == "}": |
| depth -= 1 |
| if depth == 0: |
| return solution_text[start:i].strip() |
| i += 1 |
| return None |
|
|
|
|
| class MATHDataset(Dataset): |
| """Hendrycks et al. MATH dataset. |
| |
| Loads via HF datasets. The original `hendrycks/competition_math` has been |
| removed from the Hub; we use `qwedsacf/competition_math` (12500 rows, |
| all in a single `train` split) and DETERMINISTICALLY split into our own |
| train/test partition. Split = 11000 train / 1500 test, by row index. |
| This is not the official Hendrycks split (which is 7500/5000); using it |
| means our test problems may overlap with the original train. For relative |
| in-recipe comparisons that's fine, but absolute numbers cannot be compared |
| directly to published results. |
| |
| Schema after adaptation matches GSM8KDataset: |
| {"question": problem, "answer": full_solution_text, "final": boxed_value} |
| final is the content of the LAST ``\\boxed{...}`` in the solution. |
| """ |
|
|
| _N_TRAIN = 11000 |
| _N_TEST = 1500 |
| _SRC = "qwedsacf/competition_math" |
|
|
| def __init__(self, split: str = "train", max_examples: Optional[int] = None): |
| from datasets import load_dataset |
| ds = load_dataset(self._SRC, split="train", trust_remote_code=True) |
| |
| if split == "train": |
| indices = range(0, self._N_TRAIN) |
| elif split == "test": |
| indices = range(self._N_TRAIN, self._N_TRAIN + self._N_TEST) |
| else: |
| raise ValueError(f"unknown split: {split}") |
|
|
| self.examples = [] |
| for i in indices: |
| if i >= len(ds): |
| break |
| ex = ds[int(i)] |
| sol = ex.get("solution") or ex.get("answer") or "" |
| prob = ex.get("problem") or ex.get("question") or "" |
| if not (sol and prob): |
| continue |
| boxed = extract_boxed_answer(sol) |
| if boxed is None: |
| continue |
| self.examples.append({ |
| "question": prob, |
| "answer": sol, |
| "final": boxed, |
| }) |
| if max_examples and len(self.examples) >= max_examples: |
| break |
|
|
| def __len__(self): |
| return len(self.examples) |
|
|
| def __getitem__(self, idx): |
| return self.examples[idx] |
|
|
|
|
| @dataclass |
| class BatchTensors: |
| x_ids: torch.Tensor |
| x_attn: torch.Tensor |
| y_ids: torch.Tensor |
| y_attn: torch.Tensor |
| final_strs: List[str] |
| full_answer_strs: List[str] = None |
| y_chunk_strs: List[List[str]] = None |
|
|
|
|
| def split_y_into_chunks(answer_text: str, K: int) -> List[str]: |
| """Split y into K substrings of (roughly) equal character length. |
| |
| Used for per-CoT-step InfoNCE targets, where each latent slot k is |
| asked to encode its corresponding chunk_k of the gold y. Character-based |
| splitting is dataset-agnostic and avoids parsing CoT-format issues; the |
| chunks are not perfect sentence boundaries but they're close enough on |
| GSM8K (where answers are typically 100-200 chars) and the frozen-base |
| encoder used for the target produces a smooth gradient regardless of |
| exact split point. |
| """ |
| n = len(answer_text) |
| if n == 0: |
| return ["<pad>"] * K |
| if K <= 0: |
| return [] |
| base = n // K |
| rem = n % K |
| out: List[str] = [] |
| pos = 0 |
| for k in range(K): |
| |
| sz = base + (1 if k < rem else 0) |
| chunk = answer_text[pos : pos + sz] |
| out.append(chunk if chunk else "<pad>") |
| pos += sz |
| return out |
|
|
|
|
| def collate_batch(batch, tokenizer, max_prompt_len: int = 256, max_answer_len: int = 256) -> BatchTensors: |
| prompts = [format_prompt(b["question"]) for b in batch] |
| answers = [format_answer(b["answer"]) for b in batch] |
| finals = [f"#### {b['final']}" for b in batch] |
| full_answers = [a for a in answers] |
|
|
| |
| enc_p = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, |
| max_length=max_prompt_len, padding_side="left", |
| add_special_tokens=False) |
| enc_y = tokenizer(answers, return_tensors="pt", padding=True, truncation=True, |
| max_length=max_answer_len, padding_side="right", |
| add_special_tokens=False) |
| |
| eos = tokenizer.eos_token_id |
| L_y = enc_y["input_ids"].size(1) |
| if L_y > 0: |
| |
| y_ids = enc_y["input_ids"] |
| y_attn = enc_y["attention_mask"] |
| |
| lengths = y_attn.sum(dim=1) |
| B = y_ids.size(0) |
| |
| y_ids = torch.cat([y_ids, torch.full((B, 1), tokenizer.pad_token_id, dtype=y_ids.dtype)], dim=1) |
| y_attn = torch.cat([y_attn, torch.zeros(B, 1, dtype=y_attn.dtype)], dim=1) |
| for b in range(B): |
| pos = int(lengths[b].item()) |
| y_ids[b, pos] = eos |
| y_attn[b, pos] = 1 |
| else: |
| y_ids = enc_y["input_ids"] |
| y_attn = enc_y["attention_mask"] |
|
|
| return BatchTensors( |
| x_ids=enc_p["input_ids"], |
| x_attn=enc_p["attention_mask"], |
| y_ids=y_ids, |
| y_attn=y_attn, |
| final_strs=finals, |
| full_answer_strs=full_answers, |
| y_chunk_strs=None, |
| ) |
|
|