File size: 7,717 Bytes
e53f10b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
Behavior detectors: count trigger occurrences in generated CoT.

Used for:
  - RR (reduction rate) evaluation during steering sweep
  - Sanity checks for labeling

KEY UPDATE (Apr 2026):
  Added "true reflection vs filler" distinction for monitoring triggers.
  In Qwen3-Thinking CoT, "wait" is often used as a filler word
  ("Wait, 5+3=8") rather than as a real reflection signal
  ("Wait, that's wrong, let me check"). The new `count_real_monitoring`
  excludes filler usage.

Also includes a more robust collapse detector (n-gram based, not
word-based; relative to baseline length, not fixed thresholds).
"""
import re
from typing import Dict, List, Tuple
from configs.patterns import MONITORING_PATTERNS, PLANNING_PATTERNS


class BehaviorDetector:
    """Count triggers of one dimension."""
    def __init__(self, dimension: str):
        assert dimension in ("planning", "monitoring")
        self.dimension = dimension
        self.patterns = PLANNING_PATTERNS if dimension == "planning" else MONITORING_PATTERNS
        self.compiled = {
            subtype: [re.compile(p) for p in plist]
            for subtype, plist in self.patterns.items()
        }

    def detect(self, text: str) -> Dict:
        res = {"total": 0, "by_type": {}, "spans": []}
        for subtype, regs in self.compiled.items():
            cnt = 0
            for r in regs:
                for m in r.finditer(text):
                    cnt += 1
                    res["spans"].append({
                        "subtype": subtype,
                        "start": m.start(),
                        "end":   m.end(),
                        "match": m.group(0)[:50],
                    })
            res["by_type"][subtype] = cnt
            res["total"] += cnt
        return res


def compute_rr(base_count: int, steered_count: int) -> float:
    if base_count == 0:
        return 0.0
    return (base_count - steered_count) / base_count


# ============================================================
# Real-reflection vs filler word distinction
# ============================================================
# "wait" is sometimes used as a real monitoring signal
# (followed by reflective content), but other times just as a
# discourse filler before continuing computation.
#
# Real reflection patterns: "wait" followed within ~80 chars by
# language indicating self-correction / verification / re-evaluation
# Real reflection patterns: language indicating self-correction / verification
# / re-evaluation, applied to the CONTEXT AFTER the trigger ("wait, ...").
_REAL_REFLECTION_AFTER_WAIT = [
    re.compile(
        r"^[,.]?\s+.{0,80}?\b("
        r"let\s+me\s+(check|verify|re-?check|reconsider|re-?examine|see|think)|"
        r"i\s+(made|have)\s+(a|an)?\s*(mistake|error|typo|miscalc)|"
        r"that'?s\s+(not\s+right|wrong|incorrect|off|not\s+correct)|"
        r"actually|"
        r"no[,.]?\s+(that|this|i)|"
        r"hold\s+on|"
        r"hmm[,.]?|"
        r"i\s+(think|need\s+to|should|forgot|missed|skipped)|"
        r"(but|because|since)\s+(i|we|the)|"
        r"is\s+that\s+(right|correct)\?|"
        r"does\s+that\s+(make\s+sense|work)"
        r")",
        re.IGNORECASE | re.DOTALL,
    ),
]

# Pure filler usage (just continues to compute, no reflection).
# These regexes check the CONTEXT AFTER "wait" — the "wait" itself is
# already matched by the trigger detector, ctx starts after it.
_WAIT_AS_FILLER = [
    re.compile(r"^[,.]?\s*(\d|\$|\\)"),   # immediately followed by computation
    re.compile(r"^[,.]?\s*\b(here|there|so|then|the\s+\w+\s+is)\b"),
]


def count_real_monitoring(text: str) -> Dict:
    """
    Count monitoring triggers, distinguishing real reflection from filler.

    Returns:
        total_triggers:     all monitoring triggers (regex match)
        real_reflection:    triggers backed by reflective content within 80 chars
        filler_only:        triggers that are pure filler ("wait, 5+3=...")
        ambiguous:          triggers neither clearly real nor clearly filler

    Use `real_reflection` (not `total_triggers`) as the primary metric
    when scoring monitoring suppression — it ignores cases where the model
    only kept the surface word.
    """
    mon_det = BehaviorDetector("monitoring")
    raw = mon_det.detect(text)
    total = raw["total"]

    real_count = 0
    filler_count = 0
    ambiguous = 0

    for span in raw["spans"]:
        if span["subtype"] != "error_detection":
            real_count += 1   # other monitoring subtypes are unambiguous
            continue
        # Look at 80 chars after the trigger word
        ctx = text[span["end"]:span["end"] + 80]
        # Priority: filler check (immediate computation/connector after trigger)
        is_filler = any(p.search(ctx) for p in _WAIT_AS_FILLER)
        if is_filler:
            filler_count += 1
            continue
        # Real reflection check: language indicating reflection in ctx
        is_real = any(p.search(ctx) for p in _REAL_REFLECTION_AFTER_WAIT)
        if is_real:
            real_count += 1
        else:
            ambiguous += 1

    return {
        "total_triggers":  total,
        "real_reflection": real_count,
        "filler_only":     filler_count,
        "ambiguous":       ambiguous,
        "by_type":         raw["by_type"],
    }


# ============================================================
# Robust collapse detection
# ============================================================
def is_collapsed(text: str, base_text: str = None,
                 ngram: int = 4, ngram_threshold: float = 0.5,
                 length_ratio_low: float = 0.3,
                 length_ratio_high: float = 1.8) -> Dict:
    """
    Detect generation collapse using multiple signals.

    Args:
        text: generated CoT
        base_text: optional baseline CoT for length comparison
        ngram: n-gram size for repetition (default 4)
        ngram_threshold: fraction of repeated n-grams above which collapse
        length_ratio_low / high: relative to base, outside this range = collapsed

    Returns:
        {
          "collapsed": bool,
          "ngram_repetition": float,
          "length_ratio": float or None,
          "reason": str
        }
    """
    if not text or len(text) < 50:
        return {
            "collapsed": True,
            "ngram_repetition": 0.0,
            "length_ratio": None,
            "reason": "empty_or_too_short",
        }

    # n-gram repetition (robust to word-tokenization noise)
    toks = text.split()
    rep = 0.0
    if len(toks) >= ngram * 4:
        ngrams = [tuple(toks[i:i+ngram]) for i in range(len(toks) - ngram + 1)]
        if ngrams:
            rep = 1.0 - (len(set(ngrams)) / len(ngrams))

    # Length anomaly relative to baseline
    length_ratio = None
    length_anomaly = False
    if base_text:
        base_len = max(len(base_text), 1)
        length_ratio = len(text) / base_len
        length_anomaly = (length_ratio < length_ratio_low or
                          length_ratio > length_ratio_high)

    rep_anomaly = rep > ngram_threshold

    if rep_anomaly and length_anomaly:
        reason = "repetition+length"
    elif rep_anomaly:
        reason = "repetition"
    elif length_anomaly:
        reason = "length"
    else:
        reason = "none"

    return {
        "collapsed": bool(rep_anomaly or length_anomaly),
        "ngram_repetition": float(rep),
        "length_ratio": length_ratio,
        "reason": reason,
    }


# Legacy export for backward compatibility
def repetition_score(text: str, window: int = 100) -> float:
    """Legacy. Use is_collapsed() instead."""
    info = is_collapsed(text)
    return info["ngram_repetition"]