Navya-Sree commited on
Commit
166441f
·
verified ·
1 Parent(s): bfc9de1

Create agents/tester_agent.py

Browse files
Files changed (1) hide show
  1. agents/tester_agent.py +251 -0
agents/tester_agent.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import tempfile
3
+ import os
4
+ import json
5
+ from typing import Dict, Any, List
6
+ import ast
7
+ import sys
8
+
9
+ class TesterAgent:
10
+ """
11
+ Agent responsible for testing generated code.
12
+ Creates and runs tests to verify functionality.
13
+ """
14
+
15
+ def __init__(self):
16
+ self.test_cases = self._load_default_test_cases()
17
+
18
+ def _load_default_test_cases(self) -> Dict:
19
+ """Load default test cases for common functions."""
20
+ return {
21
+ "reverse_string": [
22
+ {"input": "hello", "expected": "olleh"},
23
+ {"input": "world", "expected": "dlrow"},
24
+ {"input": "", "expected": ""},
25
+ {"input": "a", "expected": "a"}
26
+ ],
27
+ "factorial": [
28
+ {"input": 5, "expected": 120},
29
+ {"input": 0, "expected": 1},
30
+ {"input": 1, "expected": 1}
31
+ ],
32
+ "fibonacci": [
33
+ {"input": 5, "expected": 5},
34
+ {"input": 1, "expected": 1},
35
+ {"input": 0, "expected": 0}
36
+ ]
37
+ }
38
+
39
+ def extract_function_name(self, code: str) -> str:
40
+ """
41
+ Extract the main function name from code.
42
+
43
+ Args:
44
+ code: Python code
45
+
46
+ Returns:
47
+ Name of the main function
48
+ """
49
+ try:
50
+ tree = ast.parse(code)
51
+
52
+ # Find function definitions
53
+ functions = [
54
+ node.name for node in ast.walk(tree)
55
+ if isinstance(node, ast.FunctionDef)
56
+ ]
57
+
58
+ # Return the first function, or "main_function" if none found
59
+ return functions[0] if functions else "main_function"
60
+
61
+ except:
62
+ return "main_function"
63
+
64
+ def generate_test_cases(self, code: str, prompt: str) -> List[Dict]:
65
+ """
66
+ Generate test cases based on code and prompt.
67
+
68
+ Args:
69
+ code: Generated code
70
+ prompt: Original user prompt
71
+
72
+ Returns:
73
+ List of test cases
74
+ """
75
+ function_name = self.extract_function_name(code)
76
+
77
+ # Check if we have predefined test cases
78
+ for key in self.test_cases:
79
+ if key in prompt.lower():
80
+ return self.test_cases[key]
81
+
82
+ # Otherwise generate simple test cases based on function name
83
+ if "reverse" in prompt.lower():
84
+ return self.test_cases["reverse_string"]
85
+ elif "factorial" in prompt.lower():
86
+ return self.test_cases["factorial"]
87
+ elif "fibonacci" in prompt.lower():
88
+ return self.test_cases["fibonacci"]
89
+ else:
90
+ # Default test cases
91
+ return [
92
+ {"input": "test_input", "expected": "expected_output"},
93
+ {"input": 123, "expected": 321}
94
+ ]
95
+
96
+ def run_tests(self, code: str, test_cases: List[Dict]) -> Dict[str, Any]:
97
+ """
98
+ Run tests on the generated code.
99
+
100
+ Args:
101
+ code: Generated code
102
+ test_cases: List of test cases
103
+
104
+ Returns:
105
+ Test results
106
+ """
107
+ try:
108
+ function_name = self.extract_function_name(code)
109
+
110
+ # Create a temporary test file
111
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
112
+ # Write the original code
113
+ f.write(code)
114
+ f.write("\n\n")
115
+
116
+ # Write test code
117
+ f.write("import sys\n")
118
+ f.write("def run_all_tests():\n")
119
+ f.write(" results = []\n")
120
+ f.write(" errors = []\n")
121
+
122
+ for i, test_case in enumerate(test_cases):
123
+ input_val = repr(test_case["input"])
124
+ expected = repr(test_case["expected"])
125
+
126
+ f.write(f" # Test case {i+1}\n")
127
+ f.write(f" try:\n")
128
+ f.write(f" result = {function_name}({input_val})\n")
129
+ f.write(f" passed = result == {expected}\n")
130
+ f.write(f" results.append({{\n")
131
+ f.write(f" 'test_id': {i+1},\n")
132
+ f.write(f" 'input': {input_val},\n")
133
+ f.write(f" 'expected': {expected},\n")
134
+ f.write(f" 'actual': result,\n")
135
+ f.write(f" 'passed': passed\n")
136
+ f.write(f" }})\n")
137
+ f.write(f" if not passed:\n")
138
+ f.write(f" errors.append(f'Test {i+1} failed: expected {expected}, got {{result}}')\n")
139
+ f.write(f" except Exception as e:\n")
140
+ f.write(f" results.append({{\n")
141
+ f.write(f" 'test_id': {i+1},\n")
142
+ f.write(f" 'input': {input_val},\n")
143
+ f.write(f" 'expected': {expected},\n")
144
+ f.write(f" 'actual': str(e),\n")
145
+ f.write(f" 'passed': False,\n")
146
+ f.write(f" 'error': str(e)\n")
147
+ f.write(f" }})\n")
148
+ f.write(f" errors.append(f'Test {i+1} error: {{e}}')\n")
149
+
150
+ f.write(f" return results, errors\n")
151
+ f.write(f"\n")
152
+ f.write(f"if __name__ == '__main__':\n")
153
+ f.write(f" results, errors = run_all_tests()\n")
154
+ f.write(f" print('TEST_RESULTS_START')\n")
155
+ f.write(f" print(json.dumps(results))\n")
156
+ f.write(f" print('TEST_RESULTS_END')\n")
157
+ f.write(f" print('ERRORS_START')\n")
158
+ f.write(f" print(json.dumps(errors))\n")
159
+ f.write(f" print('ERRORS_END')\n")
160
+
161
+ temp_file_path = f.name
162
+
163
+ # Import json module for printing
164
+ with open(temp_file_path, 'r+') as f:
165
+ content = f.read()
166
+ f.seek(0)
167
+ f.write("import json\n" + content)
168
+
169
+ # Execute the test file
170
+ result = subprocess.run(
171
+ [sys.executable, temp_file_path],
172
+ capture_output=True,
173
+ text=True,
174
+ timeout=10 # Timeout after 10 seconds
175
+ )
176
+
177
+ # Clean up
178
+ os.unlink(temp_file_path)
179
+
180
+ # Parse results
181
+ output = result.stdout
182
+
183
+ if 'TEST_RESULTS_START' in output:
184
+ # Extract JSON results
185
+ results_json = output.split('TEST_RESULTS_START')[1].split('TEST_RESULTS_END')[0].strip()
186
+ errors_json = output.split('ERRORS_START')[1].split('ERRORS_END')[0].strip()
187
+
188
+ results = json.loads(results_json)
189
+ errors = json.loads(errors_json)
190
+
191
+ # Calculate metrics
192
+ total_tests = len(results)
193
+ passed_tests = sum(1 for r in results if r.get('passed', False))
194
+ pass_rate = (passed_tests / total_tests * 100) if total_tests > 0 else 0
195
+
196
+ return {
197
+ "status": "success",
198
+ "results": results,
199
+ "errors": errors,
200
+ "metrics": {
201
+ "total_tests": total_tests,
202
+ "passed_tests": passed_tests,
203
+ "pass_rate": pass_rate,
204
+ "has_errors": len(errors) > 0
205
+ }
206
+ }
207
+ else:
208
+ return {
209
+ "status": "error",
210
+ "error": f"Execution failed: {result.stderr}",
211
+ "results": [],
212
+ "errors": [result.stderr]
213
+ }
214
+
215
+ except subprocess.TimeoutExpired:
216
+ return {
217
+ "status": "error",
218
+ "error": "Test execution timed out after 10 seconds",
219
+ "results": [],
220
+ "errors": ["Timeout error"]
221
+ }
222
+ except Exception as e:
223
+ return {
224
+ "status": "error",
225
+ "error": str(e),
226
+ "results": [],
227
+ "errors": [str(e)]
228
+ }
229
+
230
+ def test_code(self, code: str, prompt: str) -> Dict[str, Any]:
231
+ """
232
+ Complete testing workflow.
233
+
234
+ Args:
235
+ code: Generated code
236
+ prompt: Original user prompt
237
+
238
+ Returns:
239
+ Complete test results
240
+ """
241
+ # Generate test cases
242
+ test_cases = self.generate_test_cases(code, prompt)
243
+
244
+ # Run tests
245
+ test_results = self.run_tests(code, test_cases)
246
+
247
+ return {
248
+ "test_cases": test_cases,
249
+ "test_results": test_results,
250
+ "function_name": self.extract_function_name(code)
251
+ }