File size: 7,717 Bytes
e53f10b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """
Behavior detectors: count trigger occurrences in generated CoT.
Used for:
- RR (reduction rate) evaluation during steering sweep
- Sanity checks for labeling
KEY UPDATE (Apr 2026):
Added "true reflection vs filler" distinction for monitoring triggers.
In Qwen3-Thinking CoT, "wait" is often used as a filler word
("Wait, 5+3=8") rather than as a real reflection signal
("Wait, that's wrong, let me check"). The new `count_real_monitoring`
excludes filler usage.
Also includes a more robust collapse detector (n-gram based, not
word-based; relative to baseline length, not fixed thresholds).
"""
import re
from typing import Dict, List, Tuple
from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS
class BehaviorDetector:
"""Count triggers of one dimension."""
def __init__(self, dimension: str):
assert dimension in ("planning", "monitoring")
self.dimension = dimension
self.patterns = PLANNING_PATTERNS if dimension == "planning" else MONITORING_PATTERNS
self.compiled = {
subtype: [re.compile(p) for p in plist]
for subtype, plist in self.patterns.items()
}
def detect(self, text: str) -> Dict:
res = {"total": 0, "by_type": {}, "spans": []}
for subtype, regs in self.compiled.items():
cnt = 0
for r in regs:
for m in r.finditer(text):
cnt += 1
res["spans"].append({
"subtype": subtype,
"start": m.start(),
"end": m.end(),
"match": m.group(0)[:50],
})
res["by_type"][subtype] = cnt
res["total"] += cnt
return res
def compute_rr(base_count: int, steered_count: int) -> float:
if base_count == 0:
return 0.0
return (base_count - steered_count) / base_count
# ============================================================
# Real-reflection vs filler word distinction
# ============================================================
# "wait" is sometimes used as a real monitoring signal
# (followed by reflective content), but other times just as a
# discourse filler before continuing computation.
#
# Real reflection patterns: "wait" followed within ~80 chars by
# language indicating self-correction / verification / re-evaluation
# Real reflection patterns: language indicating self-correction / verification
# / re-evaluation, applied to the CONTEXT AFTER the trigger ("wait, ...").
_REAL_REFLECTION_AFTER_WAIT = [
re.compile(
r"^[,.]?\s+.{0,80}?\b("
r"let\s+me\s+(check|verify|re-?check|reconsider|re-?examine|see|think)|"
r"i\s+(made|have)\s+(a|an)?\s*(mistake|error|typo|miscalc)|"
r"that'?s\s+(not\s+right|wrong|incorrect|off|not\s+correct)|"
r"actually|"
r"no[,.]?\s+(that|this|i)|"
r"hold\s+on|"
r"hmm[,.]?|"
r"i\s+(think|need\s+to|should|forgot|missed|skipped)|"
r"(but|because|since)\s+(i|we|the)|"
r"is\s+that\s+(right|correct)\?|"
r"does\s+that\s+(make\s+sense|work)"
r")",
re.IGNORECASE | re.DOTALL,
),
]
# Pure filler usage (just continues to compute, no reflection).
# These regexes check the CONTEXT AFTER "wait" — the "wait" itself is
# already matched by the trigger detector, ctx starts after it.
_WAIT_AS_FILLER = [
re.compile(r"^[,.]?\s*(\d|\$|\\)"), # immediately followed by computation
re.compile(r"^[,.]?\s*\b(here|there|so|then|the\s+\w+\s+is)\b"),
]
def count_real_monitoring(text: str) -> Dict:
"""
Count monitoring triggers, distinguishing real reflection from filler.
Returns:
total_triggers: all monitoring triggers (regex match)
real_reflection: triggers backed by reflective content within 80 chars
filler_only: triggers that are pure filler ("wait, 5+3=...")
ambiguous: triggers neither clearly real nor clearly filler
Use `real_reflection` (not `total_triggers`) as the primary metric
when scoring monitoring suppression — it ignores cases where the model
only kept the surface word.
"""
mon_det = BehaviorDetector("monitoring")
raw = mon_det.detect(text)
total = raw["total"]
real_count = 0
filler_count = 0
ambiguous = 0
for span in raw["spans"]:
if span["subtype"] != "error_detection":
real_count += 1 # other monitoring subtypes are unambiguous
continue
# Look at 80 chars after the trigger word
ctx = text[span["end"]:span["end"] + 80]
# Priority: filler check (immediate computation/connector after trigger)
is_filler = any(p.search(ctx) for p in _WAIT_AS_FILLER)
if is_filler:
filler_count += 1
continue
# Real reflection check: language indicating reflection in ctx
is_real = any(p.search(ctx) for p in _REAL_REFLECTION_AFTER_WAIT)
if is_real:
real_count += 1
else:
ambiguous += 1
return {
"total_triggers": total,
"real_reflection": real_count,
"filler_only": filler_count,
"ambiguous": ambiguous,
"by_type": raw["by_type"],
}
# ============================================================
# Robust collapse detection
# ============================================================
def is_collapsed(text: str, base_text: str = None,
ngram: int = 4, ngram_threshold: float = 0.5,
length_ratio_low: float = 0.3,
length_ratio_high: float = 1.8) -> Dict:
"""
Detect generation collapse using multiple signals.
Args:
text: generated CoT
base_text: optional baseline CoT for length comparison
ngram: n-gram size for repetition (default 4)
ngram_threshold: fraction of repeated n-grams above which collapse
length_ratio_low / high: relative to base, outside this range = collapsed
Returns:
{
"collapsed": bool,
"ngram_repetition": float,
"length_ratio": float or None,
"reason": str
}
"""
if not text or len(text) < 50:
return {
"collapsed": True,
"ngram_repetition": 0.0,
"length_ratio": None,
"reason": "empty_or_too_short",
}
# n-gram repetition (robust to word-tokenization noise)
toks = text.split()
rep = 0.0
if len(toks) >= ngram * 4:
ngrams = [tuple(toks[i:i+ngram]) for i in range(len(toks) - ngram + 1)]
if ngrams:
rep = 1.0 - (len(set(ngrams)) / len(ngrams))
# Length anomaly relative to baseline
length_ratio = None
length_anomaly = False
if base_text:
base_len = max(len(base_text), 1)
length_ratio = len(text) / base_len
length_anomaly = (length_ratio < length_ratio_low or
length_ratio > length_ratio_high)
rep_anomaly = rep > ngram_threshold
if rep_anomaly and length_anomaly:
reason = "repetition+length"
elif rep_anomaly:
reason = "repetition"
elif length_anomaly:
reason = "length"
else:
reason = "none"
return {
"collapsed": bool(rep_anomaly or length_anomaly),
"ngram_repetition": float(rep),
"length_ratio": length_ratio,
"reason": reason,
}
# Legacy export for backward compatibility
def repetition_score(text: str, window: int = 100) -> float:
"""Legacy. Use is_collapsed() instead."""
info = is_collapsed(text)
return info["ngram_repetition"]
|