Spaces:
Build error
Build error
Create agents/tester_agent.py
Browse files- agents/tester_agent.py +251 -0
agents/tester_agent.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import tempfile
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
import ast
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
class TesterAgent:
|
| 10 |
+
"""
|
| 11 |
+
Agent responsible for testing generated code.
|
| 12 |
+
Creates and runs tests to verify functionality.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.test_cases = self._load_default_test_cases()
|
| 17 |
+
|
| 18 |
+
def _load_default_test_cases(self) -> Dict:
|
| 19 |
+
"""Load default test cases for common functions."""
|
| 20 |
+
return {
|
| 21 |
+
"reverse_string": [
|
| 22 |
+
{"input": "hello", "expected": "olleh"},
|
| 23 |
+
{"input": "world", "expected": "dlrow"},
|
| 24 |
+
{"input": "", "expected": ""},
|
| 25 |
+
{"input": "a", "expected": "a"}
|
| 26 |
+
],
|
| 27 |
+
"factorial": [
|
| 28 |
+
{"input": 5, "expected": 120},
|
| 29 |
+
{"input": 0, "expected": 1},
|
| 30 |
+
{"input": 1, "expected": 1}
|
| 31 |
+
],
|
| 32 |
+
"fibonacci": [
|
| 33 |
+
{"input": 5, "expected": 5},
|
| 34 |
+
{"input": 1, "expected": 1},
|
| 35 |
+
{"input": 0, "expected": 0}
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
def extract_function_name(self, code: str) -> str:
|
| 40 |
+
"""
|
| 41 |
+
Extract the main function name from code.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
code: Python code
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Name of the main function
|
| 48 |
+
"""
|
| 49 |
+
try:
|
| 50 |
+
tree = ast.parse(code)
|
| 51 |
+
|
| 52 |
+
# Find function definitions
|
| 53 |
+
functions = [
|
| 54 |
+
node.name for node in ast.walk(tree)
|
| 55 |
+
if isinstance(node, ast.FunctionDef)
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
# Return the first function, or "main_function" if none found
|
| 59 |
+
return functions[0] if functions else "main_function"
|
| 60 |
+
|
| 61 |
+
except:
|
| 62 |
+
return "main_function"
|
| 63 |
+
|
| 64 |
+
def generate_test_cases(self, code: str, prompt: str) -> List[Dict]:
|
| 65 |
+
"""
|
| 66 |
+
Generate test cases based on code and prompt.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
code: Generated code
|
| 70 |
+
prompt: Original user prompt
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
List of test cases
|
| 74 |
+
"""
|
| 75 |
+
function_name = self.extract_function_name(code)
|
| 76 |
+
|
| 77 |
+
# Check if we have predefined test cases
|
| 78 |
+
for key in self.test_cases:
|
| 79 |
+
if key in prompt.lower():
|
| 80 |
+
return self.test_cases[key]
|
| 81 |
+
|
| 82 |
+
# Otherwise generate simple test cases based on function name
|
| 83 |
+
if "reverse" in prompt.lower():
|
| 84 |
+
return self.test_cases["reverse_string"]
|
| 85 |
+
elif "factorial" in prompt.lower():
|
| 86 |
+
return self.test_cases["factorial"]
|
| 87 |
+
elif "fibonacci" in prompt.lower():
|
| 88 |
+
return self.test_cases["fibonacci"]
|
| 89 |
+
else:
|
| 90 |
+
# Default test cases
|
| 91 |
+
return [
|
| 92 |
+
{"input": "test_input", "expected": "expected_output"},
|
| 93 |
+
{"input": 123, "expected": 321}
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
def run_tests(self, code: str, test_cases: List[Dict]) -> Dict[str, Any]:
|
| 97 |
+
"""
|
| 98 |
+
Run tests on the generated code.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
code: Generated code
|
| 102 |
+
test_cases: List of test cases
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Test results
|
| 106 |
+
"""
|
| 107 |
+
try:
|
| 108 |
+
function_name = self.extract_function_name(code)
|
| 109 |
+
|
| 110 |
+
# Create a temporary test file
|
| 111 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| 112 |
+
# Write the original code
|
| 113 |
+
f.write(code)
|
| 114 |
+
f.write("\n\n")
|
| 115 |
+
|
| 116 |
+
# Write test code
|
| 117 |
+
f.write("import sys\n")
|
| 118 |
+
f.write("def run_all_tests():\n")
|
| 119 |
+
f.write(" results = []\n")
|
| 120 |
+
f.write(" errors = []\n")
|
| 121 |
+
|
| 122 |
+
for i, test_case in enumerate(test_cases):
|
| 123 |
+
input_val = repr(test_case["input"])
|
| 124 |
+
expected = repr(test_case["expected"])
|
| 125 |
+
|
| 126 |
+
f.write(f" # Test case {i+1}\n")
|
| 127 |
+
f.write(f" try:\n")
|
| 128 |
+
f.write(f" result = {function_name}({input_val})\n")
|
| 129 |
+
f.write(f" passed = result == {expected}\n")
|
| 130 |
+
f.write(f" results.append({{\n")
|
| 131 |
+
f.write(f" 'test_id': {i+1},\n")
|
| 132 |
+
f.write(f" 'input': {input_val},\n")
|
| 133 |
+
f.write(f" 'expected': {expected},\n")
|
| 134 |
+
f.write(f" 'actual': result,\n")
|
| 135 |
+
f.write(f" 'passed': passed\n")
|
| 136 |
+
f.write(f" }})\n")
|
| 137 |
+
f.write(f" if not passed:\n")
|
| 138 |
+
f.write(f" errors.append(f'Test {i+1} failed: expected {expected}, got {{result}}')\n")
|
| 139 |
+
f.write(f" except Exception as e:\n")
|
| 140 |
+
f.write(f" results.append({{\n")
|
| 141 |
+
f.write(f" 'test_id': {i+1},\n")
|
| 142 |
+
f.write(f" 'input': {input_val},\n")
|
| 143 |
+
f.write(f" 'expected': {expected},\n")
|
| 144 |
+
f.write(f" 'actual': str(e),\n")
|
| 145 |
+
f.write(f" 'passed': False,\n")
|
| 146 |
+
f.write(f" 'error': str(e)\n")
|
| 147 |
+
f.write(f" }})\n")
|
| 148 |
+
f.write(f" errors.append(f'Test {i+1} error: {{e}}')\n")
|
| 149 |
+
|
| 150 |
+
f.write(f" return results, errors\n")
|
| 151 |
+
f.write(f"\n")
|
| 152 |
+
f.write(f"if __name__ == '__main__':\n")
|
| 153 |
+
f.write(f" results, errors = run_all_tests()\n")
|
| 154 |
+
f.write(f" print('TEST_RESULTS_START')\n")
|
| 155 |
+
f.write(f" print(json.dumps(results))\n")
|
| 156 |
+
f.write(f" print('TEST_RESULTS_END')\n")
|
| 157 |
+
f.write(f" print('ERRORS_START')\n")
|
| 158 |
+
f.write(f" print(json.dumps(errors))\n")
|
| 159 |
+
f.write(f" print('ERRORS_END')\n")
|
| 160 |
+
|
| 161 |
+
temp_file_path = f.name
|
| 162 |
+
|
| 163 |
+
# Import json module for printing
|
| 164 |
+
with open(temp_file_path, 'r+') as f:
|
| 165 |
+
content = f.read()
|
| 166 |
+
f.seek(0)
|
| 167 |
+
f.write("import json\n" + content)
|
| 168 |
+
|
| 169 |
+
# Execute the test file
|
| 170 |
+
result = subprocess.run(
|
| 171 |
+
[sys.executable, temp_file_path],
|
| 172 |
+
capture_output=True,
|
| 173 |
+
text=True,
|
| 174 |
+
timeout=10 # Timeout after 10 seconds
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Clean up
|
| 178 |
+
os.unlink(temp_file_path)
|
| 179 |
+
|
| 180 |
+
# Parse results
|
| 181 |
+
output = result.stdout
|
| 182 |
+
|
| 183 |
+
if 'TEST_RESULTS_START' in output:
|
| 184 |
+
# Extract JSON results
|
| 185 |
+
results_json = output.split('TEST_RESULTS_START')[1].split('TEST_RESULTS_END')[0].strip()
|
| 186 |
+
errors_json = output.split('ERRORS_START')[1].split('ERRORS_END')[0].strip()
|
| 187 |
+
|
| 188 |
+
results = json.loads(results_json)
|
| 189 |
+
errors = json.loads(errors_json)
|
| 190 |
+
|
| 191 |
+
# Calculate metrics
|
| 192 |
+
total_tests = len(results)
|
| 193 |
+
passed_tests = sum(1 for r in results if r.get('passed', False))
|
| 194 |
+
pass_rate = (passed_tests / total_tests * 100) if total_tests > 0 else 0
|
| 195 |
+
|
| 196 |
+
return {
|
| 197 |
+
"status": "success",
|
| 198 |
+
"results": results,
|
| 199 |
+
"errors": errors,
|
| 200 |
+
"metrics": {
|
| 201 |
+
"total_tests": total_tests,
|
| 202 |
+
"passed_tests": passed_tests,
|
| 203 |
+
"pass_rate": pass_rate,
|
| 204 |
+
"has_errors": len(errors) > 0
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
else:
|
| 208 |
+
return {
|
| 209 |
+
"status": "error",
|
| 210 |
+
"error": f"Execution failed: {result.stderr}",
|
| 211 |
+
"results": [],
|
| 212 |
+
"errors": [result.stderr]
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
except subprocess.TimeoutExpired:
|
| 216 |
+
return {
|
| 217 |
+
"status": "error",
|
| 218 |
+
"error": "Test execution timed out after 10 seconds",
|
| 219 |
+
"results": [],
|
| 220 |
+
"errors": ["Timeout error"]
|
| 221 |
+
}
|
| 222 |
+
except Exception as e:
|
| 223 |
+
return {
|
| 224 |
+
"status": "error",
|
| 225 |
+
"error": str(e),
|
| 226 |
+
"results": [],
|
| 227 |
+
"errors": [str(e)]
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
def test_code(self, code: str, prompt: str) -> Dict[str, Any]:
|
| 231 |
+
"""
|
| 232 |
+
Complete testing workflow.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
code: Generated code
|
| 236 |
+
prompt: Original user prompt
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Complete test results
|
| 240 |
+
"""
|
| 241 |
+
# Generate test cases
|
| 242 |
+
test_cases = self.generate_test_cases(code, prompt)
|
| 243 |
+
|
| 244 |
+
# Run tests
|
| 245 |
+
test_results = self.run_tests(code, test_cases)
|
| 246 |
+
|
| 247 |
+
return {
|
| 248 |
+
"test_cases": test_cases,
|
| 249 |
+
"test_results": test_results,
|
| 250 |
+
"function_name": self.extract_function_name(code)
|
| 251 |
+
}
|