File size: 8,826 Bytes
1588266 d089573 1588266 d089573 1588266 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | # test_runner.py – Full production version with continuous scoring, dynamic function detection, and randomised tests
import subprocess
import tempfile
import os
import json
import ast
import random
import sys
from typing import Tuple, List, Any, Optional
from dataclasses import dataclass
# Bridge fine-grained RedTeam ids to canonical TestRunner families.
# This keeps evaluation stable even when bug generators use richer labels.
BUG_ID_CANONICAL_MAP = {
# Easy-family bugs on `get_user`-style behavior.
"simple_typo": "null_check",
"default_value": "null_check",
"empty_return": "null_check",
"string_index": "off_by_one",
# Medium arithmetic/control-flow aliases.
"loop_skip": "off_by_one",
"sign_error": "wrong_operator",
"swap_args": "wrong_operator",
"uninitialised_var": "null_check",
# Hard numeric-safety aliases.
"division_by_zero_empty": "division_by_zero",
"division_by_zero_zero": "division_by_zero",
"float_precision": "division_by_zero",
"abs_usage": "division_by_zero",
"round_error": "division_by_zero",
}
@dataclass
class TestRunner:
bug_id: str
timeout_sec: int = 5
max_memory_mb: int = 256
fuzz_rounds: int = 3 # number of random test cases per bug
def run_tests(self, fix_code: str) -> Tuple[float, str]:
"""
Returns (score, output_message) where score is proportion of passed tests (0.0–1.0).
"""
# 1. Detect the function defined in the agent's code (dynamic)
func_name = self._get_defined_function_name(fix_code)
if not func_name:
return 0.0, "No function definition found in agent code"
# 2. Normalize bug id so broader RedTeam ids still hit meaningful tests.
canonical_bug_id = self._canonical_bug_id()
# 3. Generate the test script (includes fixed + fuzzed test cases)
test_script = self._generate_test_script(fix_code, func_name, canonical_bug_id)
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f:
f.write(test_script)
tmp_path = f.name
try:
# Resource limiting (Linux only; fallback otherwise)
try:
import resource
resource.setrlimit(resource.RLIMIT_AS, (self.max_memory_mb * 1024 * 1024, self.max_memory_mb * 1024 * 1024))
except Exception:
pass
result = subprocess.run(
[sys.executable, tmp_path],
capture_output=True,
text=True,
timeout=self.timeout_sec,
encoding='utf-8'
)
# Parse JSON output
try:
data = json.loads(result.stdout.strip())
passed = data.get("passed", 0)
total = data.get("total", 1)
score = passed / total if total > 0 else 0.0
return score, result.stdout.strip()
except json.JSONDecodeError:
# Fallback: look for "True" (legacy)
if "True" in result.stdout:
return 1.0, result.stdout
return 0.0, result.stdout
except subprocess.TimeoutExpired:
return 0.0, "Test execution timed out"
except Exception as e:
return 0.0, f"Unexpected error: {str(e)}"
finally:
try:
os.unlink(tmp_path)
except:
pass
def _get_defined_function_name(self, code: str) -> Optional[str]:
"""Extract the target function name from the code.
Looks for a function named 'fix' first; otherwise returns the first function found.
"""
try:
tree = ast.parse(code)
first_func = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
if node.name == "fix":
return "fix"
if first_func is None:
first_func = node.name
return first_func # fallback if no 'fix' function exists
except SyntaxError:
pass
return None
def _canonical_bug_id(self) -> str:
"""Return canonical bug family used by this test harness."""
return BUG_ID_CANONICAL_MAP.get(self.bug_id, self.bug_id)
def _generate_test_script(self, fix_code: str, func_name: str, canonical_bug_id: str) -> str:
"""Generate a test script that runs fixed + fuzzed test cases and outputs JSON."""
test_cases = self._get_test_cases(canonical_bug_id, func_name)
fuzzed_cases = self._generate_fuzzed_cases(canonical_bug_id, func_name)
all_cases = test_cases + fuzzed_cases
lines = []
lines.append(fix_code)
lines.append("")
lines.append("import json")
lines.append("")
lines.append("def run_tests():")
lines.append(f" test_cases = {json.dumps(all_cases)}")
lines.append(" passed = 0")
lines.append(" total = len(test_cases)")
lines.append(" for args, expected in test_cases:")
lines.append(f" try:")
lines.append(f" result = {func_name}(*args) if isinstance(args, list) else {func_name}(args)")
lines.append(f" if result == expected:")
lines.append(f" passed += 1")
lines.append(f" except Exception:")
lines.append(f" pass")
lines.append(" return {'passed': passed, 'total': total}")
lines.append("")
lines.append("if __name__ == '__main__':")
lines.append(" result = run_tests()")
lines.append(" print(json.dumps(result))")
return "\n".join(lines)
def _get_test_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
"""
Return a list of (arguments, expected_output) for the given bug_id.
Uses the actual function name (dynamic) for consistency.
"""
if canonical_bug_id == "null_check":
return [
([{"users": {"alice": "Alice"}, "id": "bob"}], None), # missing key should not crash
([{"users": {"alice": "Alice"}, "id": "alice"}], "Alice"),
]
elif canonical_bug_id == "off_by_one":
return [
([[1,2,3,4]], 4),
([[]], 0),
]
elif canonical_bug_id == "division_by_zero":
return [
([[]], 0),
([[1,2,3]], 2.0),
]
elif canonical_bug_id == "wrong_operator":
return [
([5,3], 8),
([-1,1], 0),
]
else:
# For missing_lock, deadlock_order, etc., return empty list (will be handled gracefully)
return []
def _generate_fuzzed_cases(self, canonical_bug_id: str, func_name: str) -> List[Tuple[List[Any], Any]]:
"""
Generate random test cases to prevent memorisation.
Only for bugs where meaningful fuzzing is possible.
"""
cases = []
if canonical_bug_id == "null_check":
# Random users dictionary and random ids
for _ in range(self.fuzz_rounds):
users = {f"user_{i}": f"name_{i}" for i in range(random.randint(1, 5))}
# Pick existing or missing key
if random.random() > 0.5:
key = random.choice(list(users.keys()))
expected = users[key]
else:
key = "missing_" + str(random.randint(100, 999))
expected = None
cases.append(([{"users": users, "id": key}], expected))
elif canonical_bug_id == "off_by_one":
for _ in range(self.fuzz_rounds):
length = random.randint(0, 10)
arr = list(range(length))
cases.append(([arr], length))
elif canonical_bug_id == "division_by_zero":
for _ in range(self.fuzz_rounds):
length = random.randint(0, 10)
data = [random.randint(-100, 100) for _ in range(length)]
expected = sum(data)/length if length else 0
cases.append(([data], expected))
elif canonical_bug_id == "wrong_operator":
for _ in range(self.fuzz_rounds):
a = random.randint(-100, 100)
b = random.randint(-100, 100)
cases.append(([a, b], a + b))
return cases
|