Spaces:
Sleeping
Sleeping
| """ | |
| CodeSensei — Test Runner. | |
| Runs individual test cases against a function and returns structured | |
| TestResult objects. Handles test isolation and error capture. | |
| """ | |
| from __future__ import annotations | |
| import traceback | |
| from typing import List, Tuple | |
| from env.models import TestResult | |
| from env.server.sandbox import run_function_with_tests, check_syntax | |
| def run_tests( | |
| function_code: str, | |
| test_cases: List[dict], | |
| timeout: int = 5, | |
| ) -> Tuple[List[TestResult], int, int, str]: | |
| """Run all test cases against a function and return results. | |
| Each test case dict has: | |
| - "name": str — test description | |
| - "code": str — Python assert statement(s) calling the function | |
| Args: | |
| function_code: The Python function source code. | |
| test_cases: List of test case dicts. | |
| timeout: Max execution time per test batch. | |
| Returns: | |
| Tuple of (test_results, passed_count, total_count, raw_error_output). | |
| """ | |
| # First check syntax | |
| is_valid, syntax_error = check_syntax(function_code) | |
| if not is_valid: | |
| return ( | |
| [ | |
| TestResult(test_name=tc["name"], passed=False, error_message=syntax_error) | |
| for tc in test_cases | |
| ], | |
| 0, | |
| len(test_cases), | |
| syntax_error, | |
| ) | |
| results: List[TestResult] = [] | |
| total_passed = 0 | |
| total = len(test_cases) | |
| raw_errors = [] | |
| # Run all tests together first for speed | |
| combined_test_code = "\n".join( | |
| f"# Test: {tc['name']}\n{tc['code']}" for tc in test_cases | |
| ) | |
| stdout, stderr, all_success = run_function_with_tests( | |
| function_code, combined_test_code, timeout | |
| ) | |
| if all_success and "ALL_TESTS_PASSED" in stdout: | |
| # All tests passed in batch — fast path | |
| for tc in test_cases: | |
| results.append(TestResult(test_name=tc["name"], passed=True)) | |
| return results, total, total, "" | |
| # If batch failed, run tests individually to identify which ones fail | |
| for tc in test_cases: | |
| stdout_i, stderr_i, success_i = run_function_with_tests( | |
| function_code, tc["code"], timeout | |
| ) | |
| if success_i and "ALL_TESTS_PASSED" in stdout_i: | |
| results.append(TestResult(test_name=tc["name"], passed=True)) | |
| total_passed += 1 | |
| else: | |
| # Extract the meaningful error | |
| error_msg = _extract_error(stderr_i) | |
| results.append( | |
| TestResult(test_name=tc["name"], passed=False, error_message=error_msg) | |
| ) | |
| raw_errors.append(f"[{tc['name']}] {error_msg}") | |
| return results, total_passed, total, "\n".join(raw_errors) | |
| def _extract_error(stderr: str) -> str: | |
| """Extract the most meaningful error line from stderr. | |
| Args: | |
| stderr: Raw stderr output from subprocess. | |
| Returns: | |
| Cleaned error message string (single line or short). | |
| """ | |
| if not stderr: | |
| return "Unknown error (no output)" | |
| lines = stderr.strip().split("\n") | |
| # Look for the last line that starts with a known error type | |
| for line in reversed(lines): | |
| stripped = line.strip() | |
| if any( | |
| stripped.startswith(err) | |
| for err in [ | |
| "AssertionError", | |
| "AssertionError", | |
| "TypeError", | |
| "ValueError", | |
| "NameError", | |
| "IndexError", | |
| "KeyError", | |
| "AttributeError", | |
| "ZeroDivisionError", | |
| "RecursionError", | |
| "RuntimeError", | |
| "StopIteration", | |
| "SyntaxError", | |
| "IndentationError", | |
| "AssertionError", | |
| ] | |
| ): | |
| return stripped | |
| # Fallback: last non-empty line, truncated | |
| for line in reversed(lines): | |
| if line.strip(): | |
| return line.strip()[:200] | |
| return "Unknown error" | |