File size: 6,182 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
Token-level labeling of CoT text with three decision-point classes:
S_plan: token immediately before a planning trigger (in the window before trigger)
S_mon: token immediately before a monitoring trigger
S_exec: random \n\n-following token that is NEITHER plan nor mon
The "decision point" = first token AFTER a \n\n delimiter (step boundary).
We label decision points, not whole trigger spans. This gives us WRITE-direction samples.
"""
import random
import re
from typing import List, Tuple, Dict, Set
from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS, flatten_patterns
def find_newline_decision_points(offset_mapping: List[Tuple[int, int]], text: str) -> List[int]:
"""
Return token indices that are the first non-empty token AFTER a \n\n in text.
These are the "decision points" for our analysis.
"""
decision_tis = []
for ti in range(1, len(offset_mapping)):
ts_prev, te_prev = offset_mapping[ti - 1]
ts, te = offset_mapping[ti]
# Check if the text ending at ti-1 or between ti-1 and ti contains \n\n
if ts_prev >= len(text) or te_prev > len(text):
continue
# Look at the text slice just before this token
# Simpler: check if the previous token contains a \n\n boundary
prev_text = text[ts_prev:te_prev]
if "\n\n" in prev_text:
decision_tis.append(ti)
continue
# Also check the gap between tokens (tokenizers sometimes split on newline)
gap_start = te_prev
gap_end = ts
if gap_end > gap_start and "\n\n" in text[gap_start:gap_end]:
decision_tis.append(ti)
return decision_tis
def find_trigger_char_positions(text: str, patterns: List[str]) -> List[int]:
"""Find char-start positions of all regex matches."""
starts = set()
for pat in patterns:
for m in re.finditer(pat, text):
starts.add(m.start())
return sorted(starts)
def char_to_token_index(char_pos: int, offsets: List[Tuple[int, int]]) -> int:
"""Map a char position to the token index containing it. Returns -1 if not found."""
for ti, (ts, te) in enumerate(offsets):
if ts <= char_pos < te or (ts == te == char_pos):
return ti
return -1
def label_cot_decision_points(
text: str,
tokenizer,
plan_window_before: int = 0,
plan_window_after: int = 1,
mon_window_before: int = 0,
mon_window_after: int = 1,
exec_sample_ratio: float = 1.0, # sample at most this ratio of available exec points
rng_seed: int = 42,
) -> Dict:
"""
Returns a dict with:
token_ids: List[int]
offset_mapping: List[(int, int)]
plan_decision_tis: List[int] # decision points BEFORE planning triggers
mon_decision_tis: List[int]
exec_decision_tis: List[int] # neutral decision points (newline-following, not plan/mon)
all_newline_tis: List[int] # for computing 'general' direction
A decision point is classified as:
plan if the trigger token is within [ti, ti + plan_window_after) for any plan regex match
mon if similar for mon regex
exec if ti is a newline-following token AND neither plan nor mon
"""
enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False, truncation=False)
token_ids = enc["input_ids"]
offsets = enc["offset_mapping"]
# Find all newline-following decision points
all_newline_tis = find_newline_decision_points(offsets, text)
all_newline_set: Set[int] = set(all_newline_tis)
# Find plan/mon trigger token positions
plan_flat = flatten_patterns(PLANNING_PATTERNS)
mon_flat = flatten_patterns(MONITORING_PATTERNS)
plan_char_starts = find_trigger_char_positions(text, plan_flat)
mon_char_starts = find_trigger_char_positions(text, mon_flat)
plan_trigger_tis = [char_to_token_index(cp, offsets) for cp in plan_char_starts]
plan_trigger_tis = [t for t in plan_trigger_tis if t >= 0]
mon_trigger_tis = [char_to_token_index(cp, offsets) for cp in mon_char_starts]
mon_trigger_tis = [t for t in mon_trigger_tis if t >= 0]
# For each newline decision point, classify
plan_dp: List[int] = []
mon_dp: List[int] = []
# A decision point counts as plan if any plan trigger falls in [ti, ti + some_threshold]
# We use: trigger must be within next 10 tokens after the decision point
# (plan/mon triggers usually follow the newline within a few words)
lookahead = 10
plan_trigger_set = set(plan_trigger_tis)
mon_trigger_set = set(mon_trigger_tis)
for ti in all_newline_tis:
is_plan = any((ti <= tt < ti + lookahead) for tt in plan_trigger_set)
is_mon = any((ti <= tt < ti + lookahead) for tt in mon_trigger_set)
if is_plan and not is_mon:
plan_dp.append(ti)
elif is_mon and not is_plan:
mon_dp.append(ti)
elif is_plan and is_mon:
# If both match, assign to whichever is closer
plan_min = min((tt - ti) for tt in plan_trigger_set if ti <= tt < ti + lookahead)
mon_min = min((tt - ti) for tt in mon_trigger_set if ti <= tt < ti + lookahead)
if plan_min <= mon_min:
plan_dp.append(ti)
else:
mon_dp.append(ti)
# exec = newline DP but not plan/mon
classified = set(plan_dp) | set(mon_dp)
exec_candidates = [ti for ti in all_newline_tis if ti not in classified]
# Down-sample exec_candidates to balance class sizes
rng = random.Random(rng_seed)
rng.shuffle(exec_candidates)
# Keep at most: 3 × max(len(plan), len(mon))
target_exec = max(10, 3 * max(len(plan_dp), len(mon_dp)))
exec_dp = exec_candidates[:target_exec]
return {
"token_ids": token_ids,
"offset_mapping": offsets,
"plan_decision_tis": plan_dp,
"mon_decision_tis": mon_dp,
"exec_decision_tis": exec_dp,
"all_newline_tis": all_newline_tis,
"n_plan": len(plan_dp),
"n_mon": len(mon_dp),
"n_exec": len(exec_dp),
"n_newlines_total": len(all_newline_tis),
}
|