File size: 18,645 Bytes
0b0338d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
# server/causal_probe.py
"""
Causal Reasoning Probe β€” v4.0

The key scientific question: Did the agent understand WHY the bug exists,
or did it accidentally fix it by pattern matching?

We measure causal understanding by checking if the agent traversed the
COMPLETE causal chain: Failing test β†’ tested function β†’ return path β†’ root cause.

An agent that reads only the test and immediately rewrites the function
is guessing. An agent that reads test β†’ traces the call stack β†’ finds the
actual cause first is reasoning causally.

This is NOT in any current benchmark. SWE-bench only checks if the test passes.
We check HOW the agent got there.
"""
from __future__ import annotations
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from enum import Enum


class CausalUnderstandingLevel(str, Enum):
    DEEP = "DEEP"             # Full causal chain traversal
    PARTIAL = "PARTIAL"       # Partial chain (some steps missing)
    SUPERFICIAL = "SUPERFICIAL"  # Direct test→rewrite with no chain
    RANDOM = "RANDOM"         # No discernible causal pattern


@dataclass
class CausalChainNode:
    """One node in the reconstructed causal chain."""
    file: str
    role: str   # "test", "caller", "called", "root_cause", "missed"
    was_read: bool
    read_order: Optional[int]  # Which step did agent read this?


@dataclass
class CausalProbeReport:
    """
    Full causal reasoning analysis for one episode.
    This is the primary output of the CausalProbe.
    """
    episode_id: str
    task: str

    # Core verdict
    understanding_level: CausalUnderstandingLevel
    causal_score: float            # 0.0 – 1.0

    # Chain analysis
    expected_chain: List[CausalChainNode]  # What SHOULD have been read
    actual_chain_coverage: float           # Fraction of chain actually traversed
    chain_order_score: float               # Was chain traversed in correct order?

    # Behavioral signals
    read_before_write: bool        # Did agent read all relevant files before writing?
    test_informed_navigation: bool # Did reading tests change which files agent read next?
    search_before_navigate: bool   # Did agent search for function names before reading?
    submit_after_test: bool        # Did agent verify fix before submitting?

    # Signal: understanding vs guessing
    guessing_indicators: List[str]   # Signs agent was guessing
    understanding_indicators: List[str]  # Signs agent understood

    # Calibration
    false_confidence_detected: bool  # Submitted without reading root cause file
    shortcut_learning_detected: bool # Read test file β†’ immediately wrote β†’ submitted

    explanation: str
    recommendations: List[str]

    def to_dict(self) -> dict:
        return {
            "episode_id": self.episode_id,
            "task": self.task,
            "understanding_level": self.understanding_level.value,
            "causal_score": round(self.causal_score, 3),
            "chain_coverage": round(self.actual_chain_coverage, 3),
            "chain_order_score": round(self.chain_order_score, 3),
            "behavioral_signals": {
                "read_before_write": self.read_before_write,
                "test_informed_navigation": self.test_informed_navigation,
                "search_before_navigate": self.search_before_navigate,
                "submit_after_test": self.submit_after_test,
            },
            "guessing_indicators": self.guessing_indicators,
            "understanding_indicators": self.understanding_indicators,
            "diagnostics": {
                "false_confidence_detected": self.false_confidence_detected,
                "shortcut_learning_detected": self.shortcut_learning_detected,
            },
            "expected_chain": [
                {"file": n.file, "role": n.role, "read": n.was_read, "order": n.read_order}
                for n in self.expected_chain
            ],
            "explanation": self.explanation,
            "recommendations": self.recommendations,
        }


