Spaces:
Sleeping
Sleeping
File size: 6,912 Bytes
0738d13 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 | """Generate data/problems_pool.json from GSM8K and MATH per PROJECT.md Section 12.3.
Target: 500 problems (300 GSM8K + 200 MATH), integer-answer only, deterministic.
Usage::
python scripts/generate_problems_pool.py
Re-runs are reproducible because we seed the source-side shuffle with
``SOURCE_SEED`` (see below). Rerun whenever Section 12.3 changes or HF
dataset splits shift.
Dependencies:
pip install datasets
Answer extraction
-----------------
* GSM8K: answers end with ``#### N`` β extract the final integer.
* MATH: answers live inside ``\\boxed{...}`` β try ``int()`` on the contents.
Filter: reject anything whose extracted answer cannot be parsed as ``int``
(fractions, decimals, non-numeric, multi-part).
Difficulty heuristic (per the Phase 6 task spec β PROJECT.md Section 12 itself
does not specify one)::
easy if answer <= 100 and len(problem) < 150
hard if answer > 1000 or len(problem) > 300
medium otherwise
If HF access fails or <500 valid problems can be produced, the script stops
WITHOUT writing output and surfaces the shortfall β do NOT hand-roll a
substitute.
"""
from __future__ import annotations
import json
import random
import re
import sys
from pathlib import Path
# Reproducibility seed for source-side shuffling/sampling.
SOURCE_SEED = 20260425
GSM8K_TARGET = 300
MATH_TARGET = 200
OUTPUT_PATH = Path(__file__).resolve().parents[1] / "data" / "problems_pool.json"
def _classify_difficulty(problem: str, answer: int) -> str:
if answer <= 100 and len(problem) < 150:
return "easy"
if answer > 1000 or len(problem) > 300:
return "hard"
return "medium"
_GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+)")
_MATH_BOXED_RE = re.compile(r"\\boxed\{([^{}]*)\}")
def _extract_gsm8k_answer(raw_answer: str) -> int | None:
"""Pull the trailing ``#### N`` integer out of a GSM8K answer string."""
m = _GSM8K_ANSWER_RE.search(raw_answer)
if m is None:
return None
try:
return int(m.group(1))
except (TypeError, ValueError):
return None
def _extract_math_answer(raw_solution: str) -> int | None:
"""Pull an integer out of the final ``\\boxed{...}`` of a MATH solution."""
matches = _MATH_BOXED_RE.findall(raw_solution)
if not matches:
return None
candidate = matches[-1].strip()
try:
return int(candidate)
except (TypeError, ValueError):
return None
def _collect_gsm8k(target: int) -> list[dict]:
"""Pull ``target`` integer-answer problems from GSM8K (openai/gsm8k, main, train)."""
from datasets import load_dataset
ds = load_dataset("openai/gsm8k", "main", split="train")
# Deterministic shuffle over indices so re-runs match.
indices = list(range(len(ds)))
random.Random(SOURCE_SEED).shuffle(indices)
collected: list[dict] = []
for idx in indices:
if len(collected) >= target:
break
row = ds[idx]
question = row["question"].strip()
answer = _extract_gsm8k_answer(row["answer"])
if answer is None:
continue
collected.append(
{
"problem": question,
"answer": answer,
"source": "gsm8k",
}
)
return collected
def _collect_math(target: int) -> list[dict]:
"""Pull ``target`` integer-answer problems from the MATH algebra track.
We try ``EleutherAI/hendrycks_math`` first (the 2024+ maintained mirror
with per-subject configs); ``hendrycks/competition_math`` was deprecated.
"""
from datasets import load_dataset
load_errors: list[str] = []
ds = None
for source_name, loader in [
("EleutherAI/hendrycks_math[algebra]", lambda: load_dataset(
"EleutherAI/hendrycks_math", "algebra", split="train"
)),
("hendrycks/competition_math", lambda: load_dataset(
"hendrycks/competition_math", split="train"
)),
]:
try:
ds = loader()
print(f" MATH source: {source_name}")
break
except Exception as exc: # noqa: BLE001
load_errors.append(f"{source_name}: {exc}")
continue
if ds is None:
raise RuntimeError(
"Could not load any MATH dataset. Tried:\n "
+ "\n ".join(load_errors)
)
indices = list(range(len(ds)))
random.Random(SOURCE_SEED + 1).shuffle(indices)
collected: list[dict] = []
for idx in indices:
if len(collected) >= target:
break
row = ds[idx]
# Some mirrors use "problem"+"solution", others "question"+"answer".
question = (row.get("problem") or row.get("question") or "").strip()
raw_soln = row.get("solution") or row.get("answer") or ""
if not question or not raw_soln:
continue
# If algebra subset isn't available, fall back to filtering by "type".
subject = (row.get("type") or row.get("subject") or "algebra").lower()
if "algebra" not in subject:
continue
answer = _extract_math_answer(raw_soln)
if answer is None:
continue
collected.append(
{
"problem": question,
"answer": answer,
"source": "math_algebra",
}
)
return collected
def main() -> int:
print("Collecting GSM8K problems...")
gsm = _collect_gsm8k(GSM8K_TARGET)
print(f" GSM8K: {len(gsm)} valid integer-answer problems")
print("Collecting MATH algebra problems...")
math = _collect_math(MATH_TARGET)
print(f" MATH: {len(math)} valid integer-answer problems")
total_needed = GSM8K_TARGET + MATH_TARGET
combined_raw = gsm + math
# Shortfall check β STOP before writing if we're under target.
if len(combined_raw) < total_needed:
print(
f"ERROR: shortfall. Got {len(combined_raw)} valid problems, "
f"need {total_needed}. Not writing output.",
file=sys.stderr,
)
return 1
# Assign sequential ids from 1; attach difficulty; drop source.
pool: list[dict] = []
for i, entry in enumerate(combined_raw, start=1):
problem = entry["problem"]
answer = int(entry["answer"])
pool.append(
{
"id": i,
"problem": problem,
"answer": answer,
"difficulty": _classify_difficulty(problem, answer),
}
)
OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
with OUTPUT_PATH.open("w", encoding="utf-8") as fh:
json.dump(pool, fh, ensure_ascii=False, indent=2)
print(
f"Generated {len(pool)} problems: "
f"{len(gsm)} from GSM8K, {len(math)} from MATH -> {OUTPUT_PATH}"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|