Spaces:
Sleeping
Sleeping
Phase 6: Problems pool (500 GSM8K-style problems) and sampling API per PROJECT.md Section 12
0738d13 | """Generate data/problems_pool.json from GSM8K and MATH per PROJECT.md Section 12.3. | |
| Target: 500 problems (300 GSM8K + 200 MATH), integer-answer only, deterministic. | |
| Usage:: | |
| python scripts/generate_problems_pool.py | |
| Re-runs are reproducible because we seed the source-side shuffle with | |
| ``SOURCE_SEED`` (see below). Rerun whenever Section 12.3 changes or HF | |
| dataset splits shift. | |
| Dependencies: | |
| pip install datasets | |
| Answer extraction | |
| ----------------- | |
| * GSM8K: answers end with ``#### N`` β extract the final integer. | |
| * MATH: answers live inside ``\\boxed{...}`` β try ``int()`` on the contents. | |
| Filter: reject anything whose extracted answer cannot be parsed as ``int`` | |
| (fractions, decimals, non-numeric, multi-part). | |
| Difficulty heuristic (per the Phase 6 task spec β PROJECT.md Section 12 itself | |
| does not specify one):: | |
| easy if answer <= 100 and len(problem) < 150 | |
| hard if answer > 1000 or len(problem) > 300 | |
| medium otherwise | |
| If HF access fails or <500 valid problems can be produced, the script stops | |
| WITHOUT writing output and surfaces the shortfall β do NOT hand-roll a | |
| substitute. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import random | |
| import re | |
| import sys | |
| from pathlib import Path | |
| # Reproducibility seed for source-side shuffling/sampling. | |
| SOURCE_SEED = 20260425 | |
| GSM8K_TARGET = 300 | |
| MATH_TARGET = 200 | |
| OUTPUT_PATH = Path(__file__).resolve().parents[1] / "data" / "problems_pool.json" | |
| def _classify_difficulty(problem: str, answer: int) -> str: | |
| if answer <= 100 and len(problem) < 150: | |
| return "easy" | |
| if answer > 1000 or len(problem) > 300: | |
| return "hard" | |
| return "medium" | |
| _GSM8K_ANSWER_RE = re.compile(r"####\s*(-?\d+)") | |
| _MATH_BOXED_RE = re.compile(r"\\boxed\{([^{}]*)\}") | |
| def _extract_gsm8k_answer(raw_answer: str) -> int | None: | |
| """Pull the trailing ``#### N`` integer out of a GSM8K answer string.""" | |
| m = _GSM8K_ANSWER_RE.search(raw_answer) | |
| if m is None: | |
| return None | |
| try: | |
| return int(m.group(1)) | |
| except (TypeError, ValueError): | |
| return None | |
| def _extract_math_answer(raw_solution: str) -> int | None: | |
| """Pull an integer out of the final ``\\boxed{...}`` of a MATH solution.""" | |
| matches = _MATH_BOXED_RE.findall(raw_solution) | |
| if not matches: | |
| return None | |
| candidate = matches[-1].strip() | |
| try: | |
| return int(candidate) | |
| except (TypeError, ValueError): | |
| return None | |
| def _collect_gsm8k(target: int) -> list[dict]: | |
| """Pull ``target`` integer-answer problems from GSM8K (openai/gsm8k, main, train).""" | |
| from datasets import load_dataset | |
| ds = load_dataset("openai/gsm8k", "main", split="train") | |
| # Deterministic shuffle over indices so re-runs match. | |
| indices = list(range(len(ds))) | |
| random.Random(SOURCE_SEED).shuffle(indices) | |
| collected: list[dict] = [] | |
| for idx in indices: | |
| if len(collected) >= target: | |
| break | |
| row = ds[idx] | |
| question = row["question"].strip() | |
| answer = _extract_gsm8k_answer(row["answer"]) | |
| if answer is None: | |
| continue | |
| collected.append( | |
| { | |
| "problem": question, | |
| "answer": answer, | |
| "source": "gsm8k", | |
| } | |
| ) | |
| return collected | |
| def _collect_math(target: int) -> list[dict]: | |
| """Pull ``target`` integer-answer problems from the MATH algebra track. | |
| We try ``EleutherAI/hendrycks_math`` first (the 2024+ maintained mirror | |
| with per-subject configs); ``hendrycks/competition_math`` was deprecated. | |
| """ | |
| from datasets import load_dataset | |
| load_errors: list[str] = [] | |
| ds = None | |
| for source_name, loader in [ | |
| ("EleutherAI/hendrycks_math[algebra]", lambda: load_dataset( | |
| "EleutherAI/hendrycks_math", "algebra", split="train" | |
| )), | |
| ("hendrycks/competition_math", lambda: load_dataset( | |
| "hendrycks/competition_math", split="train" | |
| )), | |
| ]: | |
| try: | |
| ds = loader() | |
| print(f" MATH source: {source_name}") | |
| break | |
| except Exception as exc: # noqa: BLE001 | |
| load_errors.append(f"{source_name}: {exc}") | |
| continue | |
| if ds is None: | |
| raise RuntimeError( | |
| "Could not load any MATH dataset. Tried:\n " | |
| + "\n ".join(load_errors) | |
| ) | |
| indices = list(range(len(ds))) | |
| random.Random(SOURCE_SEED + 1).shuffle(indices) | |
| collected: list[dict] = [] | |
| for idx in indices: | |
| if len(collected) >= target: | |
| break | |
| row = ds[idx] | |
| # Some mirrors use "problem"+"solution", others "question"+"answer". | |
| question = (row.get("problem") or row.get("question") or "").strip() | |
| raw_soln = row.get("solution") or row.get("answer") or "" | |
| if not question or not raw_soln: | |
| continue | |
| # If algebra subset isn't available, fall back to filtering by "type". | |
| subject = (row.get("type") or row.get("subject") or "algebra").lower() | |
| if "algebra" not in subject: | |
| continue | |
| answer = _extract_math_answer(raw_soln) | |
| if answer is None: | |
| continue | |
| collected.append( | |
| { | |
| "problem": question, | |
| "answer": answer, | |
| "source": "math_algebra", | |
| } | |
| ) | |
| return collected | |
| def main() -> int: | |
| print("Collecting GSM8K problems...") | |
| gsm = _collect_gsm8k(GSM8K_TARGET) | |
| print(f" GSM8K: {len(gsm)} valid integer-answer problems") | |
| print("Collecting MATH algebra problems...") | |
| math = _collect_math(MATH_TARGET) | |
| print(f" MATH: {len(math)} valid integer-answer problems") | |
| total_needed = GSM8K_TARGET + MATH_TARGET | |
| combined_raw = gsm + math | |
| # Shortfall check β STOP before writing if we're under target. | |
| if len(combined_raw) < total_needed: | |
| print( | |
| f"ERROR: shortfall. Got {len(combined_raw)} valid problems, " | |
| f"need {total_needed}. Not writing output.", | |
| file=sys.stderr, | |
| ) | |
| return 1 | |
| # Assign sequential ids from 1; attach difficulty; drop source. | |
| pool: list[dict] = [] | |
| for i, entry in enumerate(combined_raw, start=1): | |
| problem = entry["problem"] | |
| answer = int(entry["answer"]) | |
| pool.append( | |
| { | |
| "id": i, | |
| "problem": problem, | |
| "answer": answer, | |
| "difficulty": _classify_difficulty(problem, answer), | |
| } | |
| ) | |
| OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True) | |
| with OUTPUT_PATH.open("w", encoding="utf-8") as fh: | |
| json.dump(pool, fh, ensure_ascii=False, indent=2) | |
| print( | |
| f"Generated {len(pool)} problems: " | |
| f"{len(gsm)} from GSM8K, {len(math)} from MATH -> {OUTPUT_PATH}" | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |