v2 / scripts /02_generate_cots.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Generate CoTs for MATH-500 (training set for direction extraction).
Also prepares test sets: MATH-500 holdout, AIME-24, and optional GPQA-D.
Supports --resume: skip already-generated samples by idx.
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR, RAW_COTS_PATH,
TEST_MATH_PATH, TEST_AIME_PATH, TEST_GPQA_PATH,
)
from configs.model import GEN_CONFIG
from src.utils import setup_logger, append_jsonl, compute_completed_ids, cleanup_memory
from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate
def load_math500(n_train: int, n_holdout: int):
from datasets import load_dataset
ds = load_dataset("HuggingFaceH4/MATH-500", split="test")
total = len(ds)
train = [dict(ds[i], idx=i) for i in range(min(n_train, total))]
holdout_start = max(n_train, total - n_holdout)
holdout = [dict(ds[i], idx=i) for i in range(holdout_start, total)]
return train, holdout
def load_aime24(n: int):
"""Load AIME 2024 problems. Try HuggingFaceH4 or fallback."""
from datasets import load_dataset
candidates = [
("HuggingFaceH4/aime_2024", None),
("Maxwell-Jia/AIME_2024", None),
]
for repo, cfg in candidates:
try:
ds = load_dataset(repo, split="train" if cfg is None else cfg)
rows = []
for i in range(min(n, len(ds))):
row = dict(ds[i])
row["idx"] = f"aime24_{i}"
# Normalize field name
if "problem" not in row and "Problem" in row:
row["problem"] = row["Problem"]
if "answer" not in row and "Answer" in row:
row["answer"] = row["Answer"]
rows.append(row)
return rows
except Exception:
continue
return []
def load_gpqa(n: int):
"""Load GPQA diamond subset."""
from datasets import load_dataset
try:
ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
rows = []
for i in range(min(n, len(ds))):
row = dict(ds[i])
row["idx"] = f"gpqa_{i}"
# GPQA fields: Question, Correct Answer
if "problem" not in row:
row["problem"] = row.get("Question", "")
if "answer" not in row:
row["answer"] = row.get("Correct Answer", "")
rows.append(row)
return rows
except Exception as e:
print(f"[warn] GPQA load failed: {e}")
return []
def generate_one(model, tokenizer, problem, max_new_tokens):
prompt = build_thinking_prompt(tokenizer, problem["problem"], enable_thinking=True)
cot = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
return {
"idx": problem["idx"],
"problem": problem["problem"],
"answer": problem.get("answer", ""),
"subject": problem.get("subject", ""),
"level": problem.get("level", -1),
"prompt": prompt,
"cot": cot,
"cot_len_tokens": len(tokenizer.encode(cot, add_special_tokens=False)),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--n_train", type=int, default=150,
help="MATH-500 training CoTs (for direction extraction)")
parser.add_argument("--n_math_test", type=int, default=50)
parser.add_argument("--n_aime", type=int, default=30)
parser.add_argument("--n_gpqa", type=int, default=20)
parser.add_argument("--max_new_tokens", type=int, default=None)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--skip_test", action="store_true",
help="only generate training CoTs, no test CoT generation")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("02_generate", LOGS_DIR / "02_generate.log")
mnt = args.max_new_tokens or GEN_CONFIG["max_new_tokens"]
log.info(f"n_train={args.n_train} n_math_test={args.n_math_test} "
f"n_aime={args.n_aime} n_gpqa={args.n_gpqa} max_new_tokens={mnt}")
# Load datasets (prepare splits before loading the model — save failure time)
train, math_test = load_math500(args.n_train, args.n_math_test)
log.info(f"MATH-500: train={len(train)}, holdout_test={len(math_test)}")
if not args.skip_test:
aime = load_aime24(args.n_aime)
gpqa = load_gpqa(args.n_gpqa)
log.info(f"AIME-24: {len(aime)} problems")
log.info(f"GPQA-D: {len(gpqa)} problems")
else:
aime = []
gpqa = []
# Load model
log.info("Loading model (this will take a few minutes)...")
model, tokenizer = load_model_and_tokenizer()
log.info("Model loaded.")
# ========== Generate training CoTs ==========
done = compute_completed_ids(RAW_COTS_PATH) if args.resume else set()
log.info(f"Training set: {len(done)} already completed, "
f"{len(train) - len(done)} remaining")
for prob in tqdm(train, desc="train CoTs"):
if prob["idx"] in done:
continue
try:
rec = generate_one(model, tokenizer, prob, mnt)
append_jsonl(rec, RAW_COTS_PATH)
except Exception as e:
log.error(f"idx={prob['idx']}: {e}")
cleanup_memory()
# ========== Generate test CoTs (for later downstream eval baseline) ==========
# Test CoTs are generated at downstream eval time (with steering). Here we only
# save the raw problems for scripts 09/12 to consume.
if not args.skip_test:
test_pairs = [
(math_test, TEST_MATH_PATH, "MATH-500-holdout"),
(aime, TEST_AIME_PATH, "AIME-24"),
(gpqa, TEST_GPQA_PATH, "GPQA-D"),
]
for ds, path, name in test_pairs:
if not ds:
continue
if args.resume and Path(path).exists():
log.info(f"{name}: test set already saved at {path}")
continue
# Just save the problem list; CoT generation happens during eval
from src.utils import write_jsonl
write_jsonl(ds, path)
log.info(f"{name}: {len(ds)} problems saved to {path}")
log.info("=" * 60)
log.info(f"Done. Training CoTs saved to {RAW_COTS_PATH}")
if __name__ == "__main__":
main()