|
|
| import subprocess
|
| import tempfile
|
| import os
|
| import json
|
| import ast
|
| import random
|
| import sys
|
| from typing import Tuple, List, Any, Optional
|
| from dataclasses import dataclass
|
|
|
|
|
|
|
| BUG_ID_CANONICAL_MAP = {
|
|
|
| "simple_typo": "null_check",
|
| "default_value": "null_check",
|
| "empty_return": "null_check",
|
| "string_index": "off_by_one",
|
|
|
|
|
| "loop_skip": "off_by_one",
|
| "sign_error": "wrong_operator",
|
| "swap_args": "wrong_operator",
|
| "uninitialised_var": "null_check",
|
|
|
|
|
| "division_by_zero_empty": "division_by_zero",
|
| "division_by_zero_zero": "division_by_zero",
|
| "float_precision": "division_by_zero",
|
| "abs_usage": "division_by_zero",
|
| "round_error": "division_by_zero",
|
| }
|
|
|
| @dataclass
|
| class TestRunner:
|
| bug_id: str
|
| timeout_sec: int = 5
|
| max_memory_mb: int = 256
|
| fuzz_rounds: int = 3
|
|
|
| def run_tests(self, fix_code: str) -> Tuple[float, str]:
|
| """
|
| Returns (score, output_message) where score is proportion of passed tests (0.0–1.0).
|
| """
|
|
|
| func_name = self._get_defined_function_name(fix_code)
|
| if not func_name:
|
| return 0.0, "No function definition found in agent code"
|
|
|
|
|
| canonical_bug_id = self._canonical_bug_id()
|
|
|
|
|
| test_script = self._generate_test_script(fix_code, func_name, canonical_bug_id)
|
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
|
| f.write(test_script)
|
| tmp_path = f.name
|
|
|
| try:
|
|
|
| try:
|
| import resource
|
| resource.setrlimit(resource.RLIMIT_AS, (self.max_memory_mb * 1024 * 1024, self.max_memory_mb * 1024 * 1024))
|
| except Exception:
|
| pass
|
|
|
| result = subprocess.run(
|
| [sys.executable, tmp_path],
|
| capture_output=True,
|
| text=True,
|
| timeout=self.timeout_sec,
|
| encoding='utf-8'
|
| )
|
|
|
| try:
|
| data = json.loads(result.stdout.strip())
|
| passed = data.get("passed", 0)
|
| total = data.get("total", 1)
|
| score = passed / total if total > 0 else 0.0
|
| return score, result.stdout.strip()
|
| except json.JSONDecodeError:
|
|
|
| if "True" in result.stdout:
|
| return 1.0, result.stdout
|
| return 0.0, result.stdout
|
| except subprocess.TimeoutExpired:
|
| return 0.0, "Test execution timed out"
|
| except Exception as e:
|
| return 0.0, f"Unexpected error: {str(e)}"
|
| finally:
|
| try:
|
| os.unlink(tmp_path)
|
| except:
|
| pass
|
|
|
| def _get_defined_function_name(self, code: str) -> Optional[str]:
|
| """Extract the target function name from the code.
|
| Looks for a function named 'fix' first; otherwise returns the first function found.
|
| """
|
| try:
|
| tree = ast.parse(code)
|
| first_func = None
|
| for node in ast.walk(tree):
|
| if isinstance(node, ast.FunctionDef):
|
| if node.name == "fix":
|
| return "fix"
|
| if first_func is None:
|
| first_func = node.name
|
| return first_func
|
| except SyntaxError:
|
| pass
|
| return None
|
|
|
| def _canonical_bug_id(self) -> str:
|
| """Return canonical bug family used by this test harness."""
|
| return BUG_ID_CANONICAL_MAP.get(self.bug_id, self.bug_id)
|
|
|
| def _generate_test_script(self, fix_code: str, func_name: str, canonical_bug_id: str) -> str:
|
| """Generate a test script that runs fixed + fuzzed test cases and outputs JSON."""
|
| test_cases = self._get_test_cases(canonical_bug_id, func_name)
|
| fuzzed_cases = self._generate_fuzzed_cases(canonical_bug_id, func_name)
|
| all_cases = test_cases + fuzzed_cases
|
|
|
| lines = []
|
| lines.append(fix_code)
|
| lines.append("")
|
| lines.append("import json")
|
| lines.append("")
|
| lines.append("def run_tests():")
|
| lines.append(f" test_cases = {json.dumps(all_cases)}")
|
| lines.append(" passed = 0")
|
| lines.append(" total = len(test_cases)")
|
| lines.append(" for args, expected in test_cases:")
|
| lines.append(f" try:")
|
| lines.append(f" result = {func_name}(*args) if isinstance(args, list) else {func_name}(args)")
|
| lines.append(f" if result == expected:")
|
| lines.append(f" passed += 1")
|
| lines.append(f" except Exception:")
|
| lines.append(f" pass")
|
| lines.append(" return {'passed': passed, 'total': total}")
|
| lines.append("")
|
| lines.append("if __name__ == '__main__':")
|
| lines.append(" result = run_tests()")
|
| lines.append(" print(json.dumps(result))")
|
| return "\n".join(lines)
|
|
|
| def _get_test_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
|
| """
|
| Return a list of (arguments, expected_output) for the given bug_id.
|
| Uses the actual function name (dynamic) for consistency.
|
| """
|
| if canonical_bug_id == "null_check":
|
| return [
|
| ([{"users": {"alice": "Alice"}, "id": "bob"}], None),
|
| ([{"users": {"alice": "Alice"}, "id": "alice"}], "Alice"),
|
| ]
|
| elif canonical_bug_id == "off_by_one":
|
| return [
|
| ([[1,2,3,4]], 4),
|
| ([[]], 0),
|
| ]
|
| elif canonical_bug_id == "division_by_zero":
|
| return [
|
| ([[]], 0),
|
| ([[1,2,3]], 2.0),
|
| ]
|
| elif canonical_bug_id == "wrong_operator":
|
| return [
|
| ([5,3], 8),
|
| ([-1,1], 0),
|
| ]
|
| else:
|
|
|
| return []
|
|
|
| def _generate_fuzzed_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
|
| """
|
| Generate random test cases to prevent memorisation.
|
| Only for bugs where meaningful fuzzing is possible.
|
| """
|
| cases = []
|
| if canonical_bug_id == "null_check":
|
|
|
| for _ in range(self.fuzz_rounds):
|
| users = {f"user_{i}": f"name_{i}" for i in range(random.randint(1, 5))}
|
|
|
| if random.random() > 0.5:
|
| key = random.choice(list(users.keys()))
|
| expected = users[key]
|
| else:
|
| key = "missing_" + str(random.randint(100, 999))
|
| expected = None
|
| cases.append(([{"users": users, "id": key}], expected))
|
| elif canonical_bug_id == "off_by_one":
|
| for _ in range(self.fuzz_rounds):
|
| length = random.randint(0, 10)
|
| arr = list(range(length))
|
| cases.append(([arr], length))
|
| elif canonical_bug_id == "division_by_zero":
|
| for _ in range(self.fuzz_rounds):
|
| length = random.randint(0, 10)
|
| data = [random.randint(-100, 100) for _ in range(length)]
|
| expected = sum(data)/length if length else 0
|
| cases.append(([data], expected))
|
| elif canonical_bug_id == "wrong_operator":
|
| for _ in range(self.fuzz_rounds):
|
| a = random.randint(-100, 100)
|
| b = random.randint(-100, 100)
|
| cases.append(([a, b], a + b))
|
| return cases
|
|
|