File size: 3,258 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Label training CoTs with three decision-point classes: plan / mon / exec.
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from tqdm import tqdm
from transformers import AutoTokenizer

from configs.paths import ensure_dirs, LOGS_DIR, RAW_COTS_PATH, LABELED_COTS_PATH
from configs.model import MODEL_CONFIG
from src.utils import setup_logger, read_jsonl, write_jsonl
from src.labeling import label_cot_decision_points
from src.detectors import BehaviorDetector


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--resume", action="store_true")
    args = parser.parse_args()

    ensure_dirs()
    log = setup_logger("03_label", LOGS_DIR / "03_label.log")

    if args.resume and LABELED_COTS_PATH.exists():
        log.info(f"Labeled file exists: {LABELED_COTS_PATH}. Skipping.")
        return

    log.info(f"Loading tokenizer: {MODEL_CONFIG['local_dir']}")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_CONFIG["local_dir"], trust_remote_code=True
    )

    log.info(f"Reading CoTs: {RAW_COTS_PATH}")
    cots = read_jsonl(RAW_COTS_PATH)
    log.info(f"Got {len(cots)} CoTs")

    mon_det = BehaviorDetector("monitoring")
    plan_det = BehaviorDetector("planning")

    total_plan = total_mon = total_exec = total_newlines = 0
    labeled = []
    for r in tqdm(cots, desc="labeling"):
        text = r["cot"]
        lab = label_cot_decision_points(text, tokenizer)

        # Sanity check with behavior detectors
        mon_cnt = mon_det.detect(text)["total"]
        plan_cnt = plan_det.detect(text)["total"]

        rec = {
            "idx": r["idx"],
            "problem": r["problem"],
            "cot": text,
            "cot_len_tokens": r.get("cot_len_tokens", len(lab["token_ids"])),
            "token_ids": lab["token_ids"],
            "plan_decision_tis": lab["plan_decision_tis"],
            "mon_decision_tis":  lab["mon_decision_tis"],
            "exec_decision_tis": lab["exec_decision_tis"],
            "all_newline_tis":   lab["all_newline_tis"],
            "n_plan_dp": lab["n_plan"],
            "n_mon_dp":  lab["n_mon"],
            "n_exec_dp": lab["n_exec"],
            "n_newlines_total": lab["n_newlines_total"],
            "detector_plan_total": plan_cnt,
            "detector_mon_total":  mon_cnt,
        }
        labeled.append(rec)
        total_plan += lab["n_plan"]
        total_mon += lab["n_mon"]
        total_exec += lab["n_exec"]
        total_newlines += lab["n_newlines_total"]

    write_jsonl(labeled, LABELED_COTS_PATH)

    log.info("=" * 60)
    log.info("LABELING SUMMARY")
    log.info(f"  N CoTs: {len(labeled)}")
    log.info(f"  Total plan decision points: {total_plan}")
    log.info(f"  Total mon  decision points: {total_mon}")
    log.info(f"  Total exec decision points: {total_exec}")
    log.info(f"  Total newlines overall: {total_newlines}")
    log.info(f"  Saved -> {LABELED_COTS_PATH}")

    if total_plan < 100:
        log.warning(f"Only {total_plan} planning decision points — probes may be weak")
    if total_mon < 100:
        log.warning(f"Only {total_mon} monitoring decision points — probes may be weak")


if __name__ == "__main__":
    main()