class CausalProbe:
    """
    Analyzes whether an agent engaged in true causal reasoning.

    The core insight: for a bug in src/X.py called from tests/test_X.py,
    the causal chain is:
        tests/test_X.py β†’ (calls) β†’ src/X.py β†’ (calls) β†’ src/utils.py (maybe)

    A causally-aware agent reads in this order.
    A shortcut agent reads the test, guesses the bug, rewrites without reading source.

    We score order, coverage, and behavioral signals.
    """

    def probe(
        self,
        episode_id: str,
        task: str,
        trajectory_steps: List[dict],
        variant_meta: dict,
        files_read: List[str],
        files_written: List[str],
        final_score: float,
    ) -> CausalProbeReport:
        """Run the causal probe on an episode's trajectory."""

        # ── Build expected causal chain from variant metadata ─────────────────
        test_files = variant_meta.get("test_files", []) or [
            f for f in variant_meta.get("read_first_files", []) if "test" in f
        ]
        bug_files = variant_meta.get("bug_files", []) or variant_meta.get("files_to_implement", [])
        dep_files = variant_meta.get("dependencies", []) or []

        # If metadata sparse, infer from trajectory
        all_files_in_traj = list({
            s.get("action_path") for s in trajectory_steps
            if s.get("action_path") and s.get("action_type") in ("read_file", "write_file")
        })

        if not test_files:
            test_files = [f for f in all_files_in_traj if "test" in f.lower()]
        if not bug_files:
            bug_files = [f for f in all_files_in_traj
                        if "test" not in f.lower() and f.endswith(".py")]

        # Build expected chain
        expected_chain: List[CausalChainNode] = []
        read_set = set(files_read)
        read_order: Dict[str, int] = {}
        for step in trajectory_steps:
            if step.get("action_type") == "read_file" and step.get("action_path"):
                path = step["action_path"]
                if path not in read_order:
                    read_order[path] = step.get("step_number", len(read_order) + 1)

        for tf in test_files:
            expected_chain.append(CausalChainNode(
                file=tf, role="test",
                was_read=tf in read_set,
                read_order=read_order.get(tf),
            ))
        for bf in bug_files:
            expected_chain.append(CausalChainNode(
                file=bf, role="root_cause",
                was_read=bf in read_set,
                read_order=read_order.get(bf),
            ))
        for df in dep_files:
            expected_chain.append(CausalChainNode(
                file=df, role="caller",
                was_read=df in read_set,
                read_order=read_order.get(df),
            ))

        if not expected_chain:
            # Fallback: any file is better than none
            for f in all_files_in_traj[:3]:
                expected_chain.append(CausalChainNode(
                    file=f, role="unknown",
                    was_read=True,
                    read_order=read_order.get(f),
                ))

        # ── Chain coverage ────────────────────────────────────────────────────
        chain_files_read = [n for n in expected_chain if n.was_read and n.role != "missed"]
        actual_chain_coverage = (
            len(chain_files_read) / len(expected_chain) if expected_chain else 0.0
        )

        # ── Chain order score (tests before src = good causal order) ──────────
        chain_order_score = 0.0
        test_orders = [n.read_order for n in expected_chain if n.role == "test" and n.read_order]
        src_orders = [n.read_order for n in expected_chain
                      if n.role in ("root_cause", "caller") and n.read_order]

        if test_orders and src_orders:
            # Good: all tests read before source files
            correct_order_pairs = sum(
                1 for to in test_orders for so in src_orders if to < so
            )
            total_pairs = len(test_orders) * len(src_orders)
            chain_order_score = correct_order_pairs / total_pairs if total_pairs > 0 else 0.0
        elif test_orders and not src_orders:
            chain_order_score = 0.3  # Partial β€” read tests but not source
        elif src_orders and not test_orders:
            chain_order_score = 0.2  # Read source without reading tests = weaker

        # ── Behavioral signals ────────────────────────────────────────────────
        action_types = [s.get("action_type", "") for s in trajectory_steps]
        action_paths = [s.get("action_path") for s in trajectory_steps]

        # read_before_write: all written files were read at least once before write
        read_before_write = True
        for step in trajectory_steps:
            if step.get("action_type") == "write_file" and step.get("action_path"):
                p = step["action_path"]
                step_n = step.get("step_number", 0)
                was_read_before = any(
                    s2.get("action_type") == "read_file"
                    and s2.get("action_path") == p
                    and s2.get("step_number", 99) < step_n
                    for s2 in trajectory_steps
                )
                if not was_read_before:
                    read_before_write = False
                    break

        # test_informed_navigation: did agent read source files AFTER reading tests?
        test_read_step = min(
            (s.get("step_number", 99) for s in trajectory_steps
             if s.get("action_type") == "read_file"
             and any(tf in (s.get("action_path") or "") for tf in test_files)),
            default=None
        )
        src_read_after_test = test_read_step is not None and any(
            s.get("action_type") == "read_file"
            and s.get("step_number", 0) > test_read_step
            and any(bf in (s.get("action_path") or "") for bf in bug_files)
            for s in trajectory_steps
        )
        test_informed_navigation = src_read_after_test

        # search_before_navigate: used search_code before reading source files
        search_steps = [s for s in trajectory_steps if s.get("action_type") == "search_code"]
        first_src_read = min(
            (s.get("step_number", 99) for s in trajectory_steps
             if s.get("action_type") == "read_file"
             and any(bf in (s.get("action_path") or "") for bf in bug_files)),
            default=None
        )
        search_before_navigate = bool(search_steps) and (
            first_src_read is None or
            any(s.get("step_number", 99) < first_src_read for s in search_steps)
        )

        # submit_after_test: ran tests before submitting
        test_runs = [s for s in trajectory_steps if s.get("action_type") == "run_tests"]
        submit_step = next(
            (s.get("step_number", 99) for s in trajectory_steps
             if s.get("action_type") == "submit"), None
        )
        submit_after_test = bool(test_runs) and submit_step is not None and any(
            s.get("step_number", 0) < submit_step for s in test_runs
        )

        # ── Guessing vs understanding indicators ──────────────────────────────
        guessing_indicators = []
        understanding_indicators = []

        total = len(trajectory_steps)

        # Guessing: short episode with low score
        if total <= 3 and final_score < 0.5:
            guessing_indicators.append(f"Submitted in only {total} steps with score {final_score:.2f}")

        # Guessing: wrote without reading
        if not read_before_write:
            guessing_indicators.append("Wrote to file(s) without first reading them")

        # Guessing: skipped test files
        if not any(n.was_read for n in expected_chain if n.role == "test"):
            guessing_indicators.append("Never read any test files")

        # Guessing: skipped source files
        if not any(n.was_read for n in expected_chain if n.role == "root_cause"):
            guessing_indicators.append("Never read the bug/source file")

        # Understanding: search used
        if search_steps:
            understanding_indicators.append(
                f"Used search_code {len(search_steps)}Γ— to locate bug"
            )

        # Understanding: read tests first
        if chain_order_score > 0.7:
            understanding_indicators.append("Read tests before source files (correct causal order)")

        # Understanding: tested before submitting
        if submit_after_test:
            understanding_indicators.append("Verified fix with run_tests before submitting")

        # Understanding: explored full chain
        if actual_chain_coverage > 0.7:
            understanding_indicators.append(
                f"Covered {actual_chain_coverage:.0%} of expected causal chain"
            )

        # ── Diagnostics ───────────────────────────────────────────────────────
        # False confidence: submitted very early without testing
        false_confidence_detected = (
            submit_step is not None and submit_step <= 3 and not test_runs
        )

        # Shortcut learning: read test β†’ immediate write β†’ submit (skipped source)
        has_write = "write_file" in action_types
        has_src_read = any(
            bf in (s.get("action_path") or "")
            for s in trajectory_steps
            if s.get("action_type") == "read_file"
            for bf in bug_files
        )
        shortcut_sequence = has_write and not has_src_read
        shortcut_learning_detected = shortcut_sequence

        # ── Composite causal score ─────────────────────────────────────────────
        scores = {
            "chain_coverage": actual_chain_coverage * 0.30,
            "chain_order": chain_order_score * 0.25,
            "read_before_write": (0.15 if read_before_write else 0.0),
            "test_informed": (0.15 if test_informed_navigation else 0.0),
            "verified": (0.10 if submit_after_test else 0.0),
            "searched": (0.05 if search_before_navigate else 0.0),
        }
        causal_score = sum(scores.values())
        causal_score = max(0.0, min(1.0, causal_score))

        # ── Understanding level classification ────────────────────────────────
        if causal_score >= 0.75:
            level = CausalUnderstandingLevel.DEEP
        elif causal_score >= 0.45:
            level = CausalUnderstandingLevel.PARTIAL
        elif causal_score >= 0.20:
            level = CausalUnderstandingLevel.SUPERFICIAL
        else:
            level = CausalUnderstandingLevel.RANDOM

        # ── Explanation ───────────────────────────────────────────────────────
        level_explanations = {
            CausalUnderstandingLevel.DEEP: (
                "Agent demonstrated genuine causal reasoning: read tests to understand expected "
                "behavior, traced the call chain to the root cause, made a targeted fix, and "
                "verified with tests before submitting."
            ),
            CausalUnderstandingLevel.PARTIAL: (
                "Agent showed partial causal understanding. Some chain links were traversed "
                "but the reasoning was incomplete β€” likely missed tracing deeper dependencies "
                "or skipped test verification."
            ),
            CausalUnderstandingLevel.SUPERFICIAL: (
                "Agent showed superficial reasoning. Actions did not follow a clear causal "
                "path from test β†’ failure β†’ root cause. Likely pattern-matched on filenames "
                "or guessed the fix location."
            ),
            CausalUnderstandingLevel.RANDOM: (
                "Agent showed no discernible causal reasoning. Actions appear random relative "
                "to the causal structure of the bug. This is the profile of pure trial-and-error."
            ),
        }
        explanation = level_explanations[level]

        # ── Recommendations ───────────────────────────────────────────────────
        recs = []
        if not any(n.was_read for n in expected_chain if n.role == "test"):
            recs.append("Always read the failing test first β€” it defines the expected behavior.")
        if not read_before_write:
            recs.append("Never write to a file before reading it β€” blind writes cause more bugs.")
        if not submit_after_test:
            recs.append("Run tests after every write to verify your fix is correct.")
        if not search_steps:
            recs.append("Use search_code to find function definitions before navigating blindly.")
        if actual_chain_coverage < 0.5:
            recs.append(
                "Explore more of the causal chain. The bug's root cause may be deeper than the first file."
            )
        if not recs:
            recs.append("Excellent reasoning! Maintain this systematic approach.")

        return CausalProbeReport(
            episode_id=episode_id,
            task=task,
            understanding_level=level,
            causal_score=causal_score,
            expected_chain=expected_chain,
            actual_chain_coverage=actual_chain_coverage,
            chain_order_score=chain_order_score,
            read_before_write=read_before_write,
            test_informed_navigation=test_informed_navigation,
            search_before_navigate=search_before_navigate,
            submit_after_test=submit_after_test,
            guessing_indicators=guessing_indicators,
            understanding_indicators=understanding_indicators,
            false_confidence_detected=false_confidence_detected,
            shortcut_learning_detected=shortcut_learning_detected,
            explanation=explanation,
            recommendations=recs,
        )