| """ |
| Label training CoTs with three decision-point classes: plan / mon / exec. |
| """ |
| import sys |
| import argparse |
| from pathlib import Path |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
| from configs.paths import ensure_dirs, LOGS_DIR, RAW_COTS_PATH, LABELED_COTS_PATH |
| from configs.model import MODEL_CONFIG |
| from src.utils import setup_logger, read_jsonl, write_jsonl |
| from src.labeling import label_cot_decision_points |
| from src.detectors import BehaviorDetector |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--resume", action="store_true") |
| args = parser.parse_args() |
|
|
| ensure_dirs() |
| log = setup_logger("03_label", LOGS_DIR / "03_label.log") |
|
|
| if args.resume and LABELED_COTS_PATH.exists(): |
| log.info(f"Labeled file exists: {LABELED_COTS_PATH}. Skipping.") |
| return |
|
|
| log.info(f"Loading tokenizer: {MODEL_CONFIG['local_dir']}") |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_CONFIG["local_dir"], trust_remote_code=True |
| ) |
|
|
| log.info(f"Reading CoTs: {RAW_COTS_PATH}") |
| cots = read_jsonl(RAW_COTS_PATH) |
| log.info(f"Got {len(cots)} CoTs") |
|
|
| mon_det = BehaviorDetector("monitoring") |
| plan_det = BehaviorDetector("planning") |
|
|
| total_plan = total_mon = total_exec = total_newlines = 0 |
| labeled = [] |
| for r in tqdm(cots, desc="labeling"): |
| text = r["cot"] |
| lab = label_cot_decision_points(text, tokenizer) |
|
|
| |
| mon_cnt = mon_det.detect(text)["total"] |
| plan_cnt = plan_det.detect(text)["total"] |
|
|
| rec = { |
| "idx": r["idx"], |
| "problem": r["problem"], |
| "cot": text, |
| "cot_len_tokens": r.get("cot_len_tokens", len(lab["token_ids"])), |
| "token_ids": lab["token_ids"], |
| "plan_decision_tis": lab["plan_decision_tis"], |
| "mon_decision_tis": lab["mon_decision_tis"], |
| "exec_decision_tis": lab["exec_decision_tis"], |
| "all_newline_tis": lab["all_newline_tis"], |
| "n_plan_dp": lab["n_plan"], |
| "n_mon_dp": lab["n_mon"], |
| "n_exec_dp": lab["n_exec"], |
| "n_newlines_total": lab["n_newlines_total"], |
| "detector_plan_total": plan_cnt, |
| "detector_mon_total": mon_cnt, |
| } |
| labeled.append(rec) |
| total_plan += lab["n_plan"] |
| total_mon += lab["n_mon"] |
| total_exec += lab["n_exec"] |
| total_newlines += lab["n_newlines_total"] |
|
|
| write_jsonl(labeled, LABELED_COTS_PATH) |
|
|
| log.info("=" * 60) |
| log.info("LABELING SUMMARY") |
| log.info(f" N CoTs: {len(labeled)}") |
| log.info(f" Total plan decision points: {total_plan}") |
| log.info(f" Total mon decision points: {total_mon}") |
| log.info(f" Total exec decision points: {total_exec}") |
| log.info(f" Total newlines overall: {total_newlines}") |
| log.info(f" Saved -> {LABELED_COTS_PATH}") |
|
|
| if total_plan < 100: |
| log.warning(f"Only {total_plan} planning decision points — probes may be weak") |
| if total_mon < 100: |
| log.warning(f"Only {total_mon} monitoring decision points — probes may be weak") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|