""" CodeSensei — Test Runner. Runs individual test cases against a function and returns structured TestResult objects. Handles test isolation and error capture. """ from __future__ import annotations import traceback from typing import List, Tuple from env.models import TestResult from env.server.sandbox import run_function_with_tests, check_syntax def run_tests( function_code: str, test_cases: List[dict], timeout: int = 5, ) -> Tuple[List[TestResult], int, int, str]: """Run all test cases against a function and return results. Each test case dict has: - "name": str — test description - "code": str — Python assert statement(s) calling the function Args: function_code: The Python function source code. test_cases: List of test case dicts. timeout: Max execution time per test batch. Returns: Tuple of (test_results, passed_count, total_count, raw_error_output). """ # First check syntax is_valid, syntax_error = check_syntax(function_code) if not is_valid: return ( [ TestResult(test_name=tc["name"], passed=False, error_message=syntax_error) for tc in test_cases ], 0, len(test_cases), syntax_error, ) results: List[TestResult] = [] total_passed = 0 total = len(test_cases) raw_errors = [] # Run all tests together first for speed combined_test_code = "\n".join( f"# Test: {tc['name']}\n{tc['code']}" for tc in test_cases ) stdout, stderr, all_success = run_function_with_tests( function_code, combined_test_code, timeout ) if all_success and "ALL_TESTS_PASSED" in stdout: # All tests passed in batch — fast path for tc in test_cases: results.append(TestResult(test_name=tc["name"], passed=True)) return results, total, total, "" # If batch failed, run tests individually to identify which ones fail for tc in test_cases: stdout_i, stderr_i, success_i = run_function_with_tests( function_code, tc["code"], timeout ) if success_i and "ALL_TESTS_PASSED" in stdout_i: results.append(TestResult(test_name=tc["name"], passed=True)) total_passed += 1 else: # Extract the meaningful error error_msg = _extract_error(stderr_i) results.append( TestResult(test_name=tc["name"], passed=False, error_message=error_msg) ) raw_errors.append(f"[{tc['name']}] {error_msg}") return results, total_passed, total, "\n".join(raw_errors) def _extract_error(stderr: str) -> str: """Extract the most meaningful error line from stderr. Args: stderr: Raw stderr output from subprocess. Returns: Cleaned error message string (single line or short). """ if not stderr: return "Unknown error (no output)" lines = stderr.strip().split("\n") # Look for the last line that starts with a known error type for line in reversed(lines): stripped = line.strip() if any( stripped.startswith(err) for err in [ "AssertionError", "AssertionError", "TypeError", "ValueError", "NameError", "IndexError", "KeyError", "AttributeError", "ZeroDivisionError", "RecursionError", "RuntimeError", "StopIteration", "SyntaxError", "IndentationError", "AssertionError", ] ): return stripped # Fallback: last non-empty line, truncated for line in reversed(lines): if line.strip(): return line.strip()[:200] return "Unknown error"