Spaces:
Runtime error
Runtime error
| import ast | |
| import contextlib | |
| import io | |
| import logging | |
| import os | |
| import re | |
| import signal | |
| import traceback | |
| from typing import Any, Dict, List, Optional, Union | |
| from smolagents.tools import Tool | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class CodeExecutionTool(Tool): | |
| """ | |
| Executes Python code snippets safely with timeout protection. | |
| Useful for data processing, analysis, and transformation. | |
| Includes special utilities for web data processing and robust error handling. | |
| """ | |
| name = "python_executor" | |
| description = "Safely executes Python code with enhancements for data processing, parsing, and error recovery." | |
| inputs = { | |
| "code_string": {"type": "string", "description": "The Python code to execute.", "nullable": True}, | |
| "filepath": {"type": "string", "description": "Path to a Python file to execute.", "nullable": True}, | |
| } | |
| outputs = { | |
| "success": {"type": "boolean", "description": "Whether the code executed successfully."}, | |
| "output": {"type": "string", "description": "The captured stdout or formatted result.", "nullable": True}, | |
| "error": {"type": "string", "description": "Error message if execution failed.", "nullable": True}, | |
| "result_value": {"type": "any", "description": "The final expression value if applicable.", "nullable": True}, | |
| } | |
| output_type = "object" | |
| def __init__(self, timeout: int = 10, max_output_size: int = 20000, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.timeout = timeout | |
| self.max_output_size = max_output_size | |
| self.banned_modules = [ | |
| "os", | |
| "subprocess", | |
| "sys", | |
| "builtins", | |
| "importlib", | |
| "pickle", | |
| "requests", | |
| "socket", | |
| "shutil", | |
| "ctypes", | |
| "multiprocessing", | |
| ] | |
| self.is_initialized = True | |
| self._utility_functions = self._get_utility_functions() | |
| def _get_utility_functions(self) -> str: | |
| """Define utility functions that will be available in the executed code.""" | |
| return ''' | |
| def extract_pattern(text, pattern, group=0, all_matches=False): | |
| """ | |
| Extract data using regex pattern from text. | |
| Args: | |
| text (str): Text to search in | |
| pattern (str): Regex pattern to use | |
| group (int): Capture group to return (default 0 - entire match) | |
| all_matches (bool): If True, return all matches, otherwise just first | |
| Returns: | |
| Matched string(s) or None if no match | |
| """ | |
| import re | |
| if not text or not pattern: | |
| print("Warning: Empty text or pattern provided to extract_pattern") | |
| return None | |
| try: | |
| matches = re.finditer(pattern, text) | |
| results = [m.group(group) if group < len(m.groups())+1 else m.group(0) for m in matches] | |
| if not results: | |
| print(f"No matches found for pattern '{pattern}'") | |
| return None | |
| return results if all_matches else results[0] | |
| except Exception as e: | |
| print(f"Error in extract_pattern: {e}") | |
| return None | |
| def clean_text(text, remove_extra_whitespace=True, remove_special_chars=False): | |
| """ | |
| Clean text by removing extra whitespace and optionally special characters. | |
| Args: | |
| text (str): Text to clean | |
| remove_extra_whitespace (bool): If True, replace multiple spaces with single space | |
| remove_special_chars (bool): If True, remove special characters | |
| Returns: | |
| Cleaned string | |
| """ | |
| import re | |
| if not text: | |
| return "" | |
| # Replace newlines and tabs with spaces | |
| text = re.sub(r"[\\n\\t\\r]+", " ", text) | |
| if remove_special_chars: | |
| # Keep only alphanumeric, spaces, and basic punctuation | |
| text = re.sub(r"[^\w\s.,;:!?\'\"()-]", "", text) | |
| if remove_extra_whitespace: | |
| # Replace multiple spaces with single space | |
| text = re.sub(r"\\s+", " ", text) | |
| return text.strip() | |
| def parse_table_text(table_text): | |
| """ | |
| Parse table-like text into list of rows. | |
| Args: | |
| table_text (str): Text containing table-like data | |
| Returns: | |
| List of rows (each row is a list of cells) | |
| """ | |
| import re | |
| rows = [] | |
| lines = table_text.strip().split("\\n") | |
| for line in lines: | |
| # Skip empty lines | |
| if not line.strip(): | |
| continue | |
| # Split by whitespace or common separators | |
| cells = re.split(r"\\s{2,}|\\t+|\\|+", line.strip()) | |
| # Clean up cells | |
| cells = [cell.strip() for cell in cells if cell.strip()] | |
| if cells: | |
| rows.append(cells) | |
| # Print parsing result for debugging | |
| print(f"Parsed {len(rows)} rows from table text") | |
| if rows and len(rows) > 0: | |
| print(f"First row (columns: {len(rows[0])}): {rows[0]}") | |
| return rows | |
| def safe_float(text): | |
| """ | |
| Safely convert text to float, handling various formats. | |
| Args: | |
| text (str): Text to convert | |
| Returns: | |
| float or None if conversion fails | |
| """ | |
| import re | |
| if not text: | |
| return None | |
| # Remove currency symbols, commas in numbers, etc. | |
| text = re.sub(r"[^0-9.-]", "", str(text)) | |
| try: | |
| return float(text) | |
| except ValueError: | |
| print(f"Warning: Could not convert '{text}' to float") | |
| return None | |
| ''' | |
| def _analyze_code_safety(self, code: str) -> Dict[str, Any]: | |
| """Perform static analysis to check for potentially harmful code.""" | |
| try: | |
| parsed = ast.parse(code) | |
| # Check for banned imports | |
| imports = [] | |
| for node in ast.walk(parsed): | |
| if isinstance(node, ast.Import): | |
| imports.extend(n.name for n in node.names) | |
| elif isinstance(node, ast.ImportFrom): | |
| # Ensure node.module is not None before attempting to check against banned_modules | |
| if node.module and any(banned in node.module for banned in self.banned_modules): | |
| imports.append(node.module) | |
| dangerous_imports = [ | |
| imp for imp in imports | |
| if imp and any(banned in imp for banned in self.banned_modules) | |
| ] | |
| if dangerous_imports: | |
| return { | |
| "safe": False, | |
| "reason": f"Potentially harmful imports detected: {dangerous_imports}", | |
| } | |
| # Check for exec/eval usage | |
| for node in ast.walk(parsed): | |
| if isinstance(node, ast.Call) and hasattr(node, "func"): | |
| if isinstance(node.func, ast.Name) and node.func.id in ["exec", "eval"]: | |
| return {"safe": False, "reason": "Contains exec() or eval() calls"} | |
| return {"safe": True} | |
| except SyntaxError: | |
| return {"safe": False, "reason": "Invalid Python syntax"} | |
| def _timeout_handler(self, signum, frame): | |
| """Handler for timeout signal.""" | |
| raise TimeoutError(f"Code execution timed out after {self.timeout} seconds") | |
| def _extract_numeric_value(self, output: str) -> Optional[Union[int, float]]: | |
| """Extract the final numeric value from output.""" | |
| if not output: | |
| return None | |
| # Look for the last line that contains a number | |
| lines = output.strip().split("\n") | |
| for line in reversed(lines): | |
| # Try to interpret it as a pure number | |
| line = line.strip() | |
| try: | |
| return float(line) if "." in line else int(line) | |
| except ValueError: | |
| # Not a pure number, try to extract numbers with regex | |
| match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?$", line) | |
| if match: | |
| num_str = match.group(0) | |
| try: | |
| return float(num_str) if "." in num_str else int(num_str) | |
| except ValueError: | |
| pass | |
| return None | |
| def forward(self, code_string: Optional[str] = None, filepath: Optional[str] = None) -> Dict[str, Any]: | |
| """Main entry point for code execution.""" | |
| if not code_string and not filepath: | |
| return {"success": False, "error": "No code string or filepath provided."} | |
| if code_string and filepath: | |
| return {"success": False, "error": "Provide either a code string or a filepath, not both."} | |
| code_to_execute = "" | |
| if filepath: | |
| if not os.path.exists(filepath): | |
| return {"success": False, "error": f"File not found: {filepath}"} | |
| if not filepath.endswith(".py"): | |
| return {"success": False, "error": f"File is not a Python file: {filepath}"} | |
| try: | |
| with open(filepath, "r") as file: | |
| code_to_execute = file.read() | |
| except Exception as e: | |
| return {"success": False, "error": f"Error reading file {filepath}: {str(e)}"} | |
| else: | |
| code_to_execute = code_string | |
| # Inject utility functions | |
| enhanced_code = self._utility_functions + "\n\n" + code_to_execute | |
| return self._execute_actual_code(enhanced_code) | |
| def _execute_actual_code(self, code: str) -> Dict[str, Any]: | |
| """Execute Python code and capture the output or error.""" | |
| safety_check = self._analyze_code_safety(code) | |
| if not safety_check["safe"]: | |
| return {"success": False, "error": f"Safety check failed: {safety_check['reason']}"} | |
| # Capture stdout and execute the code with a timeout | |
| stdout_buffer = io.StringIO() | |
| result_value = None | |
| try: | |
| # Set timeout handler | |
| signal.signal(signal.SIGALRM, self._timeout_handler) | |
| signal.alarm(self.timeout) | |
| # Execute code and capture stdout | |
| with contextlib.redirect_stdout(stdout_buffer): | |
| # Execute the code within a new dictionary for local variables | |
| local_vars = {} | |
| exec(code, {}, local_vars) | |
| # Try to extract the result from common variable names | |
| for var_name in ["result", "answer", "output", "value", "final_result", "data"]: | |
| if var_name in local_vars: | |
| result_value = local_vars[var_name] | |
| break | |
| # Reset the alarm | |
| signal.alarm(0) | |
| # Get the captured output | |
| output = stdout_buffer.getvalue() | |
| if len(output) > self.max_output_size: | |
| output = output[: self.max_output_size] + f"\n... (output truncated, exceeded {self.max_output_size} characters)" | |
| # If no result_value was found, try to extract a numeric value from the output | |
| if result_value is None: | |
| result_value = self._extract_numeric_value(output) | |
| return {"success": True, "output": output, "result_value": result_value} | |
| except TimeoutError: | |
| signal.alarm(0) | |
| return {"success": False, "error": f"Code execution timed out after {self.timeout} seconds"} | |
| except Exception as e: | |
| signal.alarm(0) | |
| trace = traceback.format_exc() | |
| error_msg = f"Error executing code: {str(e)}\n{trace}" | |
| return {"success": False, "error": error_msg} | |
| finally: | |
| # Ensure the alarm is reset | |
| signal.alarm(0) | |
| # Helper methods for backward compatibility | |
| def execute_file(self, filepath: str) -> Dict[str, Any]: | |
| """Helper to execute Python code from file.""" | |
| return self.forward(filepath=filepath) | |
| def execute_code(self, code: str) -> Dict[str, Any]: | |
| """Helper to execute Python code from a string.""" | |
| return self.forward(code_string=code) | |
| def _run_tests(): | |
| """Run comprehensive tests for the CodeExecutionTool.""" | |
| tool = CodeExecutionTool(timeout=5) | |
| test_results = [] | |
| # Test 1: Safe code string | |
| safe_code = "print('Hello from safe code!'); result = 10 * 2; print(result)" | |
| print("\n--- Test 1: Safe Code String ---") | |
| result1 = tool.forward(code_string=safe_code) | |
| print(result1) | |
| test_results.append(result1["success"] and "Hello from safe code!" in result1["output"]) | |
| # Test 2: Code with an error | |
| error_code = "print(1/0)" | |
| print("\n--- Test 2: Code with Error ---") | |
| result2 = tool.forward(code_string=error_code) | |
| print(result2) | |
| test_results.append(not result2["success"] and "ZeroDivisionError" in result2["error"]) | |
| # Test 3: Code with a banned import | |
| unsafe_import_code = "import os; print(os.getcwd())" | |
| print("\n--- Test 3: Unsafe Import ---") | |
| result3 = tool.forward(code_string=unsafe_import_code) | |
| print(result3) | |
| test_results.append(not result3["success"] and "Safety check failed" in result3["error"]) | |
| # Test 4: Timeout | |
| timeout_code = "import time; time.sleep(10); print('Done sleeping')" | |
| print("\n--- Test 4: Timeout ---") | |
| result4 = tool.forward(code_string=timeout_code) | |
| print(result4) | |
| test_results.append(not result4["success"] and "timed out" in result4["error"]) | |
| # Test 5: Execute from file | |
| test_file_content = "print('Hello from file!'); x = 5; y = 7; print(f'Sum: {x+y}')" | |
| test_filename = "temp_test_script.py" | |
| with open(test_filename, "w") as f: | |
| f.write(test_file_content) | |
| print("\n--- Test 5: Execute from File ---") | |
| result5 = tool.forward(filepath=test_filename) | |
| print(result5) | |
| test_results.append(result5["success"] and "Hello from file!" in result5["output"]) | |
| os.remove(test_filename) | |
| # Test 6: File not found | |
| print("\n--- Test 6: File Not Found ---") | |
| result6 = tool.forward(filepath="non_existent_script.py") | |
| print(result6) | |
| test_results.append(not result6["success"] and "File not found" in result6["error"]) | |
| # Test 7: Provide both code_string and filepath | |
| print("\n--- Test 7: Both code_string and filepath ---") | |
| result7 = tool.forward(code_string="print('hello')", filepath="dummy.py") | |
| print(result7) | |
| test_results.append( | |
| not result7["success"] | |
| and "Provide either a code string or a filepath, not both" in result7["error"] | |
| ) | |
| # Test 8: Provide neither | |
| print("\n--- Test 8: Neither code_string nor filepath ---") | |
| result8 = tool.forward() | |
| print(result8) | |
| test_results.append(not result8["success"] and "No code string or filepath provided" in result8["error"]) | |
| # Test 9: Function definition and call | |
| func_def_code = "def my_func(a, b): return a + b; print(my_func(3,4))" | |
| print("\n--- Test 9: Function Definition and Call ---") | |
| result9 = tool.forward(code_string=func_def_code) | |
| print(result9) | |
| test_results.append(result9["success"] and "7" in result9["output"]) | |
| print(f"\nTests passed: {sum(test_results)}/{len(test_results)}") | |
| if all(test_results): | |
| print("All tests passed!") | |
| else: | |
| print("Some tests failed - check output for details.") | |
| if __name__ == "__main__": | |
| _run_tests() |