File size: 11,693 Bytes
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
 
 
 
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
 
 
07fffa0
 
 
 
 
 
 
 
 
 
 
 
4721d14
 
 
 
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
07fffa0
 
 
4721d14
 
 
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
07fffa0
 
 
4721d14
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
07fffa0
4721d14
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
 
 
07fffa0
 
 
 
4721d14
 
 
 
07fffa0
 
 
 
4721d14
 
 
07fffa0
 
 
4721d14
 
 
07fffa0
 
 
 
 
4721d14
07fffa0
4721d14
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4721d14
 
 
 
07fffa0
 
4721d14
 
 
 
07fffa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Sandboxed Python code executor for the REPL environment.

Uses smolagents.LocalPythonExecutor as the backend for battle-tested sandboxed
execution, with RLM-specific features on top:
- Context loading (set_context)
- Variable access (get_variable, list_variables)
- Function injection (inject_function for llm_query, llm_query_batched)
- Output capped at 8,192 characters per turn (configurable)
- Persistent namespace across code blocks
"""

import json
import logging
import time
import traceback
from collections.abc import Callable
from typing import Any, Dict, List, Optional

from smolagents import LocalPythonExecutor

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


class PythonExecutor:
    """Sandboxed Python code executor with persistent namespace.

    Wraps smolagents.LocalPythonExecutor with RLM-specific features:
    - Context loading for RLM tasks
    - Variable tracking for observation
    - Function injection for llm_query, llm_query_batched
    - Configurable output length limit (default 8192 chars per Prime Intellect)
    """

    def __init__(
        self,
        max_output_length: int = 8192,
        allowed_imports: Optional[List[str]] = None,
    ):
        """Initialize the executor.

        Args:
            max_output_length: Maximum characters for stdout/stderr (default 8192)
            allowed_imports: List of allowed module names for import

        Note:
            smolagents.LocalPythonExecutor does NOT support wall-clock timeouts.
            Instead, it limits operations (10M ops) and while iterations (1M).
        """
        self.max_output_length = max_output_length

        # Default allowed imports for RLM tasks
        default_imports = [
            "re",
            "json",
            "math",
            "random",
            "collections",
            "itertools",
            "functools",
            "operator",
            "string",
            "textwrap",
            "difflib",
            "statistics",
            "decimal",
            "fractions",
            "datetime",
            "copy",
            "pprint",
            "typing",
            "dataclasses",
            "enum",
            "bisect",
            "heapq",
            "array",
            "struct",
            "base64",
            "hashlib",
            "hmac",
            "uuid",
        ]

        self.allowed_imports = allowed_imports or default_imports

        # Initialize the smolagents executor
        self._executor = LocalPythonExecutor(
            additional_authorized_imports=self.allowed_imports
        )

        # Track variables we've set (for list_variables)
        self._user_variables: set[str] = set()

        # Track callable functions to register with send_tools
        self._callable_tools: Dict[str, Callable[..., Any]] = {}

        # Register helper utilities
        self._register_helpers()

    def _register_helpers(self) -> None:
        """Register helper functions with the executor."""
        helpers = {
            "format_exc": traceback.format_exc,
            "safe_json_dumps": lambda obj: json.dumps(
                obj, default=lambda o: repr(o)
            ),
        }
        # Register helpers as callable tools
        for name, func in helpers.items():
            self.inject_function(name, func)

    def _sync_callable_tools(self) -> None:
        """Sync callable functions with the executor via send_tools."""
        if self._callable_tools:
            try:
                # Type ignore: smolagents accepts callables despite Tool type hint
                self._executor.send_tools(self._callable_tools)  # type: ignore[arg-type]
            except Exception:
                logger.debug(
                    "send_tools failed; continuing without extra tools",
                    exc_info=True,
                )

    def set_context(self, context: str, variable_name: str = "context") -> None:
        """Load context into namespace as a variable.

        Args:
            context: The context string to load
            variable_name: Name of the variable (default "context")
        """
        self.set_variable(variable_name, context)

    def set_variable(self, name: str, value: Any) -> None:
        """Set a variable in the namespace.

        Args:
            name: Variable name
            value: Variable value
        """
        # Access the executor's internal state to set variables
        if hasattr(self._executor, "state"):
            self._executor.state[name] = value
        else:
            # Fallback: store in injected vars for later retrieval
            self._executor._injected_vars = getattr(
                self._executor, "_injected_vars", {}
            )
            self._executor._injected_vars[name] = value

        self._user_variables.add(name)

    def get_variable(self, name: str) -> Optional[Any]:
        """Retrieve a variable from namespace.

        Args:
            name: Variable name

        Returns:
            The variable value or None if not found
        """
        # Try to get from executor's state
        if hasattr(self._executor, "state"):
            return self._executor.state.get(name)

        # Fallback to injected vars
        if hasattr(self._executor, "_injected_vars"):
            return self._executor._injected_vars.get(name)

        return None

    def list_variables(self) -> List[str]:
        """List non-private variables in namespace.

        Returns:
            List of variable names (excluding private and builtins)
        """
        variables = set()

        # Get from executor's state
        if hasattr(self._executor, "state"):
            for key in self._executor.state:
                if not key.startswith("_"):
                    variables.add(key)

        # Include tracked user variables
        variables.update(self._user_variables)

        return list(variables)

    def execute(self, code: str) -> Dict[str, Any]:
        """Execute Python code and return results.

        Args:
            code: Python code to execute

        Returns:
            Dictionary with stdout, stderr, locals_snapshot, execution_time,
            success, and exception fields
        """
        start_time = time.time()
        success = True
        exception_msg = None
        new_locals: Dict[str, str] = {}

        # Track state before execution
        pre_state_keys = set()
        if hasattr(self._executor, "state"):
            pre_state_keys = set(self._executor.state.keys())

        stdout_parts: list[str] = []
        stderr_parts: list[str] = []

        try:
            exec_result = self._executor(code)

            # Extract logs/prints
            try:
                logs = getattr(exec_result, "logs", None)
                if logs:
                    stdout_parts.append(str(logs))
            except Exception:
                logger.debug("Failed to read exec_result.logs", exc_info=True)

            # Extract the result / output value
            try:
                if hasattr(exec_result, "output"):
                    out_val = exec_result.output
                    if out_val is not None:
                        try:
                            stdout_parts.append(json.dumps(out_val))
                        except Exception:
                            stdout_parts.append(repr(out_val))
            except Exception:
                logger.debug("Failed to read exec_result.output", exc_info=True)

            # Check for errors
            try:
                err = getattr(exec_result, "error", None)
                if err:
                    stderr_parts.append(str(err))
                    success = False
                    exception_msg = str(err)
            except Exception:
                logger.debug("Failed to read exec_result.error", exc_info=True)

            try:
                ex = getattr(exec_result, "exception", None)
                if ex:
                    stderr_parts.append(str(ex))
                    success = False
                    exception_msg = str(ex)
            except Exception:
                logger.debug(
                    "Failed to read exec_result.exception", exc_info=True
                )

            # Determine success from exit_code if available
            try:
                if hasattr(exec_result, "exit_code"):
                    if (
                        exec_result.exit_code is not None
                        and exec_result.exit_code != 0
                    ):
                        success = False
                elif hasattr(exec_result, "success"):
                    success = bool(exec_result.success)
            except Exception:
                logger.debug(
                    "Failed to determine exec_result exit code", exc_info=True
                )

        except Exception as e:
            success = False
            exception_msg = (
                f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
            )
            stderr_parts.append(exception_msg)

        execution_time = time.time() - start_time

        # Capture new/modified variables
        if hasattr(self._executor, "state"):
            for key in self._executor.state:
                if key not in pre_state_keys and not key.startswith("_"):
                    try:
                        val = self._executor.state[key]
                        val_repr = repr(val)
                        if len(val_repr) > 500:
                            val_repr = val_repr[:500] + "..."
                        new_locals[key] = val_repr
                        self._user_variables.add(key)
                    except Exception:
                        new_locals[key] = "<unrepresentable>"

        # Compose stdout/stderr
        stdout = "\n".join(part for part in stdout_parts if part)
        stderr = "\n".join(part for part in stderr_parts if part)

        # Truncate output to max_output_length
        if len(stdout) > self.max_output_length:
            stdout = (
                stdout[: self.max_output_length]
                + f"\n... (truncated, total {len(stdout)} chars)"
            )

        if len(stderr) > self.max_output_length:
            stderr = (
                stderr[: self.max_output_length]
                + f"\n... (truncated, total {len(stderr)} chars)"
            )

        return {
            "stdout": stdout,
            "stderr": stderr,
            "locals_snapshot": new_locals,
            "execution_time": execution_time,
            "success": success,
            "exception": exception_msg,
        }

    def reset(self) -> None:
        """Reset namespace to initial state."""
        # Create a new executor instance
        self._executor = LocalPythonExecutor(
            additional_authorized_imports=self.allowed_imports
        )
        self._user_variables.clear()
        self._callable_tools.clear()
        self._register_helpers()

    def inject_function(self, name: str, func: Callable[..., Any]) -> None:
        """Inject a callable function into the namespace.

        Used for adding llm_query, llm_query_batched, FINAL, etc.

        Args:
            name: Function name in namespace
            func: The callable to inject
        """
        # Add to callable tools and sync with executor
        self._callable_tools[name] = func
        self._user_variables.add(name)
        self._sync_callable_tools()