""" Generate CoTs for MATH-500 (training set for direction extraction). Also prepares test sets: MATH-500 holdout, AIME-24, and optional GPQA-D. Supports --resume: skip already-generated samples by idx. """ import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) import torch from tqdm import tqdm from configs.paths import ( ensure_dirs, LOGS_DIR, RAW_COTS_PATH, TEST_MATH_PATH, TEST_AIME_PATH, TEST_GPQA_PATH, ) from configs.model import GEN_CONFIG from src.utils import setup_logger, append_jsonl, compute_completed_ids, cleanup_memory from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate def load_math500(n_train: int, n_holdout: int): from datasets import load_dataset ds = load_dataset("HuggingFaceH4/MATH-500", split="test") total = len(ds) train = [dict(ds[i], idx=i) for i in range(min(n_train, total))] holdout_start = max(n_train, total - n_holdout) holdout = [dict(ds[i], idx=i) for i in range(holdout_start, total)] return train, holdout def load_aime24(n: int): """Load AIME 2024 problems. Try HuggingFaceH4 or fallback.""" from datasets import load_dataset candidates = [ ("HuggingFaceH4/aime_2024", None), ("Maxwell-Jia/AIME_2024", None), ] for repo, cfg in candidates: try: ds = load_dataset(repo, split="train" if cfg is None else cfg) rows = [] for i in range(min(n, len(ds))): row = dict(ds[i]) row["idx"] = f"aime24_{i}" # Normalize field name if "problem" not in row and "Problem" in row: row["problem"] = row["Problem"] if "answer" not in row and "Answer" in row: row["answer"] = row["Answer"] rows.append(row) return rows except Exception: continue return [] def load_gpqa(n: int): """Load GPQA diamond subset.""" from datasets import load_dataset try: ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train") rows = [] for i in range(min(n, len(ds))): row = dict(ds[i]) row["idx"] = f"gpqa_{i}" # GPQA fields: Question, Correct Answer if "problem" not in row: row["problem"] = row.get("Question", "") if "answer" not in row: row["answer"] = row.get("Correct Answer", "") rows.append(row) return rows except Exception as e: print(f"[warn] GPQA load failed: {e}") return [] def generate_one(model, tokenizer, problem, max_new_tokens): prompt = build_thinking_prompt(tokenizer, problem["problem"], enable_thinking=True) cot = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens) return { "idx": problem["idx"], "problem": problem["problem"], "answer": problem.get("answer", ""), "subject": problem.get("subject", ""), "level": problem.get("level", -1), "prompt": prompt, "cot": cot, "cot_len_tokens": len(tokenizer.encode(cot, add_special_tokens=False)), } def main(): parser = argparse.ArgumentParser() parser.add_argument("--n_train", type=int, default=150, help="MATH-500 training CoTs (for direction extraction)") parser.add_argument("--n_math_test", type=int, default=50) parser.add_argument("--n_aime", type=int, default=30) parser.add_argument("--n_gpqa", type=int, default=20) parser.add_argument("--max_new_tokens", type=int, default=None) parser.add_argument("--resume", action="store_true") parser.add_argument("--skip_test", action="store_true", help="only generate training CoTs, no test CoT generation") args = parser.parse_args() ensure_dirs() log = setup_logger("02_generate", LOGS_DIR / "02_generate.log") mnt = args.max_new_tokens or GEN_CONFIG["max_new_tokens"] log.info(f"n_train={args.n_train} n_math_test={args.n_math_test} " f"n_aime={args.n_aime} n_gpqa={args.n_gpqa} max_new_tokens={mnt}") # Load datasets (prepare splits before loading the model — save failure time) train, math_test = load_math500(args.n_train, args.n_math_test) log.info(f"MATH-500: train={len(train)}, holdout_test={len(math_test)}") if not args.skip_test: aime = load_aime24(args.n_aime) gpqa = load_gpqa(args.n_gpqa) log.info(f"AIME-24: {len(aime)} problems") log.info(f"GPQA-D: {len(gpqa)} problems") else: aime = [] gpqa = [] # Load model log.info("Loading model (this will take a few minutes)...") model, tokenizer = load_model_and_tokenizer() log.info("Model loaded.") # ========== Generate training CoTs ========== done = compute_completed_ids(RAW_COTS_PATH) if args.resume else set() log.info(f"Training set: {len(done)} already completed, " f"{len(train) - len(done)} remaining") for prob in tqdm(train, desc="train CoTs"): if prob["idx"] in done: continue try: rec = generate_one(model, tokenizer, prob, mnt) append_jsonl(rec, RAW_COTS_PATH) except Exception as e: log.error(f"idx={prob['idx']}: {e}") cleanup_memory() # ========== Generate test CoTs (for later downstream eval baseline) ========== # Test CoTs are generated at downstream eval time (with steering). Here we only # save the raw problems for scripts 09/12 to consume. if not args.skip_test: test_pairs = [ (math_test, TEST_MATH_PATH, "MATH-500-holdout"), (aime, TEST_AIME_PATH, "AIME-24"), (gpqa, TEST_GPQA_PATH, "GPQA-D"), ] for ds, path, name in test_pairs: if not ds: continue if args.resume and Path(path).exists(): log.info(f"{name}: test set already saved at {path}") continue # Just save the problem list; CoT generation happens during eval from src.utils import write_jsonl write_jsonl(ds, path) log.info(f"{name}: {len(ds)} problems saved to {path}") log.info("=" * 60) log.info(f"Done. Training CoTs saved to {RAW_COTS_PATH}") if __name__ == "__main__": main()