Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Sandboxed Python code executor for the REPL environment. | |
| Uses smolagents.LocalPythonExecutor as the backend for battle-tested sandboxed | |
| execution, with RLM-specific features on top: | |
| - Context loading (set_context) | |
| - Variable access (get_variable, list_variables) | |
| - Function injection (inject_function for llm_query, llm_batch) | |
| - Output capped at 8,192 characters per turn (configurable) | |
| - Persistent namespace across code blocks | |
| """ | |
| import json | |
| import logging | |
| import time | |
| import traceback | |
| from collections.abc import Callable | |
| from typing import Any, Dict, List, Optional | |
| from smolagents import LocalPythonExecutor | |
| logger = logging.getLogger(__name__) | |
| logger.addHandler(logging.NullHandler()) | |
| class PythonExecutor: | |
| """Sandboxed Python code executor with persistent namespace. | |
| Wraps smolagents.LocalPythonExecutor with RLM-specific features: | |
| - Context loading for RLM tasks | |
| - Variable tracking for observation | |
| - Function injection for llm_query, llm_batch | |
| - Configurable output length limit (default 8192 chars per Prime Intellect) | |
| """ | |
| def __init__( | |
| self, | |
| max_output_length: int = 8192, | |
| timeout: float = 30.0, | |
| allowed_imports: Optional[List[str]] = None, | |
| ): | |
| """Initialize the executor. | |
| Args: | |
| max_output_length: Maximum characters for stdout/stderr (default 8192) | |
| timeout: Execution timeout in seconds (passed to LocalPythonExecutor) | |
| allowed_imports: List of allowed module names for import | |
| """ | |
| self.max_output_length = max_output_length | |
| self.timeout = timeout | |
| # Default allowed imports for RLM tasks | |
| default_imports = [ | |
| "re", | |
| "json", | |
| "math", | |
| "random", | |
| "collections", | |
| "itertools", | |
| "functools", | |
| "operator", | |
| "string", | |
| "textwrap", | |
| "difflib", | |
| "statistics", | |
| "decimal", | |
| "fractions", | |
| "datetime", | |
| "copy", | |
| "pprint", | |
| "typing", | |
| "dataclasses", | |
| "enum", | |
| "bisect", | |
| "heapq", | |
| "array", | |
| "struct", | |
| "base64", | |
| "hashlib", | |
| "hmac", | |
| "uuid", | |
| ] | |
| self.allowed_imports = allowed_imports or default_imports | |
| # Initialize the smolagents executor | |
| self._executor = LocalPythonExecutor( | |
| additional_authorized_imports=self.allowed_imports | |
| ) | |
| # Track variables we've set (for list_variables) | |
| self._user_variables: set[str] = set() | |
| # Track all registered tools (accumulative) | |
| self._registered_tools: Dict[str, Callable[..., Any]] = {} | |
| # Register helper utilities | |
| self._register_helpers() | |
| def _register_helpers(self) -> None: | |
| """Register helper functions with the executor.""" | |
| tools = { | |
| "format_exc": traceback.format_exc, | |
| "safe_json_dumps": lambda obj: json.dumps(obj, default=lambda o: repr(o)), | |
| } | |
| self._registered_tools.update(tools) | |
| self._sync_tools() | |
| def _sync_tools(self) -> None: | |
| """Sync all registered tools with the executor.""" | |
| try: | |
| self._executor.send_tools(self._registered_tools) | |
| except Exception: | |
| logger.debug("LocalPythonExecutor.send_tools failed; continuing without extra tools", exc_info=True) | |
| def set_context(self, context: str, variable_name: str = "context") -> None: | |
| """Load context into namespace as a variable. | |
| Args: | |
| context: The context string to load | |
| variable_name: Name of the variable (default "context") | |
| """ | |
| self.set_variable(variable_name, context) | |
| def set_variable(self, name: str, value: Any) -> None: | |
| """Set a variable in the namespace. | |
| Args: | |
| name: Variable name | |
| value: Variable value | |
| """ | |
| # Access the executor's internal state to set variables | |
| if hasattr(self._executor, 'state'): | |
| self._executor.state[name] = value | |
| else: | |
| # Fallback: inject via code execution | |
| # Store reference for later retrieval | |
| self._executor._injected_vars = getattr(self._executor, '_injected_vars', {}) | |
| self._executor._injected_vars[name] = value | |
| # Use send_tools to make it available | |
| try: | |
| self._executor.send_tools({name: lambda v=value: v}) | |
| except Exception: | |
| pass | |
| self._user_variables.add(name) | |
| def get_variable(self, name: str) -> Optional[Any]: | |
| """Retrieve a variable from namespace. | |
| Args: | |
| name: Variable name | |
| Returns: | |
| The variable value or None if not found | |
| """ | |
| # Try to get from executor's state | |
| if hasattr(self._executor, 'state'): | |
| return self._executor.state.get(name) | |
| # Fallback to injected vars | |
| if hasattr(self._executor, '_injected_vars'): | |
| return self._executor._injected_vars.get(name) | |
| return None | |
| def list_variables(self) -> List[str]: | |
| """List non-private variables in namespace. | |
| Returns: | |
| List of variable names (excluding private and builtins) | |
| """ | |
| variables = set() | |
| # Get from executor's state | |
| if hasattr(self._executor, 'state'): | |
| for key in self._executor.state: | |
| if not key.startswith('_'): | |
| variables.add(key) | |
| # Include tracked user variables | |
| variables.update(self._user_variables) | |
| return list(variables) | |
| def execute(self, code: str) -> Dict[str, Any]: | |
| """Execute Python code and return results. | |
| Args: | |
| code: Python code to execute | |
| Returns: | |
| Dictionary with stdout, stderr, locals_snapshot, execution_time, | |
| success, and exception fields | |
| """ | |
| start_time = time.time() | |
| success = True | |
| exception_msg = None | |
| new_locals: Dict[str, str] = {} | |
| # Track state before execution | |
| pre_state_keys = set() | |
| if hasattr(self._executor, 'state'): | |
| pre_state_keys = set(self._executor.state.keys()) | |
| stdout_parts: list[str] = [] | |
| stderr_parts: list[str] = [] | |
| try: | |
| exec_result = self._executor(code) | |
| # Extract logs/prints | |
| try: | |
| logs = getattr(exec_result, "logs", None) | |
| if logs: | |
| stdout_parts.append(str(logs)) | |
| except Exception: | |
| logger.debug("Failed to read exec_result.logs", exc_info=True) | |
| # Extract the result / output value | |
| try: | |
| if hasattr(exec_result, "output"): | |
| out_val = exec_result.output | |
| if out_val is not None: | |
| try: | |
| stdout_parts.append(json.dumps(out_val)) | |
| except Exception: | |
| stdout_parts.append(repr(out_val)) | |
| except Exception: | |
| logger.debug("Failed to read exec_result.output", exc_info=True) | |
| # Check for errors | |
| try: | |
| err = getattr(exec_result, "error", None) | |
| if err: | |
| stderr_parts.append(str(err)) | |
| success = False | |
| exception_msg = str(err) | |
| except Exception: | |
| logger.debug("Failed to read exec_result.error", exc_info=True) | |
| try: | |
| ex = getattr(exec_result, "exception", None) | |
| if ex: | |
| stderr_parts.append(str(ex)) | |
| success = False | |
| exception_msg = str(ex) | |
| except Exception: | |
| logger.debug("Failed to read exec_result.exception", exc_info=True) | |
| # Determine success from exit_code if available | |
| try: | |
| if hasattr(exec_result, "exit_code"): | |
| if exec_result.exit_code is not None and exec_result.exit_code != 0: | |
| success = False | |
| elif hasattr(exec_result, "success"): | |
| success = bool(exec_result.success) | |
| except Exception: | |
| logger.debug("Failed to determine exec_result exit code", exc_info=True) | |
| except Exception as e: | |
| success = False | |
| exception_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" | |
| stderr_parts.append(exception_msg) | |
| execution_time = time.time() - start_time | |
| # Capture new/modified variables | |
| if hasattr(self._executor, 'state'): | |
| for key in self._executor.state: | |
| if key not in pre_state_keys and not key.startswith('_'): | |
| try: | |
| val = self._executor.state[key] | |
| val_repr = repr(val) | |
| if len(val_repr) > 500: | |
| val_repr = val_repr[:500] + "..." | |
| new_locals[key] = val_repr | |
| self._user_variables.add(key) | |
| except Exception: | |
| new_locals[key] = "<unrepresentable>" | |
| # Compose stdout/stderr | |
| stdout = "\n".join(part for part in stdout_parts if part) | |
| stderr = "\n".join(part for part in stderr_parts if part) | |
| # Truncate output to max_output_length | |
| if len(stdout) > self.max_output_length: | |
| stdout = stdout[:self.max_output_length] + f"\n... (truncated, total {len(stdout)} chars)" | |
| if len(stderr) > self.max_output_length: | |
| stderr = stderr[:self.max_output_length] + f"\n... (truncated, total {len(stderr)} chars)" | |
| return { | |
| "stdout": stdout, | |
| "stderr": stderr, | |
| "locals_snapshot": new_locals, | |
| "execution_time": execution_time, | |
| "success": success, | |
| "exception": exception_msg, | |
| } | |
| def reset(self) -> None: | |
| """Reset namespace to initial state.""" | |
| # Create a new executor instance | |
| self._executor = LocalPythonExecutor( | |
| additional_authorized_imports=self.allowed_imports | |
| ) | |
| self._user_variables.clear() | |
| self._registered_tools.clear() | |
| self._register_helpers() | |
| def inject_function(self, name: str, func: Callable[..., Any]) -> None: | |
| """Inject a callable function into the namespace. | |
| Used for adding llm_query, llm_batch, etc. | |
| Args: | |
| name: Function name in namespace | |
| func: The callable to inject | |
| """ | |
| # Add to registered tools and sync | |
| self._registered_tools[name] = func | |
| self._user_variables.add(name) | |
| self._sync_tools() | |