File size: 19,219 Bytes
dc71cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""
tests/test_phase4_reflection.py
────────────────────────────────
Unit tests for Phase 4: tools, failure categoriser, trajectory logger,
and the reflection agent loop (mocked LLM, no real API calls).

Run with: pytest tests/test_phase4_reflection.py -v
"""
from __future__ import annotations

import json
import textwrap
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest


# ── AgentTools ────────────────────────────────────────────────────────────────

class TestAgentTools:
    def test_read_file_success(self, tmp_path):
        from agent.tools import AgentTools
        (tmp_path / "foo.py").write_text("x = 1\ny = 2\n")
        tools = AgentTools(tmp_path)
        result = tools.read_file("foo.py")
        assert result.success
        assert "x = 1" in result.output

    def test_read_file_not_found(self, tmp_path):
        from agent.tools import AgentTools
        tools = AgentTools(tmp_path)
        result = tools.read_file("nonexistent.py")
        assert not result.success
        assert "not found" in result.error.lower()

    def test_read_file_path_traversal_rejected(self, tmp_path):
        from agent.tools import AgentTools
        tools = AgentTools(tmp_path)
        result = tools.read_file("../../etc/passwd")
        assert not result.success
        assert "traversal" in result.error.lower()

    def test_read_file_truncation(self, tmp_path):
        from agent.tools import AgentTools
        content = "\n".join(f"line {i}" for i in range(300))
        (tmp_path / "big.py").write_text(content)
        tools = AgentTools(tmp_path)
        result = tools.read_file("big.py", max_lines=10)
        assert result.success
        assert "truncated" in result.output

    def test_write_patch_success(self, tmp_path):
        from agent.tools import AgentTools
        tools = AgentTools(tmp_path)
        diff = "--- a/foo.py\n+++ b/foo.py\n@@ -1 +1 @@\n-old\n+new\n"
        result = tools.write_patch(diff)
        assert result.success
        assert (tmp_path / "_agent_patch.diff").exists()

    def test_write_patch_empty_rejected(self, tmp_path):
        from agent.tools import AgentTools
        tools = AgentTools(tmp_path)
        result = tools.write_patch("")
        assert not result.success
        assert "Empty" in result.error

    def test_write_patch_invalid_format_rejected(self, tmp_path):
        from agent.tools import AgentTools
        tools = AgentTools(tmp_path)
        result = tools.write_patch("just some text without diff header")
        assert not result.success

    def test_list_files(self, tmp_path):
        from agent.tools import AgentTools
        (tmp_path / "a.py").write_text("x=1")
        (tmp_path / "b.py").write_text("y=2")
        (tmp_path / "__pycache__").mkdir()
        tools = AgentTools(tmp_path)
        result = tools.list_files("**/*.py")
        assert result.success
        assert "a.py" in result.output
        assert "b.py" in result.output
        assert "__pycache__" not in result.output

    def test_tool_result_to_prompt_str(self):
        from agent.tools import ToolResult
        tr = ToolResult("read_file", True, "x = 1\n")
        prompt = tr.to_prompt_str()
        assert "read_file" in prompt
        assert "SUCCESS" in prompt
        assert "x = 1" in prompt

    def test_tool_result_error_in_prompt(self):
        from agent.tools import ToolResult
        tr = ToolResult("run_tests", False, "", "Timeout after 60s")
        prompt = tr.to_prompt_str()
        assert "ERROR" in prompt
        assert "Timeout" in prompt


# ── Failure Categoriser ───────────────────────────────────────────────────────

class TestFailureCategoriser:
    def _categorise(self, stdout, apply_ok=True, ftp=None, ptp=None, attempt=1, prev=None):
        from agent.failure_categoriser import categorise_failure
        return categorise_failure(
            test_stdout=stdout,
            patch_apply_success=apply_ok,
            fail_to_pass_results=ftp or {},
            pass_to_pass_results=ptp or {},
            attempt_num=attempt,
            previous_categories=prev,
        )

    def test_success(self):
        cat = self._categorise(
            "1 passed", apply_ok=True,
            ftp={"t::test_x": True},
            ptp={"t::test_y": True},
        )
        assert cat == "success"

    def test_patch_apply_failure_is_syntax_error(self):
        cat = self._categorise("", apply_ok=False)
        assert cat == "syntax_error"

    def test_syntax_error_in_output(self):
        cat = self._categorise("SyntaxError: invalid syntax (foo.py, line 5)")
        assert cat == "syntax_error"

    def test_import_error(self):
        cat = self._categorise("ModuleNotFoundError: No module named 'nonexistent'")
        assert cat == "import_error"

    def test_hallucinated_api_attribute_error(self):
        cat = self._categorise("AttributeError: 'QuerySet' object has no attribute 'bulk_filer'")
        assert cat == "hallucinated_api"

    def test_hallucinated_api_name_error(self):
        cat = self._categorise("NameError: name 'nonexistent_func' is not defined")
        assert cat == "hallucinated_api"

    def test_type_error(self):
        cat = self._categorise("TypeError: unsupported operand type(s) for +")
        assert cat == "type_error"

    def test_assertion_error(self):
        cat = self._categorise("AssertionError: expected True but got False")
        assert cat == "assertion_error"

    def test_incomplete_patch(self):
        cat = self._categorise(
            "2 failed", apply_ok=True,
            ftp={"t::a": True, "t::b": False},  # one passed, one failed
            ptp={},
        )
        assert cat == "incomplete_patch"

    def test_unknown_fallback(self):
        cat = self._categorise("some unexpected output with no pattern")
        assert cat == "unknown"

    def test_extract_first_error_context(self):
        from agent.failure_categoriser import extract_first_error_context
        output = textwrap.dedent("""
            tests/test_foo.py::test_bar FAILED
            AssertionError: expected 1, got 2
            
            tests/test_foo.py::test_baz PASSED
        """)
        context = extract_first_error_context(output)
        assert "FAILED" in context or "AssertionError" in context


# ── Trajectory Logger ─────────────────────────────────────────────────────────

class TestTrajectoryLogger:
    def test_log_and_load(self, tmp_path):
        from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry
        logger = TrajectoryLogger(tmp_path / "traj.jsonl")
        entry = TrajectoryEntry(
            instance_id="test__repo-1",
            repo="test/repo",
            attempt=1,
            patch="--- a/foo.py\n+++ b/foo.py\n",
            test_stdout="1 failed",
            fail_to_pass_results={"t::test_x": False},
            pass_to_pass_results={},
            resolved=False,
            failure_category="assertion_error",
            elapsed_seconds=5.2,
        )
        logger.log(entry)
        loaded = logger.load_all()
        assert len(loaded) == 1
        assert loaded[0].instance_id == "test__repo-1"
        assert loaded[0].failure_category == "assertion_error"

    def test_multiple_entries(self, tmp_path):
        from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry
        logger = TrajectoryLogger(tmp_path / "traj.jsonl")
        for i in range(5):
            entry = TrajectoryEntry(
                instance_id=f"inst-{i}",
                repo="test/repo",
                attempt=1,
                patch="",
                test_stdout="",
                fail_to_pass_results={},
                pass_to_pass_results={},
                resolved=(i % 2 == 0),
                failure_category="success" if i % 2 == 0 else "wrong_file_edit",
                elapsed_seconds=1.0,
            )
            logger.log(entry)
        assert logger.total_logged == 5
        loaded = logger.load_all()
        assert len(loaded) == 5

    def test_stats(self, tmp_path):
        from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry
        logger = TrajectoryLogger(tmp_path / "traj.jsonl")
        for i in range(4):
            entry = TrajectoryEntry(
                instance_id=f"inst-{i}",
                repo="r",
                attempt=1,
                patch="",
                test_stdout="",
                fail_to_pass_results={},
                pass_to_pass_results={},
                resolved=(i < 2),
                failure_category="success" if i < 2 else "assertion_error",
                elapsed_seconds=1.0,
            )
            logger.log(entry)
        stats = logger.stats()
        assert stats["total"] == 4
        assert stats["resolved"] == 2
        assert abs(stats["resolved_rate"] - 0.5) < 1e-6

    def test_export_for_finetuning(self, tmp_path):
        from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry
        logger = TrajectoryLogger(tmp_path / "traj.jsonl")
        entry = TrajectoryEntry(
            instance_id="inst-1",
            repo="r",
            attempt=1,
            patch="--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-bug\n+fix\n",
            test_stdout="",
            fail_to_pass_results={},
            pass_to_pass_results={},
            resolved=True,
            failure_category="success",
            elapsed_seconds=1.0,
            problem_statement="Fix the null pointer bug",
        )
        logger.log(entry)
        out_path = tmp_path / "ft_data.jsonl"
        count = logger.export_for_finetuning(out_path)
        assert count == 1
        line = json.loads(out_path.read_text().strip())
        assert "system" in line
        assert "user" in line
        assert "assistant" in line

    def test_filter_by_category(self, tmp_path):
        from agent.trajectory_logger import TrajectoryLogger, TrajectoryEntry
        logger = TrajectoryLogger(tmp_path / "traj.jsonl")
        for cat in ["success", "assertion_error", "syntax_error", "unknown"]:
            entry = TrajectoryEntry(
                instance_id=cat,
                repo="r",
                attempt=1,
                patch="--- a/f.py\n+++ b/f.py\n",
                test_stdout="",
                fail_to_pass_results={},
                pass_to_pass_results={},
                resolved=(cat == "success"),
                failure_category=cat,
                elapsed_seconds=1.0,
                problem_statement="test issue",
            )
            logger.log(entry)
        out = tmp_path / "filtered.jsonl"
        count = logger.export_for_finetuning(
            out, filter_categories=["assertion_error", "syntax_error"]
        )
        assert count == 2

    def test_instruction_pair_format(self, tmp_path):
        from agent.trajectory_logger import TrajectoryEntry
        entry = TrajectoryEntry(
            instance_id="test-1",
            repo="r",
            attempt=2,
            patch="--- a/f.py\n+++ b/f.py\n@@ -1 +1 @@\n-x\n+y\n",
            test_stdout="AssertionError: expected 1, got 2",
            fail_to_pass_results={"t::test_x": False},
            pass_to_pass_results={},
            resolved=False,
            failure_category="assertion_error",
            elapsed_seconds=3.0,
            problem_statement="Fix the assertion in the filter method",
            localised_files=["models/query.py"],
        )
        pair = entry.to_instruction_pair()
        assert "Fix the assertion" in pair["user"]
        assert "assertion_error" in pair["user"]
        assert pair["assistant"] == entry.patch
        assert pair["metadata"]["attempt"] == 2


# ── Reflection Agent (mocked LLM) ─────────────────────────────────────────────

class TestReflectionAgent:
    """Tests for the agent loop β€” LLM calls are mocked."""

    def _make_agent(self, tmp_path, trajectory_logger=None):
        from agent.reflection_agent import ReflectionAgent
        agent = ReflectionAgent(
            model="gpt-4o",
            max_attempts=3,
            sandbox=None,
            localisation_pipeline=None,
            trajectory_logger=trajectory_logger,
        )
        return agent

    def _mock_llm_patch(self, monkeypatch, patch_text: str, tokens: int = 100):
        """Mock _call_llm to return a fixed patch without API calls."""
        import agent.reflection_agent as ra
        monkeypatch.setattr(
            ra, "_call_llm",
            lambda *args, **kwargs: (patch_text, {"total_tokens": tokens,
                                                   "prompt_tokens": 80,
                                                   "completion_tokens": 20})
        )

    def test_agent_state_initialisation(self, tmp_path):
        from agent.reflection_agent import AgentState
        state = AgentState(
            instance_id="test-1",
            repo="test/repo",
            problem_statement="Fix bug",
            base_commit="abc123",
            fail_to_pass=["tests::test_x"],
            pass_to_pass=[],
            workspace_dir=tmp_path,
        )
        assert state.current_attempt == 0
        assert state.resolved is False
        assert state.total_tokens == 0

    def test_should_retry_when_not_resolved(self):
        from agent.reflection_agent import AgentState, should_retry
        from pathlib import Path
        state = AgentState(
            instance_id="t", repo="r", problem_statement="p",
            base_commit="a", fail_to_pass=[], pass_to_pass=[],
            workspace_dir=Path("/tmp"), resolved=False, current_attempt=1
        )
        assert should_retry(state, max_attempts=3) == "retry"

    def test_should_done_when_resolved(self):
        from agent.reflection_agent import AgentState, should_retry
        from pathlib import Path
        state = AgentState(
            instance_id="t", repo="r", problem_statement="p",
            base_commit="a", fail_to_pass=[], pass_to_pass=[],
            workspace_dir=Path("/tmp"), resolved=True, current_attempt=1
        )
        assert should_retry(state, max_attempts=3) == "done"

    def test_should_done_when_max_attempts_reached(self):
        from agent.reflection_agent import AgentState, should_retry
        from pathlib import Path
        state = AgentState(
            instance_id="t", repo="r", problem_statement="p",
            base_commit="a", fail_to_pass=[], pass_to_pass=[],
            workspace_dir=Path("/tmp"), resolved=False, current_attempt=3
        )
        assert should_retry(state, max_attempts=3) == "done"

    def test_node_generate_patch_increments_attempt(self, tmp_path, monkeypatch):
        from agent.reflection_agent import AgentState, node_generate_patch
        self._mock_llm_patch(monkeypatch, "--- a/foo.py\n+++ b/foo.py\n@@ -1 +1 @@\n-x\n+y\n")
        state = AgentState(
            instance_id="t", repo="r", problem_statement="fix the bug please",
            base_commit="abc", fail_to_pass=[], pass_to_pass=[],
            workspace_dir=tmp_path,
        )
        state = node_generate_patch(state)
        assert state.current_attempt == 1
        assert "--- a/foo.py" in state.last_patch

    def test_node_generate_patch_uses_reflection_on_retry(self, tmp_path, monkeypatch):
        from agent.reflection_agent import AgentState, node_generate_patch
        prompts_seen = []

        def mock_call_llm(user_prompt, *args, **kwargs):
            prompts_seen.append(user_prompt)
            return ("--- a/f.py\n+++ b/f.py\n", {"total_tokens": 50, "prompt_tokens": 40, "completion_tokens": 10})

        import agent.reflection_agent as ra
        monkeypatch.setattr(ra, "_call_llm", mock_call_llm)

        state = AgentState(
            instance_id="t", repo="r",
            problem_statement="fix the long detailed issue description here",
            base_commit="abc", fail_to_pass=[], pass_to_pass=[],
            workspace_dir=tmp_path,
            current_attempt=1,                         # simulate already one attempt
            last_test_stdout="AssertionError: expected 1",
            last_failure_category="assertion_error",
            last_patch="--- a/wrong.py\n+++ b/wrong.py\n",
            attempts=[{"attempt_num": 1}],
        )
        state = node_generate_patch(state)
        # Should use reflection prompt (contains "Previous Attempt")
        assert "Previous Attempt" in prompts_seen[-1]

    def test_agent_logs_trajectories(self, tmp_path, monkeypatch):
        from agent.reflection_agent import AgentState, node_generate_patch
        from agent.trajectory_logger import TrajectoryLogger
        traj_path = tmp_path / "traj.jsonl"
        traj_logger = TrajectoryLogger(traj_path)

        # Mock node_apply_and_test to mark as resolved immediately
        import agent.reflection_agent as ra
        def mock_apply(state, sandbox=None):
            state.resolved = True
            state.last_test_stdout = "1 passed"
            state.last_failure_category = "success"
            state.attempts.append({
                "attempt_num": state.current_attempt,
                "patch": state.last_patch,
                "test_stdout": "1 passed",
                "fail_to_pass_results": {},
                "pass_to_pass_results": {},
                "resolved": True,
                "failure_category": "success",
            })
            return state

        monkeypatch.setattr(ra, "node_apply_and_test", mock_apply)
        monkeypatch.setattr(ra, "_call_llm",
                            lambda *a, **kw: ("--- a/f.py\n+++ b/f.py\n", {"total_tokens": 10, "prompt_tokens": 8, "completion_tokens": 2}))

        agent = self._make_agent(tmp_path, trajectory_logger=traj_logger)
        state = agent.run(
            instance_id="test-1",
            repo="test/repo",
            problem_statement="fix the bug",
            base_commit="abc",
            fail_to_pass=[],
            pass_to_pass=[],
            workspace_dir=tmp_path,
        )
        assert state.resolved
        assert traj_logger.total_logged >= 1

    def test_strip_code_fences(self):
        from agent.reflection_agent import _strip_code_fences
        raw = "```diff\n--- a/f.py\n+++ b/f.py\n```"
        cleaned = _strip_code_fences(raw)
        assert "```" not in cleaned
        assert "--- a/f.py" in cleaned

    def test_build_file_context(self):
        from agent.reflection_agent import _build_file_context
        contents = {
            "a.py": "def foo(): pass",
            "b.py": "class Bar: pass",
        }
        ctx = _build_file_context(contents)
        assert "a.py" in ctx
        assert "b.py" in ctx
        assert "def foo" in ctx