Spaces:
Build error
Build error
| import subprocess | |
| import tempfile | |
| import os | |
| import json | |
| from typing import Dict, Any, List | |
| import ast | |
| import sys | |
| class TesterAgent: | |
| """ | |
| Agent responsible for testing generated code. | |
| Creates and runs tests to verify functionality. | |
| """ | |
| def __init__(self): | |
| self.test_cases = self._load_default_test_cases() | |
| def _load_default_test_cases(self) -> Dict: | |
| """Load default test cases for common functions.""" | |
| return { | |
| "reverse_string": [ | |
| {"input": "hello", "expected": "olleh"}, | |
| {"input": "world", "expected": "dlrow"}, | |
| {"input": "", "expected": ""}, | |
| {"input": "a", "expected": "a"} | |
| ], | |
| "factorial": [ | |
| {"input": 5, "expected": 120}, | |
| {"input": 0, "expected": 1}, | |
| {"input": 1, "expected": 1} | |
| ], | |
| "fibonacci": [ | |
| {"input": 5, "expected": 5}, | |
| {"input": 1, "expected": 1}, | |
| {"input": 0, "expected": 0} | |
| ] | |
| } | |
| def extract_function_name(self, code: str) -> str: | |
| """ | |
| Extract the main function name from code. | |
| Args: | |
| code: Python code | |
| Returns: | |
| Name of the main function | |
| """ | |
| try: | |
| tree = ast.parse(code) | |
| # Find function definitions | |
| functions = [ | |
| node.name for node in ast.walk(tree) | |
| if isinstance(node, ast.FunctionDef) | |
| ] | |
| # Return the first function, or "main_function" if none found | |
| return functions[0] if functions else "main_function" | |
| except: | |
| return "main_function" | |
| def generate_test_cases(self, code: str, prompt: str) -> List[Dict]: | |
| """ | |
| Generate test cases based on code and prompt. | |
| Args: | |
| code: Generated code | |
| prompt: Original user prompt | |
| Returns: | |
| List of test cases | |
| """ | |
| function_name = self.extract_function_name(code) | |
| # Check if we have predefined test cases | |
| for key in self.test_cases: | |
| if key in prompt.lower(): | |
| return self.test_cases[key] | |
| # Otherwise generate simple test cases based on function name | |
| if "reverse" in prompt.lower(): | |
| return self.test_cases["reverse_string"] | |
| elif "factorial" in prompt.lower(): | |
| return self.test_cases["factorial"] | |
| elif "fibonacci" in prompt.lower(): | |
| return self.test_cases["fibonacci"] | |
| else: | |
| # Default test cases | |
| return [ | |
| {"input": "test_input", "expected": "expected_output"}, | |
| {"input": 123, "expected": 321} | |
| ] | |
| def run_tests(self, code: str, test_cases: List[Dict]) -> Dict[str, Any]: | |
| """ | |
| Run tests on the generated code. | |
| Args: | |
| code: Generated code | |
| test_cases: List of test cases | |
| Returns: | |
| Test results | |
| """ | |
| try: | |
| function_name = self.extract_function_name(code) | |
| # Create a temporary test file | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: | |
| # Write the original code | |
| f.write(code) | |
| f.write("\n\n") | |
| # Write test code | |
| f.write("import sys\n") | |
| f.write("def run_all_tests():\n") | |
| f.write(" results = []\n") | |
| f.write(" errors = []\n") | |
| for i, test_case in enumerate(test_cases): | |
| input_val = repr(test_case["input"]) | |
| expected = repr(test_case["expected"]) | |
| f.write(f" # Test case {i+1}\n") | |
| f.write(f" try:\n") | |
| f.write(f" result = {function_name}({input_val})\n") | |
| f.write(f" passed = result == {expected}\n") | |
| f.write(f" results.append({{\n") | |
| f.write(f" 'test_id': {i+1},\n") | |
| f.write(f" 'input': {input_val},\n") | |
| f.write(f" 'expected': {expected},\n") | |
| f.write(f" 'actual': result,\n") | |
| f.write(f" 'passed': passed\n") | |
| f.write(f" }})\n") | |
| f.write(f" if not passed:\n") | |
| f.write(f" errors.append(f'Test {i+1} failed: expected {expected}, got {{result}}')\n") | |
| f.write(f" except Exception as e:\n") | |
| f.write(f" results.append({{\n") | |
| f.write(f" 'test_id': {i+1},\n") | |
| f.write(f" 'input': {input_val},\n") | |
| f.write(f" 'expected': {expected},\n") | |
| f.write(f" 'actual': str(e),\n") | |
| f.write(f" 'passed': False,\n") | |
| f.write(f" 'error': str(e)\n") | |
| f.write(f" }})\n") | |
| f.write(f" errors.append(f'Test {i+1} error: {{e}}')\n") | |
| f.write(f" return results, errors\n") | |
| f.write(f"\n") | |
| f.write(f"if __name__ == '__main__':\n") | |
| f.write(f" results, errors = run_all_tests()\n") | |
| f.write(f" print('TEST_RESULTS_START')\n") | |
| f.write(f" print(json.dumps(results))\n") | |
| f.write(f" print('TEST_RESULTS_END')\n") | |
| f.write(f" print('ERRORS_START')\n") | |
| f.write(f" print(json.dumps(errors))\n") | |
| f.write(f" print('ERRORS_END')\n") | |
| temp_file_path = f.name | |
| # Import json module for printing | |
| with open(temp_file_path, 'r+') as f: | |
| content = f.read() | |
| f.seek(0) | |
| f.write("import json\n" + content) | |
| # Execute the test file | |
| result = subprocess.run( | |
| [sys.executable, temp_file_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=10 # Timeout after 10 seconds | |
| ) | |
| # Clean up | |
| os.unlink(temp_file_path) | |
| # Parse results | |
| output = result.stdout | |
| if 'TEST_RESULTS_START' in output: | |
| # Extract JSON results | |
| results_json = output.split('TEST_RESULTS_START')[1].split('TEST_RESULTS_END')[0].strip() | |
| errors_json = output.split('ERRORS_START')[1].split('ERRORS_END')[0].strip() | |
| results = json.loads(results_json) | |
| errors = json.loads(errors_json) | |
| # Calculate metrics | |
| total_tests = len(results) | |
| passed_tests = sum(1 for r in results if r.get('passed', False)) | |
| pass_rate = (passed_tests / total_tests * 100) if total_tests > 0 else 0 | |
| return { | |
| "status": "success", | |
| "results": results, | |
| "errors": errors, | |
| "metrics": { | |
| "total_tests": total_tests, | |
| "passed_tests": passed_tests, | |
| "pass_rate": pass_rate, | |
| "has_errors": len(errors) > 0 | |
| } | |
| } | |
| else: | |
| return { | |
| "status": "error", | |
| "error": f"Execution failed: {result.stderr}", | |
| "results": [], | |
| "errors": [result.stderr] | |
| } | |
| except subprocess.TimeoutExpired: | |
| return { | |
| "status": "error", | |
| "error": "Test execution timed out after 10 seconds", | |
| "results": [], | |
| "errors": ["Timeout error"] | |
| } | |
| except Exception as e: | |
| return { | |
| "status": "error", | |
| "error": str(e), | |
| "results": [], | |
| "errors": [str(e)] | |
| } | |
| def test_code(self, code: str, prompt: str) -> Dict[str, Any]: | |
| """ | |
| Complete testing workflow. | |
| Args: | |
| code: Generated code | |
| prompt: Original user prompt | |
| Returns: | |
| Complete test results | |
| """ | |
| # Generate test cases | |
| test_cases = self.generate_test_cases(code, prompt) | |
| # Run tests | |
| test_results = self.run_tests(code, test_cases) | |
| return { | |
| "test_cases": test_cases, | |
| "test_results": test_results, | |
| "function_name": self.extract_function_name(code) | |
| } |