| """ |
| Generate CoTs for MATH-500 (training set for direction extraction). |
| Also prepares test sets: MATH-500 holdout, AIME-24, and optional GPQA-D. |
| |
| Supports --resume: skip already-generated samples by idx. |
| """ |
| import sys |
| import argparse |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| from tqdm import tqdm |
|
|
| from configs.paths import ( |
| ensure_dirs, LOGS_DIR, RAW_COTS_PATH, |
| TEST_MATH_PATH, TEST_AIME_PATH, TEST_GPQA_PATH, |
| ) |
| from configs.model import GEN_CONFIG |
| from src.utils import setup_logger, append_jsonl, compute_completed_ids, cleanup_memory |
| from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate |
|
|
|
|
| def load_math500(n_train: int, n_holdout: int): |
| from datasets import load_dataset |
| ds = load_dataset("HuggingFaceH4/MATH-500", split="test") |
| total = len(ds) |
| train = [dict(ds[i], idx=i) for i in range(min(n_train, total))] |
| holdout_start = max(n_train, total - n_holdout) |
| holdout = [dict(ds[i], idx=i) for i in range(holdout_start, total)] |
| return train, holdout |
|
|
|
|
| def load_aime24(n: int): |
| """Load AIME 2024 problems. Try HuggingFaceH4 or fallback.""" |
| from datasets import load_dataset |
| candidates = [ |
| ("HuggingFaceH4/aime_2024", None), |
| ("Maxwell-Jia/AIME_2024", None), |
| ] |
| for repo, cfg in candidates: |
| try: |
| ds = load_dataset(repo, split="train" if cfg is None else cfg) |
| rows = [] |
| for i in range(min(n, len(ds))): |
| row = dict(ds[i]) |
| row["idx"] = f"aime24_{i}" |
| |
| if "problem" not in row and "Problem" in row: |
| row["problem"] = row["Problem"] |
| if "answer" not in row and "Answer" in row: |
| row["answer"] = row["Answer"] |
| rows.append(row) |
| return rows |
| except Exception: |
| continue |
| return [] |
|
|
|
|
| def load_gpqa(n: int): |
| """Load GPQA diamond subset.""" |
| from datasets import load_dataset |
| try: |
| ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train") |
| rows = [] |
| for i in range(min(n, len(ds))): |
| row = dict(ds[i]) |
| row["idx"] = f"gpqa_{i}" |
| |
| if "problem" not in row: |
| row["problem"] = row.get("Question", "") |
| if "answer" not in row: |
| row["answer"] = row.get("Correct Answer", "") |
| rows.append(row) |
| return rows |
| except Exception as e: |
| print(f"[warn] GPQA load failed: {e}") |
| return [] |
|
|
|
|
| def generate_one(model, tokenizer, problem, max_new_tokens): |
| prompt = build_thinking_prompt(tokenizer, problem["problem"], enable_thinking=True) |
| cot = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) |
| return { |
| "idx": problem["idx"], |
| "problem": problem["problem"], |
| "answer": problem.get("answer", ""), |
| "subject": problem.get("subject", ""), |
| "level": problem.get("level", -1), |
| "prompt": prompt, |
| "cot": cot, |
| "cot_len_tokens": len(tokenizer.encode(cot, add_special_tokens=False)), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--n_train", type=int, default=150, |
| help="MATH-500 training CoTs (for direction extraction)") |
| parser.add_argument("--n_math_test", type=int, default=50) |
| parser.add_argument("--n_aime", type=int, default=30) |
| parser.add_argument("--n_gpqa", type=int, default=20) |
| parser.add_argument("--max_new_tokens", type=int, default=None) |
| parser.add_argument("--resume", action="store_true") |
| parser.add_argument("--skip_test", action="store_true", |
| help="only generate training CoTs, no test CoT generation") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("02_generate", LOGS_DIR / "02_generate.log") |
|
|
| mnt = args.max_new_tokens or GEN_CONFIG["max_new_tokens"] |
| log.info(f"n_train={args.n_train} n_math_test={args.n_math_test} " |
| f"n_aime={args.n_aime} n_gpqa={args.n_gpqa} max_new_tokens={mnt}") |
|
|
| |
| train, math_test = load_math500(args.n_train, args.n_math_test) |
| log.info(f"MATH-500: train={len(train)}, holdout_test={len(math_test)}") |
|
|
| if not args.skip_test: |
| aime = load_aime24(args.n_aime) |
| gpqa = load_gpqa(args.n_gpqa) |
| log.info(f"AIME-24: {len(aime)} problems") |
| log.info(f"GPQA-D: {len(gpqa)} problems") |
| else: |
| aime = [] |
| gpqa = [] |
|
|
| |
| log.info("Loading model (this will take a few minutes)...") |
| model, tokenizer = load_model_and_tokenizer() |
| log.info("Model loaded.") |
|
|
| |
| done = compute_completed_ids(RAW_COTS_PATH) if args.resume else set() |
| log.info(f"Training set: {len(done)} already completed, " |
| f"{len(train) - len(done)} remaining") |
|
|
| for prob in tqdm(train, desc="train CoTs"): |
| if prob["idx"] in done: |
| continue |
| try: |
| rec = generate_one(model, tokenizer, prob, mnt) |
| append_jsonl(rec, RAW_COTS_PATH) |
| except Exception as e: |
| log.error(f"idx={prob['idx']}: {e}") |
| cleanup_memory() |
|
|
| |
| |
| |
| if not args.skip_test: |
| test_pairs = [ |
| (math_test, TEST_MATH_PATH, "MATH-500-holdout"), |
| (aime, TEST_AIME_PATH, "AIME-24"), |
| (gpqa, TEST_GPQA_PATH, "GPQA-D"), |
| ] |
| for ds, path, name in test_pairs: |
| if not ds: |
| continue |
| if args.resume and Path(path).exists(): |
| log.info(f"{name}: test set already saved at {path}") |
| continue |
| |
| from src.utils import write_jsonl |
| write_jsonl(ds, path) |
| log.info(f"{name}: {len(ds)} problems saved to {path}") |
|
|
| log.info("=" * 60) |
| log.info(f"Done. Training CoTs saved to {RAW_COTS_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|