neuralese_temp / src /hackable /data_plugins.py
psidharth567's picture
Export neuralese codebase (cache and .env excluded).
dbc69f3
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import asdict
from pathlib import Path
from .interfaces import TrainingSample
from .registry import register_data_provider
@contextmanager
def _serialized_hf_dataset_download():
"""
Serialize Hugging Face ``datasets`` downloads/prepare across processes.
Multi-GPU ``accelerate launch`` otherwise races on the same ``cache_dir`` and can
leave a half-written tree (e.g. missing ``dataset_info.json``).
"""
root = Path.home() / ".cache" / "neuralese"
root.mkdir(parents=True, exist_ok=True)
lock_path = root / "hf_dataset_download.lock"
try:
from filelock import FileLock
with FileLock(str(lock_path), timeout=7200):
yield
except ImportError:
yield
def _load_hf_split(
path: str,
split: str,
cache_dir: str | None,
config_name: str | None = None,
):
from datasets import load_dataset
with _serialized_hf_dataset_download():
try:
if config_name is not None:
return load_dataset(path, config_name, split=split, cache_dir=cache_dir)
return load_dataset(path, split=split, cache_dir=cache_dir)
except FileNotFoundError as exc:
if not cache_dir:
raise RuntimeError(
"Hugging Face dataset files are missing from the default cache. "
"Run once with HF_DATASETS_OFFLINE=0 (or download the dataset), "
"or set HF_HOME / HF_DATASETS_CACHE to a populated cache."
) from exc
try:
if config_name is not None:
return load_dataset(path, config_name, split=split, cache_dir=None)
return load_dataset(path, split=split, cache_dir=None)
except FileNotFoundError as exc2:
raise RuntimeError(
"Could not load the dataset from the experiment cache_dir or the default HF cache. "
"Seed ~/.cache/huggingface/datasets (or your HF_HOME) with HF_DATASETS_OFFLINE=0, "
"or point storage.cache_dir at a shared cache that already contains the dataset."
) from exc2
PROMPT_PREFIX = (
"Solve the following math problem.\n"
"Think step-by-step inside <think>...</think> tags.\n"
"Then output only the final answer in LaTeX boxed format.\n"
"Do not include any words or explanations outside the tags/boxed answer.\n"
"Output format must be exactly:\n"
"<think>your reasoning</think>\n"
"\\boxed{your_final_answer}\n\n"
)
def _build_math_prompt(question: str) -> str:
user_content = f"{PROMPT_PREFIX}Question: {question}"
# Chat-style prefill so decoding starts after "assistant:".
return f"user: {user_content}\nassistant:"
def _interleave_samples(
left: list[TrainingSample], right: list[TrainingSample]
) -> list[TrainingSample]:
output: list[TrainingSample] = []
width = max(len(left), len(right))
for idx in range(width):
if idx < len(left):
output.append(left[idx])
if idx < len(right):
output.append(right[idx])
return output
def _slice_if_needed(
samples: list[TrainingSample], max_samples: int | None
) -> list[TrainingSample]:
if max_samples is None:
return samples
return samples[: max(0, max_samples)]
class _MathProviderBase:
dataset_name = "EleutherAI/hendrycks_math"
dataset_configs = (
"algebra",
"counting_and_probability",
"geometry",
"intermediate_algebra",
"number_theory",
"prealgebra",
"precalculus",
)
def __init__(self, levels: tuple[str, ...]):
self.levels = levels
def load(
self,
split: str,
max_samples: int | None = None,
cache_dir: str | None = None,
) -> list[TrainingSample]:
try:
import datasets # noqa: F401
except Exception as exc:
raise RuntimeError(
"datasets is required for Hendrycks MATH providers. Install dependencies first."
) from exc
level_set = {level.strip() for level in self.levels}
output: list[TrainingSample] = []
for config_name in self.dataset_configs:
rows = _load_hf_split(
self.dataset_name,
split,
cache_dir,
config_name=config_name,
)
for row in rows:
level = str(row.get("level", "")).strip()
if level not in level_set:
continue
question = str(row.get("problem", ""))
target = str(row.get("solution", ""))
output.append(
TrainingSample(
prompt=_build_math_prompt(question),
target=target,
metadata={
"dataset": "hendrycks_math",
"subject": config_name,
"level": level,
},
)
)
if max_samples is not None and len(output) >= max_samples:
return output
return output
@register_data_provider("gsm8k")
class GSM8KProvider:
def __init__(self, dataset_name: str = "openai/gsm8k", subset: str = "main"):
self.dataset_name = dataset_name
self.subset = subset
def load(
self,
split: str,
max_samples: int | None = None,
cache_dir: str | None = None,
) -> list[TrainingSample]:
try:
import datasets # noqa: F401
except Exception as exc:
raise RuntimeError(
"datasets is required for GSM8K provider. Install dependencies first."
) from exc
rows = _load_hf_split(
self.dataset_name,
split,
cache_dir,
config_name=self.subset,
)
if max_samples is not None:
rows = rows.select(range(min(max_samples, len(rows))))
output: list[TrainingSample] = []
for sample_index, row in enumerate(rows):
prompt = _build_math_prompt(str(row["question"]))
output.append(
TrainingSample(
prompt=prompt,
target=row["answer"],
metadata={
"dataset": "gsm8k",
"sample_index": int(sample_index),
"split": str(split),
},
)
)
return output
@register_data_provider("math_level_1")
class MathLevel1Provider(_MathProviderBase):
def __init__(self):
super().__init__(levels=("Level 1",))
@register_data_provider("math_level_2")
class MathLevel2Provider(_MathProviderBase):
def __init__(self):
super().__init__(levels=("Level 2",))
@register_data_provider("math_level_3")
class MathLevel3Provider(_MathProviderBase):
def __init__(self):
super().__init__(levels=("Level 3",))
@register_data_provider("math_level_4")
class MathLevel4Provider(_MathProviderBase):
def __init__(self):
super().__init__(levels=("Level 4",))
@register_data_provider("math_level_5")
class MathLevel5Provider(_MathProviderBase):
def __init__(self):
super().__init__(levels=("Level 5",))
@register_data_provider("math_levels_12")
class MathLevels12Provider(_MathProviderBase):
def __init__(self):
super().__init__(levels=("Level 1", "Level 2"))
@register_data_provider("math_levels_345")
class MathLevels345Provider(_MathProviderBase):
def __init__(self):
super().__init__(levels=("Level 3", "Level 4", "Level 5"))
@register_data_provider("gsm8k_math_stage12")
class GSM8KMathStage12Provider:
def load(
self,
split: str,
max_samples: int | None = None,
cache_dir: str | None = None,
) -> list[TrainingSample]:
gsm = GSM8KProvider().load(split=split, max_samples=None, cache_dir=cache_dir)
math12 = MathLevels12Provider().load(
split=split, max_samples=None, cache_dir=cache_dir
)
mixed = _interleave_samples(gsm, math12)
return _slice_if_needed(mixed, max_samples)
@register_data_provider("gsm8k_math_curriculum")
class GSM8KMathCurriculumProvider:
def load(
self,
split: str,
max_samples: int | None = None,
cache_dir: str | None = None,
) -> list[TrainingSample]:
if max_samples is None:
stage12_budget = None
stage345_budget = None
else:
stage12_budget = (max_samples + 1) // 2
stage345_budget = max_samples // 2
stage12 = GSM8KMathStage12Provider().load(
split=split, max_samples=stage12_budget, cache_dir=cache_dir
)
stage345 = MathLevels345Provider().load(
split=split, max_samples=stage345_budget, cache_dir=cache_dir
)
# Curriculum order: first easier mixed set, then harder levels.
return stage12 + stage345
def to_dataset_rows(samples: list[TrainingSample]) -> list[dict]:
return [asdict(sample) for sample in samples]