| from __future__ import annotations |
|
|
| from contextlib import contextmanager |
| from dataclasses import asdict |
| from pathlib import Path |
|
|
| from .interfaces import TrainingSample |
| from .registry import register_data_provider |
|
|
|
|
| @contextmanager |
| def _serialized_hf_dataset_download(): |
| """ |
| Serialize Hugging Face ``datasets`` downloads/prepare across processes. |
| |
| Multi-GPU ``accelerate launch`` otherwise races on the same ``cache_dir`` and can |
| leave a half-written tree (e.g. missing ``dataset_info.json``). |
| """ |
| root = Path.home() / ".cache" / "neuralese" |
| root.mkdir(parents=True, exist_ok=True) |
| lock_path = root / "hf_dataset_download.lock" |
| try: |
| from filelock import FileLock |
|
|
| with FileLock(str(lock_path), timeout=7200): |
| yield |
| except ImportError: |
| yield |
|
|
|
|
| def _load_hf_split( |
| path: str, |
| split: str, |
| cache_dir: str | None, |
| config_name: str | None = None, |
| ): |
| from datasets import load_dataset |
|
|
| with _serialized_hf_dataset_download(): |
| try: |
| if config_name is not None: |
| return load_dataset(path, config_name, split=split, cache_dir=cache_dir) |
| return load_dataset(path, split=split, cache_dir=cache_dir) |
| except FileNotFoundError as exc: |
| if not cache_dir: |
| raise RuntimeError( |
| "Hugging Face dataset files are missing from the default cache. " |
| "Run once with HF_DATASETS_OFFLINE=0 (or download the dataset), " |
| "or set HF_HOME / HF_DATASETS_CACHE to a populated cache." |
| ) from exc |
| try: |
| if config_name is not None: |
| return load_dataset(path, config_name, split=split, cache_dir=None) |
| return load_dataset(path, split=split, cache_dir=None) |
| except FileNotFoundError as exc2: |
| raise RuntimeError( |
| "Could not load the dataset from the experiment cache_dir or the default HF cache. " |
| "Seed ~/.cache/huggingface/datasets (or your HF_HOME) with HF_DATASETS_OFFLINE=0, " |
| "or point storage.cache_dir at a shared cache that already contains the dataset." |
| ) from exc2 |
|
|
|
|
| PROMPT_PREFIX = ( |
| "Solve the following math problem.\n" |
| "Think step-by-step inside <think>...</think> tags.\n" |
| "Then output only the final answer in LaTeX boxed format.\n" |
| "Do not include any words or explanations outside the tags/boxed answer.\n" |
| "Output format must be exactly:\n" |
| "<think>your reasoning</think>\n" |
| "\\boxed{your_final_answer}\n\n" |
| ) |
|
|
|
|
| def _build_math_prompt(question: str) -> str: |
| user_content = f"{PROMPT_PREFIX}Question: {question}" |
| |
| return f"user: {user_content}\nassistant:" |
|
|
|
|
| def _interleave_samples( |
| left: list[TrainingSample], right: list[TrainingSample] |
| ) -> list[TrainingSample]: |
| output: list[TrainingSample] = [] |
| width = max(len(left), len(right)) |
| for idx in range(width): |
| if idx < len(left): |
| output.append(left[idx]) |
| if idx < len(right): |
| output.append(right[idx]) |
| return output |
|
|
|
|
| def _slice_if_needed( |
| samples: list[TrainingSample], max_samples: int | None |
| ) -> list[TrainingSample]: |
| if max_samples is None: |
| return samples |
| return samples[: max(0, max_samples)] |
|
|
|
|
| class _MathProviderBase: |
| dataset_name = "EleutherAI/hendrycks_math" |
| dataset_configs = ( |
| "algebra", |
| "counting_and_probability", |
| "geometry", |
| "intermediate_algebra", |
| "number_theory", |
| "prealgebra", |
| "precalculus", |
| ) |
|
|
| def __init__(self, levels: tuple[str, ...]): |
| self.levels = levels |
|
|
| def load( |
| self, |
| split: str, |
| max_samples: int | None = None, |
| cache_dir: str | None = None, |
| ) -> list[TrainingSample]: |
| try: |
| import datasets |
| except Exception as exc: |
| raise RuntimeError( |
| "datasets is required for Hendrycks MATH providers. Install dependencies first." |
| ) from exc |
|
|
| level_set = {level.strip() for level in self.levels} |
|
|
| output: list[TrainingSample] = [] |
| for config_name in self.dataset_configs: |
| rows = _load_hf_split( |
| self.dataset_name, |
| split, |
| cache_dir, |
| config_name=config_name, |
| ) |
| for row in rows: |
| level = str(row.get("level", "")).strip() |
| if level not in level_set: |
| continue |
| question = str(row.get("problem", "")) |
| target = str(row.get("solution", "")) |
| output.append( |
| TrainingSample( |
| prompt=_build_math_prompt(question), |
| target=target, |
| metadata={ |
| "dataset": "hendrycks_math", |
| "subject": config_name, |
| "level": level, |
| }, |
| ) |
| ) |
| if max_samples is not None and len(output) >= max_samples: |
| return output |
| return output |
|
|
|
|
| @register_data_provider("gsm8k") |
| class GSM8KProvider: |
| def __init__(self, dataset_name: str = "openai/gsm8k", subset: str = "main"): |
| self.dataset_name = dataset_name |
| self.subset = subset |
|
|
| def load( |
| self, |
| split: str, |
| max_samples: int | None = None, |
| cache_dir: str | None = None, |
| ) -> list[TrainingSample]: |
| try: |
| import datasets |
| except Exception as exc: |
| raise RuntimeError( |
| "datasets is required for GSM8K provider. Install dependencies first." |
| ) from exc |
|
|
| rows = _load_hf_split( |
| self.dataset_name, |
| split, |
| cache_dir, |
| config_name=self.subset, |
| ) |
| if max_samples is not None: |
| rows = rows.select(range(min(max_samples, len(rows)))) |
|
|
| output: list[TrainingSample] = [] |
| for sample_index, row in enumerate(rows): |
| prompt = _build_math_prompt(str(row["question"])) |
| output.append( |
| TrainingSample( |
| prompt=prompt, |
| target=row["answer"], |
| metadata={ |
| "dataset": "gsm8k", |
| "sample_index": int(sample_index), |
| "split": str(split), |
| }, |
| ) |
| ) |
| return output |
|
|
|
|
| @register_data_provider("math_level_1") |
| class MathLevel1Provider(_MathProviderBase): |
| def __init__(self): |
| super().__init__(levels=("Level 1",)) |
|
|
|
|
| @register_data_provider("math_level_2") |
| class MathLevel2Provider(_MathProviderBase): |
| def __init__(self): |
| super().__init__(levels=("Level 2",)) |
|
|
|
|
| @register_data_provider("math_level_3") |
| class MathLevel3Provider(_MathProviderBase): |
| def __init__(self): |
| super().__init__(levels=("Level 3",)) |
|
|
|
|
| @register_data_provider("math_level_4") |
| class MathLevel4Provider(_MathProviderBase): |
| def __init__(self): |
| super().__init__(levels=("Level 4",)) |
|
|
|
|
| @register_data_provider("math_level_5") |
| class MathLevel5Provider(_MathProviderBase): |
| def __init__(self): |
| super().__init__(levels=("Level 5",)) |
|
|
|
|
| @register_data_provider("math_levels_12") |
| class MathLevels12Provider(_MathProviderBase): |
| def __init__(self): |
| super().__init__(levels=("Level 1", "Level 2")) |
|
|
|
|
| @register_data_provider("math_levels_345") |
| class MathLevels345Provider(_MathProviderBase): |
| def __init__(self): |
| super().__init__(levels=("Level 3", "Level 4", "Level 5")) |
|
|
|
|
| @register_data_provider("gsm8k_math_stage12") |
| class GSM8KMathStage12Provider: |
| def load( |
| self, |
| split: str, |
| max_samples: int | None = None, |
| cache_dir: str | None = None, |
| ) -> list[TrainingSample]: |
| gsm = GSM8KProvider().load(split=split, max_samples=None, cache_dir=cache_dir) |
| math12 = MathLevels12Provider().load( |
| split=split, max_samples=None, cache_dir=cache_dir |
| ) |
| mixed = _interleave_samples(gsm, math12) |
| return _slice_if_needed(mixed, max_samples) |
|
|
|
|
| @register_data_provider("gsm8k_math_curriculum") |
| class GSM8KMathCurriculumProvider: |
| def load( |
| self, |
| split: str, |
| max_samples: int | None = None, |
| cache_dir: str | None = None, |
| ) -> list[TrainingSample]: |
| if max_samples is None: |
| stage12_budget = None |
| stage345_budget = None |
| else: |
| stage12_budget = (max_samples + 1) // 2 |
| stage345_budget = max_samples // 2 |
|
|
| stage12 = GSM8KMathStage12Provider().load( |
| split=split, max_samples=stage12_budget, cache_dir=cache_dir |
| ) |
| stage345 = MathLevels345Provider().load( |
| split=split, max_samples=stage345_budget, cache_dir=cache_dir |
| ) |
|
|
| |
| return stage12 + stage345 |
|
|
|
|
| def to_dataset_rows(samples: list[TrainingSample]) -> list[dict]: |
| return [asdict(sample) for sample in samples] |
|
|