File size: 6,455 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | """
Generate CoTs for MATH-500 (training set for direction extraction).
Also prepares test sets: MATH-500 holdout, AIME-24, and optional GPQA-D.
Supports --resume: skip already-generated samples by idx.
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import torch
from tqdm import tqdm
from configs.paths import (
ensure_dirs, LOGS_DIR, RAW_COTS_PATH,
TEST_MATH_PATH, TEST_AIME_PATH, TEST_GPQA_PATH,
)
from configs.model import GEN_CONFIG
from src.utils import setup_logger, append_jsonl, compute_completed_ids, cleanup_memory
from src.model_io import load_model_and_tokenizer, build_thinking_prompt, generate
def load_math500(n_train: int, n_holdout: int):
from datasets import load_dataset
ds = load_dataset("HuggingFaceH4/MATH-500", split="test")
total = len(ds)
train = [dict(ds[i], idx=i) for i in range(min(n_train, total))]
holdout_start = max(n_train, total - n_holdout)
holdout = [dict(ds[i], idx=i) for i in range(holdout_start, total)]
return train, holdout
def load_aime24(n: int):
"""Load AIME 2024 problems. Try HuggingFaceH4 or fallback."""
from datasets import load_dataset
candidates = [
("HuggingFaceH4/aime_2024", None),
("Maxwell-Jia/AIME_2024", None),
]
for repo, cfg in candidates:
try:
ds = load_dataset(repo, split="train" if cfg is None else cfg)
rows = []
for i in range(min(n, len(ds))):
row = dict(ds[i])
row["idx"] = f"aime24_{i}"
# Normalize field name
if "problem" not in row and "Problem" in row:
row["problem"] = row["Problem"]
if "answer" not in row and "Answer" in row:
row["answer"] = row["Answer"]
rows.append(row)
return rows
except Exception:
continue
return []
def load_gpqa(n: int):
"""Load GPQA diamond subset."""
from datasets import load_dataset
try:
ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train")
rows = []
for i in range(min(n, len(ds))):
row = dict(ds[i])
row["idx"] = f"gpqa_{i}"
# GPQA fields: Question, Correct Answer
if "problem" not in row:
row["problem"] = row.get("Question", "")
if "answer" not in row:
row["answer"] = row.get("Correct Answer", "")
rows.append(row)
return rows
except Exception as e:
print(f"[warn] GPQA load failed: {e}")
return []
def generate_one(model, tokenizer, problem, max_new_tokens):
prompt = build_thinking_prompt(tokenizer, problem["problem"], enable_thinking=True)
cot = generate(model, tokenizer, prompt, max_new_tokens=max_new_tokens)
return {
"idx": problem["idx"],
"problem": problem["problem"],
"answer": problem.get("answer", ""),
"subject": problem.get("subject", ""),
"level": problem.get("level", -1),
"prompt": prompt,
"cot": cot,
"cot_len_tokens": len(tokenizer.encode(cot, add_special_tokens=False)),
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--n_train", type=int, default=150,
help="MATH-500 training CoTs (for direction extraction)")
parser.add_argument("--n_math_test", type=int, default=50)
parser.add_argument("--n_aime", type=int, default=30)
parser.add_argument("--n_gpqa", type=int, default=20)
parser.add_argument("--max_new_tokens", type=int, default=None)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--skip_test", action="store_true",
help="only generate training CoTs, no test CoT generation")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("02_generate", LOGS_DIR / "02_generate.log")
mnt = args.max_new_tokens or GEN_CONFIG["max_new_tokens"]
log.info(f"n_train={args.n_train} n_math_test={args.n_math_test} "
f"n_aime={args.n_aime} n_gpqa={args.n_gpqa} max_new_tokens={mnt}")
# Load datasets (prepare splits before loading the model — save failure time)
train, math_test = load_math500(args.n_train, args.n_math_test)
log.info(f"MATH-500: train={len(train)}, holdout_test={len(math_test)}")
if not args.skip_test:
aime = load_aime24(args.n_aime)
gpqa = load_gpqa(args.n_gpqa)
log.info(f"AIME-24: {len(aime)} problems")
log.info(f"GPQA-D: {len(gpqa)} problems")
else:
aime = []
gpqa = []
# Load model
log.info("Loading model (this will take a few minutes)...")
model, tokenizer = load_model_and_tokenizer()
log.info("Model loaded.")
# ========== Generate training CoTs ==========
done = compute_completed_ids(RAW_COTS_PATH) if args.resume else set()
log.info(f"Training set: {len(done)} already completed, "
f"{len(train) - len(done)} remaining")
for prob in tqdm(train, desc="train CoTs"):
if prob["idx"] in done:
continue
try:
rec = generate_one(model, tokenizer, prob, mnt)
append_jsonl(rec, RAW_COTS_PATH)
except Exception as e:
log.error(f"idx={prob['idx']}: {e}")
cleanup_memory()
# ========== Generate test CoTs (for later downstream eval baseline) ==========
# Test CoTs are generated at downstream eval time (with steering). Here we only
# save the raw problems for scripts 09/12 to consume.
if not args.skip_test:
test_pairs = [
(math_test, TEST_MATH_PATH, "MATH-500-holdout"),
(aime, TEST_AIME_PATH, "AIME-24"),
(gpqa, TEST_GPQA_PATH, "GPQA-D"),
]
for ds, path, name in test_pairs:
if not ds:
continue
if args.resume and Path(path).exists():
log.info(f"{name}: test set already saved at {path}")
continue
# Just save the problem list; CoT generation happens during eval
from src.utils import write_jsonl
write_jsonl(ds, path)
log.info(f"{name}: {len(ds)} problems saved to {path}")
log.info("=" * 60)
log.info(f"Done. Training CoTs saved to {RAW_COTS_PATH}")
if __name__ == "__main__":
main()
|