Spaces:
Runtime error
Runtime error
Yago Bolivar commited on
Commit ·
0d2816b
1
Parent(s): b7e30dd
feat: implement CodeExecutionTool for safe code execution and output extraction
Browse filestest: add unit tests for CodeExecutionTool's safety analysis and functionality
- src/python_tool.py +216 -0
- tests/__init__.py +0 -0
- tests/test_python_tool.py +44 -0
src/python_tool.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import contextlib
|
| 3 |
+
import io
|
| 4 |
+
import signal
|
| 5 |
+
import re
|
| 6 |
+
import traceback
|
| 7 |
+
from typing import Dict, Any, Optional, Union, List
|
| 8 |
+
|
| 9 |
+
class CodeExecutionTool:
|
| 10 |
+
"""Tool to safely execute Python code files and extract numeric outputs."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, timeout: int = 5, max_output_size: int = 10000):
|
| 13 |
+
self.timeout = timeout # Maximum execution time in seconds
|
| 14 |
+
self.max_output_size = max_output_size
|
| 15 |
+
# Restricted imports - add more as needed
|
| 16 |
+
self.banned_modules = [
|
| 17 |
+
'os', 'subprocess', 'sys', 'builtins', 'importlib', 'eval',
|
| 18 |
+
'pickle', 'requests', 'socket', 'shutil'
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
def _analyze_code_safety(self, code: str) -> Dict[str, Any]:
|
| 22 |
+
"""Perform static analysis to check for potentially harmful code."""
|
| 23 |
+
try:
|
| 24 |
+
parsed = ast.parse(code)
|
| 25 |
+
|
| 26 |
+
# Check for banned imports
|
| 27 |
+
imports = []
|
| 28 |
+
for node in ast.walk(parsed):
|
| 29 |
+
if isinstance(node, ast.Import):
|
| 30 |
+
imports.extend(n.name for n in node.names)
|
| 31 |
+
elif isinstance(node, ast.ImportFrom):
|
| 32 |
+
imports.append(node.module)
|
| 33 |
+
|
| 34 |
+
dangerous_imports = [imp for imp in imports if any(
|
| 35 |
+
banned in imp for banned in self.banned_modules)]
|
| 36 |
+
|
| 37 |
+
if dangerous_imports:
|
| 38 |
+
return {
|
| 39 |
+
"safe": False,
|
| 40 |
+
"reason": f"Potentially harmful imports detected: {dangerous_imports}"
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Check for exec/eval usage
|
| 44 |
+
for node in ast.walk(parsed):
|
| 45 |
+
if isinstance(node, ast.Call) and hasattr(node, 'func'):
|
| 46 |
+
if isinstance(node.func, ast.Name) and node.func.id in ['exec', 'eval']:
|
| 47 |
+
return {
|
| 48 |
+
"safe": False,
|
| 49 |
+
"reason": "Contains exec() or eval() calls"
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
return {"safe": True}
|
| 53 |
+
except SyntaxError:
|
| 54 |
+
return {"safe": False, "reason": "Invalid Python syntax"}
|
| 55 |
+
|
| 56 |
+
def _timeout_handler(self, signum, frame):
|
| 57 |
+
"""Handler for timeout signal."""
|
| 58 |
+
raise TimeoutError("Code execution timed out")
|
| 59 |
+
|
| 60 |
+
def _extract_numeric_value(self, output: str) -> Optional[Union[int, float]]:
|
| 61 |
+
"""Extract the final numeric value from output."""
|
| 62 |
+
# First try to get the last line that's a number
|
| 63 |
+
lines = [line.strip() for line in output.strip().split('\n') if line.strip()]
|
| 64 |
+
|
| 65 |
+
for line in reversed(lines):
|
| 66 |
+
# Try direct conversion first
|
| 67 |
+
try:
|
| 68 |
+
return float(line)
|
| 69 |
+
except ValueError:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
# Try to extract numeric portion if embedded in text
|
| 73 |
+
numeric_match = re.search(r'[-+]?\d*\.?\d+', line)
|
| 74 |
+
if numeric_match:
|
| 75 |
+
try:
|
| 76 |
+
return float(numeric_match.group())
|
| 77 |
+
except ValueError:
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
def execute_file(self, filepath: str) -> Dict[str, Any]:
|
| 83 |
+
"""Execute Python code from file and capture the output."""
|
| 84 |
+
try:
|
| 85 |
+
with open(filepath, 'r') as file:
|
| 86 |
+
code = file.read()
|
| 87 |
+
|
| 88 |
+
return self.execute_code(code)
|
| 89 |
+
|
| 90 |
+
except FileNotFoundError:
|
| 91 |
+
return {"success": False, "error": f"File not found: {filepath}"}
|
| 92 |
+
except Exception as e:
|
| 93 |
+
return {
|
| 94 |
+
"success": False,
|
| 95 |
+
"error": f"Error reading file: {str(e)}"
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
def execute_code(self, code: str) -> Dict[str, Any]:
|
| 99 |
+
"""Execute Python code string and capture the output."""
|
| 100 |
+
# Check code safety first
|
| 101 |
+
safety_check = self._analyze_code_safety(code)
|
| 102 |
+
if not safety_check["safe"]:
|
| 103 |
+
return {
|
| 104 |
+
"success": False,
|
| 105 |
+
"error": f"Security check failed: {safety_check['reason']}"
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# Prepare a clean globals dictionary with minimal safe functions
|
| 109 |
+
safe_globals = {
|
| 110 |
+
'abs': abs,
|
| 111 |
+
'all': all,
|
| 112 |
+
'any': any,
|
| 113 |
+
'bin': bin,
|
| 114 |
+
'bool': bool,
|
| 115 |
+
'chr': chr,
|
| 116 |
+
'complex': complex,
|
| 117 |
+
'dict': dict,
|
| 118 |
+
'divmod': divmod,
|
| 119 |
+
'enumerate': enumerate,
|
| 120 |
+
'filter': filter,
|
| 121 |
+
'float': float,
|
| 122 |
+
'format': format,
|
| 123 |
+
'frozenset': frozenset,
|
| 124 |
+
'hash': hash,
|
| 125 |
+
'hex': hex,
|
| 126 |
+
'int': int,
|
| 127 |
+
'isinstance': isinstance,
|
| 128 |
+
'issubclass': issubclass,
|
| 129 |
+
'len': len,
|
| 130 |
+
'list': list,
|
| 131 |
+
'map': map,
|
| 132 |
+
'max': max,
|
| 133 |
+
'min': min,
|
| 134 |
+
'oct': oct,
|
| 135 |
+
'ord': ord,
|
| 136 |
+
'pow': pow,
|
| 137 |
+
'print': print,
|
| 138 |
+
'range': range,
|
| 139 |
+
'reversed': reversed,
|
| 140 |
+
'round': round,
|
| 141 |
+
'set': set,
|
| 142 |
+
'sorted': sorted,
|
| 143 |
+
'str': str,
|
| 144 |
+
'sum': sum,
|
| 145 |
+
'tuple': tuple,
|
| 146 |
+
'zip': zip,
|
| 147 |
+
'__builtins__': {}, # Empty builtins for extra security
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
# Add math module functions, commonly needed
|
| 151 |
+
try:
|
| 152 |
+
import math
|
| 153 |
+
for name in dir(math):
|
| 154 |
+
if not name.startswith('_'):
|
| 155 |
+
safe_globals[name] = getattr(math, name)
|
| 156 |
+
except ImportError:
|
| 157 |
+
pass
|
| 158 |
+
|
| 159 |
+
# Capture output using StringIO
|
| 160 |
+
output_buffer = io.StringIO()
|
| 161 |
+
|
| 162 |
+
# Set timeout handler
|
| 163 |
+
old_handler = signal.getsignal(signal.SIGALRM)
|
| 164 |
+
signal.signal(signal.SIGALRM, self._timeout_handler)
|
| 165 |
+
signal.alarm(self.timeout)
|
| 166 |
+
|
| 167 |
+
try:
|
| 168 |
+
# Execute code with stdout/stderr capture
|
| 169 |
+
with contextlib.redirect_stdout(output_buffer):
|
| 170 |
+
with contextlib.redirect_stderr(output_buffer):
|
| 171 |
+
exec(code, safe_globals)
|
| 172 |
+
|
| 173 |
+
output = output_buffer.getvalue()
|
| 174 |
+
if len(output) > self.max_output_size:
|
| 175 |
+
output = output[:self.max_output_size] + "... [output truncated]"
|
| 176 |
+
|
| 177 |
+
# Extract the numeric value
|
| 178 |
+
numeric_result = self._extract_numeric_value(output)
|
| 179 |
+
|
| 180 |
+
return {
|
| 181 |
+
"success": True,
|
| 182 |
+
"raw_output": output,
|
| 183 |
+
"numeric_value": numeric_result,
|
| 184 |
+
"has_numeric_result": numeric_result is not None
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
except TimeoutError:
|
| 188 |
+
return {
|
| 189 |
+
"success": False,
|
| 190 |
+
"error": f"Code execution timed out after {self.timeout} seconds"
|
| 191 |
+
}
|
| 192 |
+
except Exception as e:
|
| 193 |
+
error_info = traceback.format_exc()
|
| 194 |
+
return {
|
| 195 |
+
"success": False,
|
| 196 |
+
"error": str(e),
|
| 197 |
+
"traceback": error_info,
|
| 198 |
+
"raw_output": output_buffer.getvalue()
|
| 199 |
+
}
|
| 200 |
+
finally:
|
| 201 |
+
# Reset alarm and signal handler
|
| 202 |
+
signal.alarm(0)
|
| 203 |
+
signal.signal(signal.SIGALRM, old_handler)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# Example usage
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
executor = CodeExecutionTool()
|
| 209 |
+
result = executor.execute_code("""
|
| 210 |
+
# Example code that calculates a value
|
| 211 |
+
total = 0
|
| 212 |
+
for i in range(10):
|
| 213 |
+
total += i * 2
|
| 214 |
+
print(f"The result is {total}")
|
| 215 |
+
""")
|
| 216 |
+
print(result)
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_python_tool.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
# Add the parent directory to sys.path to find the src module
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
|
| 9 |
+
from src.python_tool import CodeExecutionTool
|
| 10 |
+
|
| 11 |
+
class TestCodeExecutionTool(unittest.TestCase):
|
| 12 |
+
def setUp(self):
|
| 13 |
+
self.code_tool = CodeExecutionTool()
|
| 14 |
+
|
| 15 |
+
def test_analyze_code_safety_imports(self):
|
| 16 |
+
"""Test that the tool detects banned imports."""
|
| 17 |
+
code_with_banned_import = "import os"
|
| 18 |
+
result = self.code_tool._analyze_code_safety(code_with_banned_import)
|
| 19 |
+
self.assertFalse(result["safe"])
|
| 20 |
+
self.assertIn("os", result["reason"])
|
| 21 |
+
|
| 22 |
+
def test_analyze_code_safety_exec_eval(self):
|
| 23 |
+
"""Test that the tool detects exec and eval usage."""
|
| 24 |
+
code_with_exec = "exec('print(1)')"
|
| 25 |
+
result = self.code_tool._analyze_code_safety(code_with_exec)
|
| 26 |
+
self.assertFalse(result["safe"])
|
| 27 |
+
self.assertIn("exec()", result["reason"])
|
| 28 |
+
|
| 29 |
+
def test_analyze_code_safety_valid_code(self):
|
| 30 |
+
"""Test that the tool allows safe code."""
|
| 31 |
+
safe_code = "print(1 + 1)"
|
| 32 |
+
result = self.code_tool._analyze_code_safety(safe_code)
|
| 33 |
+
self.assertTrue(result["safe"])
|
| 34 |
+
|
| 35 |
+
def test_common_question_reverse_word(self):
|
| 36 |
+
"""Test the reverse word question from common_questions.json."""
|
| 37 |
+
question = ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI"
|
| 38 |
+
expected_answer = "Right"
|
| 39 |
+
reversed_question = question[::-1]
|
| 40 |
+
self.assertEqual(reversed_question, "If you understand this sentence, write the opposite of the word \"left\" as the answer.")
|
| 41 |
+
self.assertEqual(expected_answer, "Right")
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
unittest.main()
|