Navya-Sree's picture
Create agents/tester_agent.py
166441f verified
import subprocess
import tempfile
import os
import json
from typing import Dict, Any, List
import ast
import sys
class TesterAgent:
"""
Agent responsible for testing generated code.
Creates and runs tests to verify functionality.
"""
def __init__(self):
self.test_cases = self._load_default_test_cases()
def _load_default_test_cases(self) -> Dict:
"""Load default test cases for common functions."""
return {
"reverse_string": [
{"input": "hello", "expected": "olleh"},
{"input": "world", "expected": "dlrow"},
{"input": "", "expected": ""},
{"input": "a", "expected": "a"}
],
"factorial": [
{"input": 5, "expected": 120},
{"input": 0, "expected": 1},
{"input": 1, "expected": 1}
],
"fibonacci": [
{"input": 5, "expected": 5},
{"input": 1, "expected": 1},
{"input": 0, "expected": 0}
]
}
def extract_function_name(self, code: str) -> str:
"""
Extract the main function name from code.
Args:
code: Python code
Returns:
Name of the main function
"""
try:
tree = ast.parse(code)
# Find function definitions
functions = [
node.name for node in ast.walk(tree)
if isinstance(node, ast.FunctionDef)
]
# Return the first function, or "main_function" if none found
return functions[0] if functions else "main_function"
except:
return "main_function"
def generate_test_cases(self, code: str, prompt: str) -> List[Dict]:
"""
Generate test cases based on code and prompt.
Args:
code: Generated code
prompt: Original user prompt
Returns:
List of test cases
"""
function_name = self.extract_function_name(code)
# Check if we have predefined test cases
for key in self.test_cases:
if key in prompt.lower():
return self.test_cases[key]
# Otherwise generate simple test cases based on function name
if "reverse" in prompt.lower():
return self.test_cases["reverse_string"]
elif "factorial" in prompt.lower():
return self.test_cases["factorial"]
elif "fibonacci" in prompt.lower():
return self.test_cases["fibonacci"]
else:
# Default test cases
return [
{"input": "test_input", "expected": "expected_output"},
{"input": 123, "expected": 321}
]
def run_tests(self, code: str, test_cases: List[Dict]) -> Dict[str, Any]:
"""
Run tests on the generated code.
Args:
code: Generated code
test_cases: List of test cases
Returns:
Test results
"""
try:
function_name = self.extract_function_name(code)
# Create a temporary test file
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
# Write the original code
f.write(code)
f.write("\n\n")
# Write test code
f.write("import sys\n")
f.write("def run_all_tests():\n")
f.write(" results = []\n")
f.write(" errors = []\n")
for i, test_case in enumerate(test_cases):
input_val = repr(test_case["input"])
expected = repr(test_case["expected"])
f.write(f" # Test case {i+1}\n")
f.write(f" try:\n")
f.write(f" result = {function_name}({input_val})\n")
f.write(f" passed = result == {expected}\n")
f.write(f" results.append({{\n")
f.write(f" 'test_id': {i+1},\n")
f.write(f" 'input': {input_val},\n")
f.write(f" 'expected': {expected},\n")
f.write(f" 'actual': result,\n")
f.write(f" 'passed': passed\n")
f.write(f" }})\n")
f.write(f" if not passed:\n")
f.write(f" errors.append(f'Test {i+1} failed: expected {expected}, got {{result}}')\n")
f.write(f" except Exception as e:\n")
f.write(f" results.append({{\n")
f.write(f" 'test_id': {i+1},\n")
f.write(f" 'input': {input_val},\n")
f.write(f" 'expected': {expected},\n")
f.write(f" 'actual': str(e),\n")
f.write(f" 'passed': False,\n")
f.write(f" 'error': str(e)\n")
f.write(f" }})\n")
f.write(f" errors.append(f'Test {i+1} error: {{e}}')\n")
f.write(f" return results, errors\n")
f.write(f"\n")
f.write(f"if __name__ == '__main__':\n")
f.write(f" results, errors = run_all_tests()\n")
f.write(f" print('TEST_RESULTS_START')\n")
f.write(f" print(json.dumps(results))\n")
f.write(f" print('TEST_RESULTS_END')\n")
f.write(f" print('ERRORS_START')\n")
f.write(f" print(json.dumps(errors))\n")
f.write(f" print('ERRORS_END')\n")
temp_file_path = f.name
# Import json module for printing
with open(temp_file_path, 'r+') as f:
content = f.read()
f.seek(0)
f.write("import json\n" + content)
# Execute the test file
result = subprocess.run(
[sys.executable, temp_file_path],
capture_output=True,
text=True,
timeout=10 # Timeout after 10 seconds
)
# Clean up
os.unlink(temp_file_path)
# Parse results
output = result.stdout
if 'TEST_RESULTS_START' in output:
# Extract JSON results
results_json = output.split('TEST_RESULTS_START')[1].split('TEST_RESULTS_END')[0].strip()
errors_json = output.split('ERRORS_START')[1].split('ERRORS_END')[0].strip()
results = json.loads(results_json)
errors = json.loads(errors_json)
# Calculate metrics
total_tests = len(results)
passed_tests = sum(1 for r in results if r.get('passed', False))
pass_rate = (passed_tests / total_tests * 100) if total_tests > 0 else 0
return {
"status": "success",
"results": results,
"errors": errors,
"metrics": {
"total_tests": total_tests,
"passed_tests": passed_tests,
"pass_rate": pass_rate,
"has_errors": len(errors) > 0
}
}
else:
return {
"status": "error",
"error": f"Execution failed: {result.stderr}",
"results": [],
"errors": [result.stderr]
}
except subprocess.TimeoutExpired:
return {
"status": "error",
"error": "Test execution timed out after 10 seconds",
"results": [],
"errors": ["Timeout error"]
}
except Exception as e:
return {
"status": "error",
"error": str(e),
"results": [],
"errors": [str(e)]
}
def test_code(self, code: str, prompt: str) -> Dict[str, Any]:
"""
Complete testing workflow.
Args:
code: Generated code
prompt: Original user prompt
Returns:
Complete test results
"""
# Generate test cases
test_cases = self.generate_test_cases(code, prompt)
# Run tests
test_results = self.run_tests(code, test_cases)
return {
"test_cases": test_cases,
"test_results": test_results,
"function_name": self.extract_function_name(code)
}