from __future__ import annotations from contextlib import contextmanager from dataclasses import asdict from pathlib import Path from .interfaces import TrainingSample from .registry import register_data_provider @contextmanager def _serialized_hf_dataset_download(): """ Serialize Hugging Face ``datasets`` downloads/prepare across processes. Multi-GPU ``accelerate launch`` otherwise races on the same ``cache_dir`` and can leave a half-written tree (e.g. missing ``dataset_info.json``). """ root = Path.home() / ".cache" / "neuralese" root.mkdir(parents=True, exist_ok=True) lock_path = root / "hf_dataset_download.lock" try: from filelock import FileLock with FileLock(str(lock_path), timeout=7200): yield except ImportError: yield def _load_hf_split( path: str, split: str, cache_dir: str | None, config_name: str | None = None, ): from datasets import load_dataset with _serialized_hf_dataset_download(): try: if config_name is not None: return load_dataset(path, config_name, split=split, cache_dir=cache_dir) return load_dataset(path, split=split, cache_dir=cache_dir) except FileNotFoundError as exc: if not cache_dir: raise RuntimeError( "Hugging Face dataset files are missing from the default cache. " "Run once with HF_DATASETS_OFFLINE=0 (or download the dataset), " "or set HF_HOME / HF_DATASETS_CACHE to a populated cache." ) from exc try: if config_name is not None: return load_dataset(path, config_name, split=split, cache_dir=None) return load_dataset(path, split=split, cache_dir=None) except FileNotFoundError as exc2: raise RuntimeError( "Could not load the dataset from the experiment cache_dir or the default HF cache. " "Seed ~/.cache/huggingface/datasets (or your HF_HOME) with HF_DATASETS_OFFLINE=0, " "or point storage.cache_dir at a shared cache that already contains the dataset." ) from exc2 PROMPT_PREFIX = ( "Solve the following math problem.\n" "Think step-by-step inside ... tags.\n" "Then output only the final answer in LaTeX boxed format.\n" "Do not include any words or explanations outside the tags/boxed answer.\n" "Output format must be exactly:\n" "your reasoning\n" "\\boxed{your_final_answer}\n\n" ) def _build_math_prompt(question: str) -> str: user_content = f"{PROMPT_PREFIX}Question: {question}" # Chat-style prefill so decoding starts after "assistant:". return f"user: {user_content}\nassistant:" def _interleave_samples( left: list[TrainingSample], right: list[TrainingSample] ) -> list[TrainingSample]: output: list[TrainingSample] = [] width = max(len(left), len(right)) for idx in range(width): if idx < len(left): output.append(left[idx]) if idx < len(right): output.append(right[idx]) return output def _slice_if_needed( samples: list[TrainingSample], max_samples: int | None ) -> list[TrainingSample]: if max_samples is None: return samples return samples[: max(0, max_samples)] class _MathProviderBase: dataset_name = "EleutherAI/hendrycks_math" dataset_configs = ( "algebra", "counting_and_probability", "geometry", "intermediate_algebra", "number_theory", "prealgebra", "precalculus", ) def __init__(self, levels: tuple[str, ...]): self.levels = levels def load( self, split: str, max_samples: int | None = None, cache_dir: str | None = None, ) -> list[TrainingSample]: try: import datasets # noqa: F401 except Exception as exc: raise RuntimeError( "datasets is required for Hendrycks MATH providers. Install dependencies first." ) from exc level_set = {level.strip() for level in self.levels} output: list[TrainingSample] = [] for config_name in self.dataset_configs: rows = _load_hf_split( self.dataset_name, split, cache_dir, config_name=config_name, ) for row in rows: level = str(row.get("level", "")).strip() if level not in level_set: continue question = str(row.get("problem", "")) target = str(row.get("solution", "")) output.append( TrainingSample( prompt=_build_math_prompt(question), target=target, metadata={ "dataset": "hendrycks_math", "subject": config_name, "level": level, }, ) ) if max_samples is not None and len(output) >= max_samples: return output return output @register_data_provider("gsm8k") class GSM8KProvider: def __init__(self, dataset_name: str = "openai/gsm8k", subset: str = "main"): self.dataset_name = dataset_name self.subset = subset def load( self, split: str, max_samples: int | None = None, cache_dir: str | None = None, ) -> list[TrainingSample]: try: import datasets # noqa: F401 except Exception as exc: raise RuntimeError( "datasets is required for GSM8K provider. Install dependencies first." ) from exc rows = _load_hf_split( self.dataset_name, split, cache_dir, config_name=self.subset, ) if max_samples is not None: rows = rows.select(range(min(max_samples, len(rows)))) output: list[TrainingSample] = [] for sample_index, row in enumerate(rows): prompt = _build_math_prompt(str(row["question"])) output.append( TrainingSample( prompt=prompt, target=row["answer"], metadata={ "dataset": "gsm8k", "sample_index": int(sample_index), "split": str(split), }, ) ) return output @register_data_provider("math_level_1") class MathLevel1Provider(_MathProviderBase): def __init__(self): super().__init__(levels=("Level 1",)) @register_data_provider("math_level_2") class MathLevel2Provider(_MathProviderBase): def __init__(self): super().__init__(levels=("Level 2",)) @register_data_provider("math_level_3") class MathLevel3Provider(_MathProviderBase): def __init__(self): super().__init__(levels=("Level 3",)) @register_data_provider("math_level_4") class MathLevel4Provider(_MathProviderBase): def __init__(self): super().__init__(levels=("Level 4",)) @register_data_provider("math_level_5") class MathLevel5Provider(_MathProviderBase): def __init__(self): super().__init__(levels=("Level 5",)) @register_data_provider("math_levels_12") class MathLevels12Provider(_MathProviderBase): def __init__(self): super().__init__(levels=("Level 1", "Level 2")) @register_data_provider("math_levels_345") class MathLevels345Provider(_MathProviderBase): def __init__(self): super().__init__(levels=("Level 3", "Level 4", "Level 5")) @register_data_provider("gsm8k_math_stage12") class GSM8KMathStage12Provider: def load( self, split: str, max_samples: int | None = None, cache_dir: str | None = None, ) -> list[TrainingSample]: gsm = GSM8KProvider().load(split=split, max_samples=None, cache_dir=cache_dir) math12 = MathLevels12Provider().load( split=split, max_samples=None, cache_dir=cache_dir ) mixed = _interleave_samples(gsm, math12) return _slice_if_needed(mixed, max_samples) @register_data_provider("gsm8k_math_curriculum") class GSM8KMathCurriculumProvider: def load( self, split: str, max_samples: int | None = None, cache_dir: str | None = None, ) -> list[TrainingSample]: if max_samples is None: stage12_budget = None stage345_budget = None else: stage12_budget = (max_samples + 1) // 2 stage345_budget = max_samples // 2 stage12 = GSM8KMathStage12Provider().load( split=split, max_samples=stage12_budget, cache_dir=cache_dir ) stage345 = MathLevels345Provider().load( split=split, max_samples=stage345_budget, cache_dir=cache_dir ) # Curriculum order: first easier mixed set, then harder levels. return stage12 + stage345 def to_dataset_rows(samples: list[TrainingSample]) -> list[dict]: return [asdict(sample) for sample in samples]