""" Label training CoTs with three decision-point classes: plan / mon / exec. """ import sys import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from tqdm import tqdm from transformers import AutoTokenizer from configs.paths import ensure_dirs, LOGS_DIR, RAW_COTS_PATH, LABELED_COTS_PATH from configs.model import MODEL_CONFIG from src.utils import setup_logger, read_jsonl, write_jsonl from src.labeling import label_cot_decision_points from src.detectors import BehaviorDetector def main(): parser = argparse.ArgumentParser() parser.add_argument("--resume", action="store_true") args = parser.parse_args() ensure_dirs() log = setup_logger("03_label", LOGS_DIR / "03_label.log") if args.resume and LABELED_COTS_PATH.exists(): log.info(f"Labeled file exists: {LABELED_COTS_PATH}. Skipping.") return log.info(f"Loading tokenizer: {MODEL_CONFIG['local_dir']}") tokenizer = AutoTokenizer.from_pretrained( MODEL_CONFIG["local_dir"], trust_remote_code=True ) log.info(f"Reading CoTs: {RAW_COTS_PATH}") cots = read_jsonl(RAW_COTS_PATH) log.info(f"Got {len(cots)} CoTs") mon_det = BehaviorDetector("monitoring") plan_det = BehaviorDetector("planning") total_plan = total_mon = total_exec = total_newlines = 0 labeled = [] for r in tqdm(cots, desc="labeling"): text = r["cot"] lab = label_cot_decision_points(text, tokenizer) # Sanity check with behavior detectors mon_cnt = mon_det.detect(text)["total"] plan_cnt = plan_det.detect(text)["total"] rec = { "idx": r["idx"], "problem": r["problem"], "cot": text, "cot_len_tokens": r.get("cot_len_tokens", len(lab["token_ids"])), "token_ids": lab["token_ids"], "plan_decision_tis": lab["plan_decision_tis"], "mon_decision_tis": lab["mon_decision_tis"], "exec_decision_tis": lab["exec_decision_tis"], "all_newline_tis": lab["all_newline_tis"], "n_plan_dp": lab["n_plan"], "n_mon_dp": lab["n_mon"], "n_exec_dp": lab["n_exec"], "n_newlines_total": lab["n_newlines_total"], "detector_plan_total": plan_cnt, "detector_mon_total": mon_cnt, } labeled.append(rec) total_plan += lab["n_plan"] total_mon += lab["n_mon"] total_exec += lab["n_exec"] total_newlines += lab["n_newlines_total"] write_jsonl(labeled, LABELED_COTS_PATH) log.info("=" * 60) log.info("LABELING SUMMARY") log.info(f" N CoTs: {len(labeled)}") log.info(f" Total plan decision points: {total_plan}") log.info(f" Total mon decision points: {total_mon}") log.info(f" Total exec decision points: {total_exec}") log.info(f" Total newlines overall: {total_newlines}") log.info(f" Saved -> {LABELED_COTS_PATH}") if total_plan < 100: log.warning(f"Only {total_plan} planning decision points — probes may be weak") if total_mon < 100: log.warning(f"Only {total_mon} monitoring decision points — probes may be weak") if __name__ == "__main__": main()