v2 / src /labeling.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Token-level labeling of CoT text with three decision-point classes:
S_plan: token immediately before a planning trigger (in the window before trigger)
S_mon: token immediately before a monitoring trigger
S_exec: random \n\n-following token that is NEITHER plan nor mon
The "decision point" = first token AFTER a \n\n delimiter (step boundary).
We label decision points, not whole trigger spans. This gives us WRITE-direction samples.
"""
import random
import re
from typing import List, Tuple, Dict, Set
from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS, flatten_patterns
def find_newline_decision_points(offset_mapping: List[Tuple[int, int]], text: str) -> List[int]:
"""
Return token indices that are the first non-empty token AFTER a \n\n in text.
These are the "decision points" for our analysis.
"""
decision_tis = []
for ti in range(1, len(offset_mapping)):
ts_prev, te_prev = offset_mapping[ti - 1]
ts, te = offset_mapping[ti]
# Check if the text ending at ti-1 or between ti-1 and ti contains \n\n
if ts_prev >= len(text) or te_prev > len(text):
continue
# Look at the text slice just before this token
# Simpler: check if the previous token contains a \n\n boundary
prev_text = text[ts_prev:te_prev]
if "\n\n" in prev_text:
decision_tis.append(ti)
continue
# Also check the gap between tokens (tokenizers sometimes split on newline)
gap_start = te_prev
gap_end = ts
if gap_end > gap_start and "\n\n" in text[gap_start:gap_end]:
decision_tis.append(ti)
return decision_tis
def find_trigger_char_positions(text: str, patterns: List[str]) -> List[int]:
"""Find char-start positions of all regex matches."""
starts = set()
for pat in patterns:
for m in re.finditer(pat, text):
starts.add(m.start())
return sorted(starts)
def char_to_token_index(char_pos: int, offsets: List[Tuple[int, int]]) -> int:
"""Map a char position to the token index containing it. Returns -1 if not found."""
for ti, (ts, te) in enumerate(offsets):
if ts <= char_pos < te or (ts == te == char_pos):
return ti
return -1
def label_cot_decision_points(
text: str,
tokenizer,
plan_window_before: int = 0,
plan_window_after: int = 1,
mon_window_before: int = 0,
mon_window_after: int = 1,
exec_sample_ratio: float = 1.0, # sample at most this ratio of available exec points
rng_seed: int = 42,
) -> Dict:
"""
Returns a dict with:
token_ids: List[int]
offset_mapping: List[(int, int)]
plan_decision_tis: List[int] # decision points BEFORE planning triggers
mon_decision_tis: List[int]
exec_decision_tis: List[int] # neutral decision points (newline-following, not plan/mon)
all_newline_tis: List[int] # for computing 'general' direction
A decision point is classified as:
plan if the trigger token is within [ti, ti + plan_window_after) for any plan regex match
mon if similar for mon regex
exec if ti is a newline-following token AND neither plan nor mon
"""
enc = tokenizer(text, return_offsets_mapping=True, add_special_tokens=False, truncation=False)
token_ids = enc["input_ids"]
offsets = enc["offset_mapping"]
# Find all newline-following decision points
all_newline_tis = find_newline_decision_points(offsets, text)
all_newline_set: Set[int] = set(all_newline_tis)
# Find plan/mon trigger token positions
plan_flat = flatten_patterns(PLANNING_PATTERNS)
mon_flat = flatten_patterns(MONITORING_PATTERNS)
plan_char_starts = find_trigger_char_positions(text, plan_flat)
mon_char_starts = find_trigger_char_positions(text, mon_flat)
plan_trigger_tis = [char_to_token_index(cp, offsets) for cp in plan_char_starts]
plan_trigger_tis = [t for t in plan_trigger_tis if t >= 0]
mon_trigger_tis = [char_to_token_index(cp, offsets) for cp in mon_char_starts]
mon_trigger_tis = [t for t in mon_trigger_tis if t >= 0]
# For each newline decision point, classify
plan_dp: List[int] = []
mon_dp: List[int] = []
# A decision point counts as plan if any plan trigger falls in [ti, ti + some_threshold]
# We use: trigger must be within next 10 tokens after the decision point
# (plan/mon triggers usually follow the newline within a few words)
lookahead = 10
plan_trigger_set = set(plan_trigger_tis)
mon_trigger_set = set(mon_trigger_tis)
for ti in all_newline_tis:
is_plan = any((ti <= tt < ti + lookahead) for tt in plan_trigger_set)
is_mon = any((ti <= tt < ti + lookahead) for tt in mon_trigger_set)
if is_plan and not is_mon:
plan_dp.append(ti)
elif is_mon and not is_plan:
mon_dp.append(ti)
elif is_plan and is_mon:
# If both match, assign to whichever is closer
plan_min = min((tt - ti) for tt in plan_trigger_set if ti <= tt < ti + lookahead)
mon_min = min((tt - ti) for tt in mon_trigger_set if ti <= tt < ti + lookahead)
if plan_min <= mon_min:
plan_dp.append(ti)
else:
mon_dp.append(ti)
# exec = newline DP but not plan/mon
classified = set(plan_dp) | set(mon_dp)
exec_candidates = [ti for ti in all_newline_tis if ti not in classified]
# Down-sample exec_candidates to balance class sizes
rng = random.Random(rng_seed)
rng.shuffle(exec_candidates)
# Keep at most: 3 × max(len(plan), len(mon))
target_exec = max(10, 3 * max(len(plan_dp), len(mon_dp)))
exec_dp = exec_candidates[:target_exec]
return {
"token_ids": token_ids,
"offset_mapping": offsets,
"plan_decision_tis": plan_dp,
"mon_decision_tis": mon_dp,
"exec_decision_tis": exec_dp,
"all_newline_tis": all_newline_tis,
"n_plan": len(plan_dp),
"n_mon": len(mon_dp),
"n_exec": len(exec_dp),
"n_newlines_total": len(all_newline_tis),
}