v2 / scripts /03_label_cots.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Label training CoTs with three decision-point classes: plan / mon / exec.
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from tqdm import tqdm
from transformers import AutoTokenizer
from configs.paths import ensure_dirs, LOGS_DIR, RAW_COTS_PATH, LABELED_COTS_PATH
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, write_jsonl
from src.labeling import label_cot_decision_points
from src.detectors import BehaviorDetector
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--resume", action="store_true")
args = parser.parse_args()
ensure_dirs()
log = setup_logger("03_label", LOGS_DIR / "03_label.log")
if args.resume and LABELED_COTS_PATH.exists():
log.info(f"Labeled file exists: {LABELED_COTS_PATH}. Skipping.")
return
log.info(f"Loading tokenizer: {MODEL_CONFIG['local_dir']}")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_CONFIG["local_dir"], trust_remote_code=True
)
log.info(f"Reading CoTs: {RAW_COTS_PATH}")
cots = read_jsonl(RAW_COTS_PATH)
log.info(f"Got {len(cots)} CoTs")
mon_det = BehaviorDetector("monitoring")
plan_det = BehaviorDetector("planning")
total_plan = total_mon = total_exec = total_newlines = 0
labeled = []
for r in tqdm(cots, desc="labeling"):
text = r["cot"]
lab = label_cot_decision_points(text, tokenizer)
# Sanity check with behavior detectors
mon_cnt = mon_det.detect(text)["total"]
plan_cnt = plan_det.detect(text)["total"]
rec = {
"idx": r["idx"],
"problem": r["problem"],
"cot": text,
"cot_len_tokens": r.get("cot_len_tokens", len(lab["token_ids"])),
"token_ids": lab["token_ids"],
"plan_decision_tis": lab["plan_decision_tis"],
"mon_decision_tis": lab["mon_decision_tis"],
"exec_decision_tis": lab["exec_decision_tis"],
"all_newline_tis": lab["all_newline_tis"],
"n_plan_dp": lab["n_plan"],
"n_mon_dp": lab["n_mon"],
"n_exec_dp": lab["n_exec"],
"n_newlines_total": lab["n_newlines_total"],
"detector_plan_total": plan_cnt,
"detector_mon_total": mon_cnt,
}
labeled.append(rec)
total_plan += lab["n_plan"]
total_mon += lab["n_mon"]
total_exec += lab["n_exec"]
total_newlines += lab["n_newlines_total"]
write_jsonl(labeled, LABELED_COTS_PATH)
log.info("=" * 60)
log.info("LABELING SUMMARY")
log.info(f" N CoTs: {len(labeled)}")
log.info(f" Total plan decision points: {total_plan}")
log.info(f" Total mon decision points: {total_mon}")
log.info(f" Total exec decision points: {total_exec}")
log.info(f" Total newlines overall: {total_newlines}")
log.info(f" Saved -> {LABELED_COTS_PATH}")
if total_plan < 100:
log.warning(f"Only {total_plan} planning decision points — probes may be weak")
if total_mon < 100:
log.warning(f"Only {total_mon} monitoring decision points — probes may be weak")
if __name__ == "__main__":
main()