| """ |
| Token-level labeling of CoT text with three decision-point classes: |
| |
| S_plan: token immediately before a planning trigger (in the window before trigger) |
| S_mon: token immediately before a monitoring trigger |
| S_exec: random \n\n-following token that is NEITHER plan nor mon |
| |
| The "decision point" = first token AFTER a \n\n delimiter (step boundary). |
| We label decision points, not whole trigger spans. This gives us WRITE-direction samples. |
| """ |
| import random |
| import re |
| from typing import List, Tuple, Dict, Set |
| from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS, flatten_patterns |
|
|
|
|
| def find_newline_decision_points(offset_mapping: List[Tuple[int, int]], text: str) -> List[int]: |
| """ |
| Return token indices that are the first non-empty token AFTER a \n\n in text. |
| These are the "decision points" for our analysis. |
| """ |
| decision_tis = [] |
| for ti in range(1, len(offset_mapping)): |
| ts_prev, te_prev = offset_mapping[ti - 1] |
| ts, te = offset_mapping[ti] |
| |
| if ts_prev >= len(text) or te_prev > len(text): |
| continue |
| |
| |
| prev_text = text[ts_prev:te_prev] |
| if "\n\n" in prev_text: |
| decision_tis.append(ti) |
| continue |
| |
| gap_start = te_prev |
| gap_end = ts |
| if gap_end > gap_start and "\n\n" in text[gap_start:gap_end]: |
| decision_tis.append(ti) |
| return decision_tis |
|
|
|
|
| def find_trigger_char_positions(text: str, patterns: List[str]) -> List[int]: |
| """Find char-start positions of all regex matches.""" |
| starts = set() |
| for pat in patterns: |
| for m in re.finditer(pat, text): |
| starts.add(m.start()) |
| return sorted(starts) |
|
|
|
|
| def char_to_token_index(char_pos: int, offsets: List[Tuple[int, int]]) -> int: |
| """Map a char position to the token index containing it. Returns -1 if not found.""" |
| for ti, (ts, te) in enumerate(offsets): |
| if ts <= char_pos < te or (ts == te == char_pos): |
| return ti |
| return -1 |
|
|
|
|
| def label_cot_decision_points( |
| text: str, |
| tokenizer, |
| plan_window_before: int = 0, |
| plan_window_after: int = 1, |
| mon_window_before: int = 0, |
| mon_window_after: int = 1, |
| exec_sample_ratio: float = 1.0, |
| rng_seed: int = 42, |
| ) -> Dict: |
| """ |
| Returns a dict with: |
| token_ids: List[int] |
| offset_mapping: List[(int, int)] |
| plan_decision_tis: List[int] # decision points BEFORE planning triggers |
| mon_decision_tis: List[int] |
| exec_decision_tis: List[int] # neutral decision points (newline-following, not plan/mon) |
| all_newline_tis: List[int] # for computing 'general' direction |
| |
| A decision point is classified as: |
| plan if the trigger token is within [ti, ti + plan_window_after) for any plan regex match |
| mon if similar for mon regex |
| exec if ti is a newline-following token AND neither plan nor mon |
| """ |
| enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False, truncation=False) |
| token_ids = enc["input_ids"] |
| offsets = enc["offset_mapping"] |
|
|
| |
| all_newline_tis = find_newline_decision_points(offsets, text) |
| all_newline_set: Set[int] = set(all_newline_tis) |
|
|
| |
| plan_flat = flatten_patterns(PLANNING_PATTERNS) |
| mon_flat = flatten_patterns(MONITORING_PATTERNS) |
| plan_char_starts = find_trigger_char_positions(text, plan_flat) |
| mon_char_starts = find_trigger_char_positions(text, mon_flat) |
|
|
| plan_trigger_tis = [char_to_token_index(cp, offsets) for cp in plan_char_starts] |
| plan_trigger_tis = [t for t in plan_trigger_tis if t >= 0] |
| mon_trigger_tis = [char_to_token_index(cp, offsets) for cp in mon_char_starts] |
| mon_trigger_tis = [t for t in mon_trigger_tis if t >= 0] |
|
|
| |
| plan_dp: List[int] = [] |
| mon_dp: List[int] = [] |
|
|
| |
| |
| |
| lookahead = 10 |
| plan_trigger_set = set(plan_trigger_tis) |
| mon_trigger_set = set(mon_trigger_tis) |
|
|
| for ti in all_newline_tis: |
| is_plan = any((ti <= tt < ti + lookahead) for tt in plan_trigger_set) |
| is_mon = any((ti <= tt < ti + lookahead) for tt in mon_trigger_set) |
| if is_plan and not is_mon: |
| plan_dp.append(ti) |
| elif is_mon and not is_plan: |
| mon_dp.append(ti) |
| elif is_plan and is_mon: |
| |
| plan_min = min((tt - ti) for tt in plan_trigger_set if ti <= tt < ti + lookahead) |
| mon_min = min((tt - ti) for tt in mon_trigger_set if ti <= tt < ti + lookahead) |
| if plan_min <= mon_min: |
| plan_dp.append(ti) |
| else: |
| mon_dp.append(ti) |
|
|
| |
| classified = set(plan_dp) | set(mon_dp) |
| exec_candidates = [ti for ti in all_newline_tis if ti not in classified] |
|
|
| |
| rng = random.Random(rng_seed) |
| rng.shuffle(exec_candidates) |
| |
| target_exec = max(10, 3 * max(len(plan_dp), len(mon_dp))) |
| exec_dp = exec_candidates[:target_exec] |
|
|
| return { |
| "token_ids": token_ids, |
| "offset_mapping": offsets, |
| "plan_decision_tis": plan_dp, |
| "mon_decision_tis": mon_dp, |
| "exec_decision_tis": exec_dp, |
| "all_newline_tis": all_newline_tis, |
| "n_plan": len(plan_dp), |
| "n_mon": len(mon_dp), |
| "n_exec": len(exec_dp), |
| "n_newlines_total": len(all_newline_tis), |
| } |
|
|