v2 / src /detectors.py
JulianHJR's picture
Upload folder using huggingface_hub
e53f10b verified
"""
Behavior detectors: count trigger occurrences in generated CoT.
Used for:
- RR (reduction rate) evaluation during steering sweep
- Sanity checks for labeling
KEY UPDATE (Apr 2026):
Added "true reflection vs filler" distinction for monitoring triggers.
In Qwen3-Thinking CoT, "wait" is often used as a filler word
("Wait, 5+3=8") rather than as a real reflection signal
("Wait, that's wrong, let me check"). The new `count_real_monitoring`
excludes filler usage.
Also includes a more robust collapse detector (n-gram based, not
word-based; relative to baseline length, not fixed thresholds).
"""
import re
from typing import Dict, List, Tuple
from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS
class BehaviorDetector:
"""Count triggers of one dimension."""
def __init__(self, dimension: str):
assert dimension in ("planning", "monitoring")
self.dimension = dimension
self.patterns = PLANNING_PATTERNS if dimension == "planning" else MONITORING_PATTERNS
self.compiled = {
subtype: [re.compile(p) for p in plist]
for subtype, plist in self.patterns.items()
}
def detect(self, text: str) -> Dict:
res = {"total": 0, "by_type": {}, "spans": []}
for subtype, regs in self.compiled.items():
cnt = 0
for r in regs:
for m in r.finditer(text):
cnt += 1
res["spans"].append({
"subtype": subtype,
"start": m.start(),
"end": m.end(),
"match": m.group(0)[:50],
})
res["by_type"][subtype] = cnt
res["total"] += cnt
return res
def compute_rr(base_count: int, steered_count: int) -> float:
if base_count == 0:
return 0.0
return (base_count - steered_count) / base_count
# ============================================================
# Real-reflection vs filler word distinction
# ============================================================
# "wait" is sometimes used as a real monitoring signal
# (followed by reflective content), but other times just as a
# discourse filler before continuing computation.
#
# Real reflection patterns: "wait" followed within ~80 chars by
# language indicating self-correction / verification / re-evaluation
# Real reflection patterns: language indicating self-correction / verification
# / re-evaluation, applied to the CONTEXT AFTER the trigger ("wait, ...").
_REAL_REFLECTION_AFTER_WAIT = [
re.compile(
r"^[,.]?\s+.{0,80}?\b("
r"let\s+me\s+(check|verify|re-?check|reconsider|re-?examine|see|think)|"
r"i\s+(made|have)\s+(a|an)?\s*(mistake|error|typo|miscalc)|"
r"that'?s\s+(not\s+right|wrong|incorrect|off|not\s+correct)|"
r"actually|"
r"no[,.]?\s+(that|this|i)|"
r"hold\s+on|"
r"hmm[,.]?|"
r"i\s+(think|need\s+to|should|forgot|missed|skipped)|"
r"(but|because|since)\s+(i|we|the)|"
r"is\s+that\s+(right|correct)\?|"
r"does\s+that\s+(make\s+sense|work)"
r")",
re.IGNORECASE | re.DOTALL,
),
]
# Pure filler usage (just continues to compute, no reflection).
# These regexes check the CONTEXT AFTER "wait" — the "wait" itself is
# already matched by the trigger detector, ctx starts after it.
_WAIT_AS_FILLER = [
re.compile(r"^[,.]?\s*(\d|\$|\\)"), # immediately followed by computation
re.compile(r"^[,.]?\s*\b(here|there|so|then|the\s+\w+\s+is)\b"),
]
def count_real_monitoring(text: str) -> Dict:
"""
Count monitoring triggers, distinguishing real reflection from filler.
Returns:
total_triggers: all monitoring triggers (regex match)
real_reflection: triggers backed by reflective content within 80 chars
filler_only: triggers that are pure filler ("wait, 5+3=...")
ambiguous: triggers neither clearly real nor clearly filler
Use `real_reflection` (not `total_triggers`) as the primary metric
when scoring monitoring suppression — it ignores cases where the model
only kept the surface word.
"""
mon_det = BehaviorDetector("monitoring")
raw = mon_det.detect(text)
total = raw["total"]
real_count = 0
filler_count = 0
ambiguous = 0
for span in raw["spans"]:
if span["subtype"] != "error_detection":
real_count += 1 # other monitoring subtypes are unambiguous
continue
# Look at 80 chars after the trigger word
ctx = text[span["end"]:span["end"] + 80]
# Priority: filler check (immediate computation/connector after trigger)
is_filler = any(p.search(ctx) for p in _WAIT_AS_FILLER)
if is_filler:
filler_count += 1
continue
# Real reflection check: language indicating reflection in ctx
is_real = any(p.search(ctx) for p in _REAL_REFLECTION_AFTER_WAIT)
if is_real:
real_count += 1
else:
ambiguous += 1
return {
"total_triggers": total,
"real_reflection": real_count,
"filler_only": filler_count,
"ambiguous": ambiguous,
"by_type": raw["by_type"],
}
# ============================================================
# Robust collapse detection
# ============================================================
def is_collapsed(text: str, base_text: str = None,
ngram: int = 4, ngram_threshold: float = 0.5,
length_ratio_low: float = 0.3,
length_ratio_high: float = 1.8) -> Dict:
"""
Detect generation collapse using multiple signals.
Args:
text: generated CoT
base_text: optional baseline CoT for length comparison
ngram: n-gram size for repetition (default 4)
ngram_threshold: fraction of repeated n-grams above which collapse
length_ratio_low / high: relative to base, outside this range = collapsed
Returns:
{
"collapsed": bool,
"ngram_repetition": float,
"length_ratio": float or None,
"reason": str
}
"""
if not text or len(text) < 50:
return {
"collapsed": True,
"ngram_repetition": 0.0,
"length_ratio": None,
"reason": "empty_or_too_short",
}
# n-gram repetition (robust to word-tokenization noise)
toks = text.split()
rep = 0.0
if len(toks) >= ngram * 4:
ngrams = [tuple(toks[i:i+ngram]) for i in range(len(toks) - ngram + 1)]
if ngrams:
rep = 1.0 - (len(set(ngrams)) / len(ngrams))
# Length anomaly relative to baseline
length_ratio = None
length_anomaly = False
if base_text:
base_len = max(len(base_text), 1)
length_ratio = len(text) / base_len
length_anomaly = (length_ratio < length_ratio_low or
length_ratio > length_ratio_high)
rep_anomaly = rep > ngram_threshold
if rep_anomaly and length_anomaly:
reason = "repetition+length"
elif rep_anomaly:
reason = "repetition"
elif length_anomaly:
reason = "length"
else:
reason = "none"
return {
"collapsed": bool(rep_anomaly or length_anomaly),
"ngram_repetition": float(rep),
"length_ratio": length_ratio,
"reason": reason,
}
# Legacy export for backward compatibility
def repetition_score(text: str, window: int = 100) -> float:
"""Legacy. Use is_collapsed() instead."""
info = is_collapsed(text)
return info["ngram_repetition"]