""" Token-level labeling of CoT text with three decision-point classes: S_plan: token immediately before a planning trigger (in the window before trigger) S_mon: token immediately before a monitoring trigger S_exec: random \n\n-following token that is NEITHER plan nor mon The "decision point" = first token AFTER a \n\n delimiter (step boundary). We label decision points, not whole trigger spans. This gives us WRITE-direction samples. """ import random import re from typing import List, Tuple, Dict, Set from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS, flatten_patterns def find_newline_decision_points(offset_mapping: List[Tuple[int, int]], text: str) -> List[int]: """ Return token indices that are the first non-empty token AFTER a \n\n in text. These are the "decision points" for our analysis. """ decision_tis = [] for ti in range(1, len(offset_mapping)): ts_prev, te_prev = offset_mapping[ti - 1] ts, te = offset_mapping[ti] # Check if the text ending at ti-1 or between ti-1 and ti contains \n\n if ts_prev >= len(text) or te_prev > len(text): continue # Look at the text slice just before this token # Simpler: check if the previous token contains a \n\n boundary prev_text = text[ts_prev:te_prev] if "\n\n" in prev_text: decision_tis.append(ti) continue # Also check the gap between tokens (tokenizers sometimes split on newline) gap_start = te_prev gap_end = ts if gap_end > gap_start and "\n\n" in text[gap_start:gap_end]: decision_tis.append(ti) return decision_tis def find_trigger_char_positions(text: str, patterns: List[str]) -> List[int]: """Find char-start positions of all regex matches.""" starts = set() for pat in patterns: for m in re.finditer(pat, text): starts.add(m.start()) return sorted(starts) def char_to_token_index(char_pos: int, offsets: List[Tuple[int, int]]) -> int: """Map a char position to the token index containing it. Returns -1 if not found.""" for ti, (ts, te) in enumerate(offsets): if ts <= char_pos < te or (ts == te == char_pos): return ti return -1 def label_cot_decision_points( text: str, tokenizer, plan_window_before: int = 0, plan_window_after: int = 1, mon_window_before: int = 0, mon_window_after: int = 1, exec_sample_ratio: float = 1.0, # sample at most this ratio of available exec points rng_seed: int = 42, ) -> Dict: """ Returns a dict with: token_ids: List[int] offset_mapping: List[(int, int)] plan_decision_tis: List[int] # decision points BEFORE planning triggers mon_decision_tis: List[int] exec_decision_tis: List[int] # neutral decision points (newline-following, not plan/mon) all_newline_tis: List[int] # for computing 'general' direction A decision point is classified as: plan if the trigger token is within [ti, ti + plan_window_after) for any plan regex match mon if similar for mon regex exec if ti is a newline-following token AND neither plan nor mon """ enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False, truncation=False) token_ids = enc["input_ids"] offsets = enc["offset_mapping"] # Find all newline-following decision points all_newline_tis = find_newline_decision_points(offsets, text) all_newline_set: Set[int] = set(all_newline_tis) # Find plan/mon trigger token positions plan_flat = flatten_patterns(PLANNING_PATTERNS) mon_flat = flatten_patterns(MONITORING_PATTERNS) plan_char_starts = find_trigger_char_positions(text, plan_flat) mon_char_starts = find_trigger_char_positions(text, mon_flat) plan_trigger_tis = [char_to_token_index(cp, offsets) for cp in plan_char_starts] plan_trigger_tis = [t for t in plan_trigger_tis if t >= 0] mon_trigger_tis = [char_to_token_index(cp, offsets) for cp in mon_char_starts] mon_trigger_tis = [t for t in mon_trigger_tis if t >= 0] # For each newline decision point, classify plan_dp: List[int] = [] mon_dp: List[int] = [] # A decision point counts as plan if any plan trigger falls in [ti, ti + some_threshold] # We use: trigger must be within next 10 tokens after the decision point # (plan/mon triggers usually follow the newline within a few words) lookahead = 10 plan_trigger_set = set(plan_trigger_tis) mon_trigger_set = set(mon_trigger_tis) for ti in all_newline_tis: is_plan = any((ti <= tt < ti + lookahead) for tt in plan_trigger_set) is_mon = any((ti <= tt < ti + lookahead) for tt in mon_trigger_set) if is_plan and not is_mon: plan_dp.append(ti) elif is_mon and not is_plan: mon_dp.append(ti) elif is_plan and is_mon: # If both match, assign to whichever is closer plan_min = min((tt - ti) for tt in plan_trigger_set if ti <= tt < ti + lookahead) mon_min = min((tt - ti) for tt in mon_trigger_set if ti <= tt < ti + lookahead) if plan_min <= mon_min: plan_dp.append(ti) else: mon_dp.append(ti) # exec = newline DP but not plan/mon classified = set(plan_dp) | set(mon_dp) exec_candidates = [ti for ti in all_newline_tis if ti not in classified] # Down-sample exec_candidates to balance class sizes rng = random.Random(rng_seed) rng.shuffle(exec_candidates) # Keep at most: 3 × max(len(plan), len(mon)) target_exec = max(10, 3 * max(len(plan_dp), len(mon_dp))) exec_dp = exec_candidates[:target_exec] return { "token_ids": token_ids, "offset_mapping": offsets, "plan_decision_tis": plan_dp, "mon_decision_tis": mon_dp, "exec_decision_tis": exec_dp, "all_newline_tis": all_newline_tis, "n_plan": len(plan_dp), "n_mon": len(mon_dp), "n_exec": len(exec_dp), "n_newlines_total": len(all_newline_tis), }