File size: 3,964 Bytes
c47c81c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
CodeSensei — Test Runner.

Runs individual test cases against a function and returns structured
TestResult objects. Handles test isolation and error capture.
"""

from __future__ import annotations

import traceback
from typing import List, Tuple

from env.models import TestResult
from env.server.sandbox import run_function_with_tests, check_syntax


def run_tests(
    function_code: str,
    test_cases: List[dict],
    timeout: int = 5,
) -> Tuple[List[TestResult], int, int, str]:
    """Run all test cases against a function and return results.

    Each test case dict has:
        - "name": str — test description
        - "code": str — Python assert statement(s) calling the function

    Args:
        function_code: The Python function source code.
        test_cases: List of test case dicts.
        timeout: Max execution time per test batch.

    Returns:
        Tuple of (test_results, passed_count, total_count, raw_error_output).
    """
    # First check syntax
    is_valid, syntax_error = check_syntax(function_code)
    if not is_valid:
        return (
            [
                TestResult(test_name=tc["name"], passed=False, error_message=syntax_error)
                for tc in test_cases
            ],
            0,
            len(test_cases),
            syntax_error,
        )

    results: List[TestResult] = []
    total_passed = 0
    total = len(test_cases)
    raw_errors = []

    # Run all tests together first for speed
    combined_test_code = "\n".join(
        f"# Test: {tc['name']}\n{tc['code']}" for tc in test_cases
    )
    stdout, stderr, all_success = run_function_with_tests(
        function_code, combined_test_code, timeout
    )

    if all_success and "ALL_TESTS_PASSED" in stdout:
        # All tests passed in batch — fast path
        for tc in test_cases:
            results.append(TestResult(test_name=tc["name"], passed=True))
        return results, total, total, ""

    # If batch failed, run tests individually to identify which ones fail
    for tc in test_cases:
        stdout_i, stderr_i, success_i = run_function_with_tests(
            function_code, tc["code"], timeout
        )

        if success_i and "ALL_TESTS_PASSED" in stdout_i:
            results.append(TestResult(test_name=tc["name"], passed=True))
            total_passed += 1
        else:
            # Extract the meaningful error
            error_msg = _extract_error(stderr_i)
            results.append(
                TestResult(test_name=tc["name"], passed=False, error_message=error_msg)
            )
            raw_errors.append(f"[{tc['name']}] {error_msg}")

    return results, total_passed, total, "\n".join(raw_errors)


def _extract_error(stderr: str) -> str:
    """Extract the most meaningful error line from stderr.

    Args:
        stderr: Raw stderr output from subprocess.

    Returns:
        Cleaned error message string (single line or short).
    """
    if not stderr:
        return "Unknown error (no output)"

    lines = stderr.strip().split("\n")

    # Look for the last line that starts with a known error type
    for line in reversed(lines):
        stripped = line.strip()
        if any(
            stripped.startswith(err)
            for err in [
                "AssertionError",
                "AssertionError",
                "TypeError",
                "ValueError",
                "NameError",
                "IndexError",
                "KeyError",
                "AttributeError",
                "ZeroDivisionError",
                "RecursionError",
                "RuntimeError",
                "StopIteration",
                "SyntaxError",
                "IndentationError",
                "AssertionError",
            ]
        ):
            return stripped

    # Fallback: last non-empty line, truncated
    for line in reversed(lines):
        if line.strip():
            return line.strip()[:200]

    return "Unknown error"