File size: 6,182 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Token-level labeling of CoT text with three decision-point classes:

    S_plan: token immediately before a planning trigger (in the window before trigger)
    S_mon:  token immediately before a monitoring trigger
    S_exec: random \n\n-following token that is NEITHER plan nor mon

The "decision point" = first token AFTER a \n\n delimiter (step boundary).
We label decision points, not whole trigger spans. This gives us WRITE-direction samples.
"""
import random
import re
from typing import List, Tuple, Dict, Set
from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS, flatten_patterns


def find_newline_decision_points(offset_mapping: List[Tuple[int, int]], text: str) -> List[int]:
    """
    Return token indices that are the first non-empty token AFTER a \n\n in text.
    These are the "decision points" for our analysis.
    """
    decision_tis = []
    for ti in range(1, len(offset_mapping)):
        ts_prev, te_prev = offset_mapping[ti - 1]
        ts, te = offset_mapping[ti]
        # Check if the text ending at ti-1 or between ti-1 and ti contains \n\n
        if ts_prev >= len(text) or te_prev > len(text):
            continue
        # Look at the text slice just before this token
        # Simpler: check if the previous token contains a \n\n boundary
        prev_text = text[ts_prev:te_prev]
        if "\n\n" in prev_text:
            decision_tis.append(ti)
            continue
        # Also check the gap between tokens (tokenizers sometimes split on newline)
        gap_start = te_prev
        gap_end = ts
        if gap_end > gap_start and "\n\n" in text[gap_start:gap_end]:
            decision_tis.append(ti)
    return decision_tis


def find_trigger_char_positions(text: str, patterns: List[str]) -> List[int]:
    """Find char-start positions of all regex matches."""
    starts = set()
    for pat in patterns:
        for m in re.finditer(pat, text):
            starts.add(m.start())
    return sorted(starts)


def char_to_token_index(char_pos: int, offsets: List[Tuple[int, int]]) -> int:
    """Map a char position to the token index containing it. Returns -1 if not found."""
    for ti, (ts, te) in enumerate(offsets):
        if ts <= char_pos < te or (ts == te == char_pos):
            return ti
    return -1


def label_cot_decision_points(
    text: str,
    tokenizer,
    plan_window_before: int = 0,
    plan_window_after: int = 1,
    mon_window_before: int = 0,
    mon_window_after: int = 1,
    exec_sample_ratio: float = 1.0,   # sample at most this ratio of available exec points
    rng_seed: int = 42,
) -> Dict:
    """
    Returns a dict with:
        token_ids: List[int]
        offset_mapping: List[(int, int)]
        plan_decision_tis: List[int]   # decision points BEFORE planning triggers
        mon_decision_tis:  List[int]
        exec_decision_tis: List[int]   # neutral decision points (newline-following, not plan/mon)
        all_newline_tis:   List[int]   # for computing 'general' direction

    A decision point is classified as:
        plan  if  the trigger token is within [ti, ti + plan_window_after) for any plan regex match
        mon   if  similar for mon regex
        exec  if  ti is a newline-following token AND neither plan nor mon
    """
    enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False, truncation=False)
    token_ids = enc["input_ids"]
    offsets = enc["offset_mapping"]

    # Find all newline-following decision points
    all_newline_tis = find_newline_decision_points(offsets, text)
    all_newline_set: Set[int] = set(all_newline_tis)

    # Find plan/mon trigger token positions
    plan_flat = flatten_patterns(PLANNING_PATTERNS)
    mon_flat = flatten_patterns(MONITORING_PATTERNS)
    plan_char_starts = find_trigger_char_positions(text, plan_flat)
    mon_char_starts = find_trigger_char_positions(text, mon_flat)

    plan_trigger_tis = [char_to_token_index(cp, offsets) for cp in plan_char_starts]
    plan_trigger_tis = [t for t in plan_trigger_tis if t >= 0]
    mon_trigger_tis = [char_to_token_index(cp, offsets) for cp in mon_char_starts]
    mon_trigger_tis = [t for t in mon_trigger_tis if t >= 0]

    # For each newline decision point, classify
    plan_dp: List[int] = []
    mon_dp: List[int] = []

    # A decision point counts as plan if any plan trigger falls in [ti, ti + some_threshold]
    # We use: trigger must be within next 10 tokens after the decision point
    # (plan/mon triggers usually follow the newline within a few words)
    lookahead = 10
    plan_trigger_set = set(plan_trigger_tis)
    mon_trigger_set = set(mon_trigger_tis)

    for ti in all_newline_tis:
        is_plan = any((ti <= tt < ti + lookahead) for tt in plan_trigger_set)
        is_mon = any((ti <= tt < ti + lookahead) for tt in mon_trigger_set)
        if is_plan and not is_mon:
            plan_dp.append(ti)
        elif is_mon and not is_plan:
            mon_dp.append(ti)
        elif is_plan and is_mon:
            # If both match, assign to whichever is closer
            plan_min = min((tt - ti) for tt in plan_trigger_set if ti <= tt < ti + lookahead)
            mon_min = min((tt - ti) for tt in mon_trigger_set if ti <= tt < ti + lookahead)
            if plan_min <= mon_min:
                plan_dp.append(ti)
            else:
                mon_dp.append(ti)

    # exec = newline DP but not plan/mon
    classified = set(plan_dp) | set(mon_dp)
    exec_candidates = [ti for ti in all_newline_tis if ti not in classified]

    # Down-sample exec_candidates to balance class sizes
    rng = random.Random(rng_seed)
    rng.shuffle(exec_candidates)
    # Keep at most: 3 × max(len(plan), len(mon))
    target_exec = max(10, 3 * max(len(plan_dp), len(mon_dp)))
    exec_dp = exec_candidates[:target_exec]

    return {
        "token_ids": token_ids,
        "offset_mapping": offsets,
        "plan_decision_tis": plan_dp,
        "mon_decision_tis":  mon_dp,
        "exec_decision_tis": exec_dp,
        "all_newline_tis":   all_newline_tis,
        "n_plan": len(plan_dp),
        "n_mon":  len(mon_dp),
        "n_exec": len(exec_dp),
        "n_newlines_total": len(all_newline_tis),
    }