Spaces:
Sleeping
Sleeping
| # test_runner.py – Full production version with continuous scoring, dynamic function detection, and randomised tests | |
| import subprocess | |
| import tempfile | |
| import os | |
| import json | |
| import ast | |
| import random | |
| import sys | |
| from typing import Tuple, List, Any, Optional | |
| from dataclasses import dataclass | |
| # Bridge fine-grained RedTeam ids to canonical TestRunner families. | |
| # This keeps evaluation stable even when bug generators use richer labels. | |
| BUG_ID_CANONICAL_MAP = { | |
| # Easy-family bugs on `get_user`-style behavior. | |
| "simple_typo": "null_check", | |
| "default_value": "null_check", | |
| "empty_return": "null_check", | |
| "string_index": "off_by_one", | |
| # Medium arithmetic/control-flow aliases. | |
| "loop_skip": "off_by_one", | |
| "sign_error": "wrong_operator", | |
| "swap_args": "wrong_operator", | |
| "uninitialised_var": "null_check", | |
| # Hard numeric-safety aliases. | |
| "division_by_zero_empty": "division_by_zero", | |
| "division_by_zero_zero": "division_by_zero", | |
| "float_precision": "division_by_zero", | |
| "abs_usage": "division_by_zero", | |
| "round_error": "division_by_zero", | |
| } | |
| class TestRunner: | |
| bug_id: str | |
| timeout_sec: int = 5 | |
| max_memory_mb: int = 256 | |
| fuzz_rounds: int = 3 # number of random test cases per bug | |
| def run_tests(self, fix_code: str) -> Tuple[float, str]: | |
| """ | |
| Returns (score, output_message) where score is proportion of passed tests (0.0–1.0). | |
| """ | |
| # 1. Detect the function defined in the agent's code (dynamic) | |
| func_name = self._get_defined_function_name(fix_code) | |
| if not func_name: | |
| return 0.0, "No function definition found in agent code" | |
| # 2. Normalize bug id so broader RedTeam ids still hit meaningful tests. | |
| canonical_bug_id = self._canonical_bug_id() | |
| # 3. Generate the test script (includes fixed + fuzzed test cases) | |
| test_script = self._generate_test_script(fix_code, func_name, canonical_bug_id) | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: | |
| f.write(test_script) | |
| tmp_path = f.name | |
| try: | |
| # Resource limiting (Linux only; fallback otherwise) | |
| try: | |
| import resource | |
| resource.setrlimit(resource.RLIMIT_AS, (self.max_memory_mb * 1024 * 1024, self.max_memory_mb * 1024 * 1024)) | |
| except Exception: | |
| pass | |
| result = subprocess.run( | |
| [sys.executable, tmp_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=self.timeout_sec, | |
| encoding='utf-8' | |
| ) | |
| # Parse JSON output | |
| try: | |
| data = json.loads(result.stdout.strip()) | |
| passed = data.get("passed", 0) | |
| total = data.get("total", 1) | |
| score = passed / total if total > 0 else 0.0 | |
| return score, result.stdout.strip() | |
| except json.JSONDecodeError: | |
| # Fallback: look for "True" (legacy) | |
| if "True" in result.stdout: | |
| return 1.0, result.stdout | |
| return 0.0, result.stdout | |
| except subprocess.TimeoutExpired: | |
| return 0.0, "Test execution timed out" | |
| except Exception as e: | |
| return 0.0, f"Unexpected error: {str(e)}" | |
| finally: | |
| try: | |
| os.unlink(tmp_path) | |
| except: | |
| pass | |
| def _get_defined_function_name(self, code: str) -> Optional[str]: | |
| """Extract the target function name from the code. | |
| Looks for a function named 'fix' first; otherwise returns the first function found. | |
| """ | |
| try: | |
| tree = ast.parse(code) | |
| first_func = None | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.FunctionDef): | |
| if node.name == "fix": | |
| return "fix" | |
| if first_func is None: | |
| first_func = node.name | |
| return first_func # fallback if no 'fix' function exists | |
| except SyntaxError: | |
| pass | |
| return None | |
| def _canonical_bug_id(self) -> str: | |
| """Return canonical bug family used by this test harness.""" | |
| return BUG_ID_CANONICAL_MAP.get(self.bug_id, self.bug_id) | |
| def _generate_test_script(self, fix_code: str, func_name: str, canonical_bug_id: str) -> str: | |
| """Generate a test script that runs fixed + fuzzed test cases and outputs JSON.""" | |
| test_cases = self._get_test_cases(canonical_bug_id, func_name) | |
| fuzzed_cases = self._generate_fuzzed_cases(canonical_bug_id, func_name) | |
| all_cases = test_cases + fuzzed_cases | |
| lines = [] | |
| lines.append(fix_code) | |
| lines.append("") | |
| lines.append("import json") | |
| lines.append("") | |
| lines.append("def run_tests():") | |
| lines.append(f" test_cases = {json.dumps(all_cases)}") | |
| lines.append(" passed = 0") | |
| lines.append(" total = len(test_cases)") | |
| lines.append(" for args, expected in test_cases:") | |
| lines.append(f" try:") | |
| lines.append(f" result = {func_name}(*args) if isinstance(args, list) else {func_name}(args)") | |
| lines.append(f" if result == expected:") | |
| lines.append(f" passed += 1") | |
| lines.append(f" except Exception:") | |
| lines.append(f" pass") | |
| lines.append(" return {'passed': passed, 'total': total}") | |
| lines.append("") | |
| lines.append("if __name__ == '__main__':") | |
| lines.append(" result = run_tests()") | |
| lines.append(" print(json.dumps(result))") | |
| return "\n".join(lines) | |
| def _get_test_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]: | |
| """ | |
| Return a list of (arguments, expected_output) for the given bug_id. | |
| Uses the actual function name (dynamic) for consistency. | |
| """ | |
| if canonical_bug_id == "null_check": | |
| return [ | |
| ([{"users": {"alice": "Alice"}, "id": "bob"}], None), # missing key should not crash | |
| ([{"users": {"alice": "Alice"}, "id": "alice"}], "Alice"), | |
| ] | |
| elif canonical_bug_id == "off_by_one": | |
| return [ | |
| ([[1,2,3,4]], 4), | |
| ([[]], 0), | |
| ] | |
| elif canonical_bug_id == "division_by_zero": | |
| return [ | |
| ([[]], 0), | |
| ([[1,2,3]], 2.0), | |
| ] | |
| elif canonical_bug_id == "wrong_operator": | |
| return [ | |
| ([5,3], 8), | |
| ([-1,1], 0), | |
| ] | |
| else: | |
| # For missing_lock, deadlock_order, etc., return empty list (will be handled gracefully) | |
| return [] | |
| def _generate_fuzzed_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]: | |
| """ | |
| Generate random test cases to prevent memorisation. | |
| Only for bugs where meaningful fuzzing is possible. | |
| """ | |
| cases = [] | |
| if canonical_bug_id == "null_check": | |
| # Random users dictionary and random ids | |
| for _ in range(self.fuzz_rounds): | |
| users = {f"user_{i}": f"name_{i}" for i in range(random.randint(1, 5))} | |
| # Pick existing or missing key | |
| if random.random() > 0.5: | |
| key = random.choice(list(users.keys())) | |
| expected = users[key] | |
| else: | |
| key = "missing_" + str(random.randint(100, 999)) | |
| expected = None | |
| cases.append(([{"users": users, "id": key}], expected)) | |
| elif canonical_bug_id == "off_by_one": | |
| for _ in range(self.fuzz_rounds): | |
| length = random.randint(0, 10) | |
| arr = list(range(length)) | |
| cases.append(([arr], length)) | |
| elif canonical_bug_id == "division_by_zero": | |
| for _ in range(self.fuzz_rounds): | |
| length = random.randint(0, 10) | |
| data = [random.randint(-100, 100) for _ in range(length)] | |
| expected = sum(data)/length if length else 0 | |
| cases.append(([data], expected)) | |
| elif canonical_bug_id == "wrong_operator": | |
| for _ in range(self.fuzz_rounds): | |
| a = random.randint(-100, 100) | |
| b = random.randint(-100, 100) | |
| cases.append(([a, b], a + b)) | |
| return cases | |