Spaces:
Sleeping
Sleeping
File size: 3,964 Bytes
c47c81c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
CodeSensei — Test Runner.
Runs individual test cases against a function and returns structured
TestResult objects. Handles test isolation and error capture.
"""
from __future__ import annotations
import traceback
from typing import List, Tuple
from env.models import TestResult
from env.server.sandbox import run_function_with_tests, check_syntax
def run_tests(
function_code: str,
test_cases: List[dict],
timeout: int = 5,
) -> Tuple[List[TestResult], int, int, str]:
"""Run all test cases against a function and return results.
Each test case dict has:
- "name": str — test description
- "code": str — Python assert statement(s) calling the function
Args:
function_code: The Python function source code.
test_cases: List of test case dicts.
timeout: Max execution time per test batch.
Returns:
Tuple of (test_results, passed_count, total_count, raw_error_output).
"""
# First check syntax
is_valid, syntax_error = check_syntax(function_code)
if not is_valid:
return (
[
TestResult(test_name=tc["name"], passed=False, error_message=syntax_error)
for tc in test_cases
],
0,
len(test_cases),
syntax_error,
)
results: List[TestResult] = []
total_passed = 0
total = len(test_cases)
raw_errors = []
# Run all tests together first for speed
combined_test_code = "\n".join(
f"# Test: {tc['name']}\n{tc['code']}" for tc in test_cases
)
stdout, stderr, all_success = run_function_with_tests(
function_code, combined_test_code, timeout
)
if all_success and "ALL_TESTS_PASSED" in stdout:
# All tests passed in batch — fast path
for tc in test_cases:
results.append(TestResult(test_name=tc["name"], passed=True))
return results, total, total, ""
# If batch failed, run tests individually to identify which ones fail
for tc in test_cases:
stdout_i, stderr_i, success_i = run_function_with_tests(
function_code, tc["code"], timeout
)
if success_i and "ALL_TESTS_PASSED" in stdout_i:
results.append(TestResult(test_name=tc["name"], passed=True))
total_passed += 1
else:
# Extract the meaningful error
error_msg = _extract_error(stderr_i)
results.append(
TestResult(test_name=tc["name"], passed=False, error_message=error_msg)
)
raw_errors.append(f"[{tc['name']}] {error_msg}")
return results, total_passed, total, "\n".join(raw_errors)
def _extract_error(stderr: str) -> str:
"""Extract the most meaningful error line from stderr.
Args:
stderr: Raw stderr output from subprocess.
Returns:
Cleaned error message string (single line or short).
"""
if not stderr:
return "Unknown error (no output)"
lines = stderr.strip().split("\n")
# Look for the last line that starts with a known error type
for line in reversed(lines):
stripped = line.strip()
if any(
stripped.startswith(err)
for err in [
"AssertionError",
"AssertionError",
"TypeError",
"ValueError",
"NameError",
"IndexError",
"KeyError",
"AttributeError",
"ZeroDivisionError",
"RecursionError",
"RuntimeError",
"StopIteration",
"SyntaxError",
"IndentationError",
"AssertionError",
]
):
return stripped
# Fallback: last non-empty line, truncated
for line in reversed(lines):
if line.strip():
return line.strip()[:200]
return "Unknown error"
|