Yago Bolivar commited on
Commit
0d2816b
·
1 Parent(s): b7e30dd

feat: implement CodeExecutionTool for safe code execution and output extraction

Browse files

test: add unit tests for CodeExecutionTool's safety analysis and functionality

src/python_tool.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import contextlib
3
+ import io
4
+ import signal
5
+ import re
6
+ import traceback
7
+ from typing import Dict, Any, Optional, Union, List
8
+
9
+ class CodeExecutionTool:
10
+ """Tool to safely execute Python code files and extract numeric outputs."""
11
+
12
+ def __init__(self, timeout: int = 5, max_output_size: int = 10000):
13
+ self.timeout = timeout # Maximum execution time in seconds
14
+ self.max_output_size = max_output_size
15
+ # Restricted imports - add more as needed
16
+ self.banned_modules = [
17
+ 'os', 'subprocess', 'sys', 'builtins', 'importlib', 'eval',
18
+ 'pickle', 'requests', 'socket', 'shutil'
19
+ ]
20
+
21
+ def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
22
+ """Perform static analysis to check for potentially harmful code."""
23
+ try:
24
+ parsed = ast.parse(code)
25
+
26
+ # Check for banned imports
27
+ imports = []
28
+ for node in ast.walk(parsed):
29
+ if isinstance(node, ast.Import):
30
+ imports.extend(n.name for n in node.names)
31
+ elif isinstance(node, ast.ImportFrom):
32
+ imports.append(node.module)
33
+
34
+ dangerous_imports = [imp for imp in imports if any(
35
+ banned in imp for banned in self.banned_modules)]
36
+
37
+ if dangerous_imports:
38
+ return {
39
+ "safe": False,
40
+ "reason": f"Potentially harmful imports detected: {dangerous_imports}"
41
+ }
42
+
43
+ # Check for exec/eval usage
44
+ for node in ast.walk(parsed):
45
+ if isinstance(node, ast.Call) and hasattr(node, 'func'):
46
+ if isinstance(node.func, ast.Name) and node.func.id in ['exec', 'eval']:
47
+ return {
48
+ "safe": False,
49
+ "reason": "Contains exec() or eval() calls"
50
+ }
51
+
52
+ return {"safe": True}
53
+ except SyntaxError:
54
+ return {"safe": False, "reason": "Invalid Python syntax"}
55
+
56
+ def _timeout_handler(self, signum, frame):
57
+ """Handler for timeout signal."""
58
+ raise TimeoutError("Code execution timed out")
59
+
60
+ def _extract_numeric_value(self, output: str) -> Optional[Union[int, float]]:
61
+ """Extract the final numeric value from output."""
62
+ # First try to get the last line that's a number
63
+ lines = [line.strip() for line in output.strip().split('\n') if line.strip()]
64
+
65
+ for line in reversed(lines):
66
+ # Try direct conversion first
67
+ try:
68
+ return float(line)
69
+ except ValueError:
70
+ pass
71
+
72
+ # Try to extract numeric portion if embedded in text
73
+ numeric_match = re.search(r'[-+]?\d*\.?\d+', line)
74
+ if numeric_match:
75
+ try:
76
+ return float(numeric_match.group())
77
+ except ValueError:
78
+ pass
79
+
80
+ return None
81
+
82
+ def execute_file(self, filepath: str) -> Dict[str, Any]:
83
+ """Execute Python code from file and capture the output."""
84
+ try:
85
+ with open(filepath, 'r') as file:
86
+ code = file.read()
87
+
88
+ return self.execute_code(code)
89
+
90
+ except FileNotFoundError:
91
+ return {"success": False, "error": f"File not found: {filepath}"}
92
+ except Exception as e:
93
+ return {
94
+ "success": False,
95
+ "error": f"Error reading file: {str(e)}"
96
+ }
97
+
98
+ def execute_code(self, code: str) -> Dict[str, Any]:
99
+ """Execute Python code string and capture the output."""
100
+ # Check code safety first
101
+ safety_check = self._analyze_code_safety(code)
102
+ if not safety_check["safe"]:
103
+ return {
104
+ "success": False,
105
+ "error": f"Security check failed: {safety_check['reason']}"
106
+ }
107
+
108
+ # Prepare a clean globals dictionary with minimal safe functions
109
+ safe_globals = {
110
+ 'abs': abs,
111
+ 'all': all,
112
+ 'any': any,
113
+ 'bin': bin,
114
+ 'bool': bool,
115
+ 'chr': chr,
116
+ 'complex': complex,
117
+ 'dict': dict,
118
+ 'divmod': divmod,
119
+ 'enumerate': enumerate,
120
+ 'filter': filter,
121
+ 'float': float,
122
+ 'format': format,
123
+ 'frozenset': frozenset,
124
+ 'hash': hash,
125
+ 'hex': hex,
126
+ 'int': int,
127
+ 'isinstance': isinstance,
128
+ 'issubclass': issubclass,
129
+ 'len': len,
130
+ 'list': list,
131
+ 'map': map,
132
+ 'max': max,
133
+ 'min': min,
134
+ 'oct': oct,
135
+ 'ord': ord,
136
+ 'pow': pow,
137
+ 'print': print,
138
+ 'range': range,
139
+ 'reversed': reversed,
140
+ 'round': round,
141
+ 'set': set,
142
+ 'sorted': sorted,
143
+ 'str': str,
144
+ 'sum': sum,
145
+ 'tuple': tuple,
146
+ 'zip': zip,
147
+ '__builtins__': {}, # Empty builtins for extra security
148
+ }
149
+
150
+ # Add math module functions, commonly needed
151
+ try:
152
+ import math
153
+ for name in dir(math):
154
+ if not name.startswith('_'):
155
+ safe_globals[name] = getattr(math, name)
156
+ except ImportError:
157
+ pass
158
+
159
+ # Capture output using StringIO
160
+ output_buffer = io.StringIO()
161
+
162
+ # Set timeout handler
163
+ old_handler = signal.getsignal(signal.SIGALRM)
164
+ signal.signal(signal.SIGALRM, self._timeout_handler)
165
+ signal.alarm(self.timeout)
166
+
167
+ try:
168
+ # Execute code with stdout/stderr capture
169
+ with contextlib.redirect_stdout(output_buffer):
170
+ with contextlib.redirect_stderr(output_buffer):
171
+ exec(code, safe_globals)
172
+
173
+ output = output_buffer.getvalue()
174
+ if len(output) > self.max_output_size:
175
+ output = output[:self.max_output_size] + "... [output truncated]"
176
+
177
+ # Extract the numeric value
178
+ numeric_result = self._extract_numeric_value(output)
179
+
180
+ return {
181
+ "success": True,
182
+ "raw_output": output,
183
+ "numeric_value": numeric_result,
184
+ "has_numeric_result": numeric_result is not None
185
+ }
186
+
187
+ except TimeoutError:
188
+ return {
189
+ "success": False,
190
+ "error": f"Code execution timed out after {self.timeout} seconds"
191
+ }
192
+ except Exception as e:
193
+ error_info = traceback.format_exc()
194
+ return {
195
+ "success": False,
196
+ "error": str(e),
197
+ "traceback": error_info,
198
+ "raw_output": output_buffer.getvalue()
199
+ }
200
+ finally:
201
+ # Reset alarm and signal handler
202
+ signal.alarm(0)
203
+ signal.signal(signal.SIGALRM, old_handler)
204
+
205
+
206
+ # Example usage
207
+ if __name__ == "__main__":
208
+ executor = CodeExecutionTool()
209
+ result = executor.execute_code("""
210
+ # Example code that calculates a value
211
+ total = 0
212
+ for i in range(10):
213
+ total += i * 2
214
+ print(f"The result is {total}")
215
+ """)
216
+ print(result)
tests/__init__.py ADDED
File without changes
tests/test_python_tool.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import sys
3
+ import os
4
+ from pathlib import Path
5
+
6
+ # Add the parent directory to sys.path to find the src module
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+
9
+ from src.python_tool import CodeExecutionTool
10
+
11
+ class TestCodeExecutionTool(unittest.TestCase):
12
+ def setUp(self):
13
+ self.code_tool = CodeExecutionTool()
14
+
15
+ def test_analyze_code_safety_imports(self):
16
+ """Test that the tool detects banned imports."""
17
+ code_with_banned_import = "import os"
18
+ result = self.code_tool._analyze_code_safety(code_with_banned_import)
19
+ self.assertFalse(result["safe"])
20
+ self.assertIn("os", result["reason"])
21
+
22
+ def test_analyze_code_safety_exec_eval(self):
23
+ """Test that the tool detects exec and eval usage."""
24
+ code_with_exec = "exec('print(1)')"
25
+ result = self.code_tool._analyze_code_safety(code_with_exec)
26
+ self.assertFalse(result["safe"])
27
+ self.assertIn("exec()", result["reason"])
28
+
29
+ def test_analyze_code_safety_valid_code(self):
30
+ """Test that the tool allows safe code."""
31
+ safe_code = "print(1 + 1)"
32
+ result = self.code_tool._analyze_code_safety(safe_code)
33
+ self.assertTrue(result["safe"])
34
+
35
+ def test_common_question_reverse_word(self):
36
+ """Test the reverse word question from common_questions.json."""
37
+ question = ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI"
38
+ expected_answer = "Right"
39
+ reversed_question = question[::-1]
40
+ self.assertEqual(reversed_question, "If you understand this sentence, write the opposite of the word \"left\" as the answer.")
41
+ self.assertEqual(expected_answer, "Right")
42
+
43
+ if __name__ == "__main__":
44
+ unittest.main()