Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import ast | |
| import builtins | |
| import difflib | |
| import inspect | |
| import logging | |
| import math | |
| import re | |
| from collections.abc import Mapping | |
| from importlib import import_module | |
| from types import ModuleType | |
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| from .tools import Tool | |
| from .utils import BASE_BUILTIN_MODULES, truncate_content | |
| logger = logging.getLogger(__name__) | |
| class InterpreterError(ValueError): | |
| """ | |
| An error raised when the interpreter cannot evaluate a Python expression, due to syntax error or unsupported | |
| operations. | |
| """ | |
| pass | |
| ERRORS = { | |
| name: getattr(builtins, name) | |
| for name in dir(builtins) | |
| if isinstance(getattr(builtins, name), type) and issubclass(getattr(builtins, name), BaseException) | |
| } | |
| DEFAULT_MAX_LEN_OUTPUT = 50000 | |
| MAX_OPERATIONS = 10000000 | |
| MAX_WHILE_ITERATIONS = 1000000 | |
| def custom_print(*args): | |
| return None | |
| BASE_PYTHON_TOOLS = { | |
| "print": custom_print, | |
| "isinstance": isinstance, | |
| "range": range, | |
| "float": float, | |
| "int": int, | |
| "bool": bool, | |
| "str": str, | |
| "set": set, | |
| "list": list, | |
| "dict": dict, | |
| "tuple": tuple, | |
| "round": round, | |
| "ceil": math.ceil, | |
| "floor": math.floor, | |
| "log": math.log, | |
| "exp": math.exp, | |
| "sin": math.sin, | |
| "cos": math.cos, | |
| "tan": math.tan, | |
| "asin": math.asin, | |
| "acos": math.acos, | |
| "atan": math.atan, | |
| "atan2": math.atan2, | |
| "degrees": math.degrees, | |
| "radians": math.radians, | |
| "pow": pow, | |
| "sqrt": math.sqrt, | |
| "len": len, | |
| "sum": sum, | |
| "max": max, | |
| "min": min, | |
| "abs": abs, | |
| "enumerate": enumerate, | |
| "zip": zip, | |
| "reversed": reversed, | |
| "sorted": sorted, | |
| "all": all, | |
| "any": any, | |
| "map": map, | |
| "filter": filter, | |
| "ord": ord, | |
| "chr": chr, | |
| "next": next, | |
| "iter": iter, | |
| "divmod": divmod, | |
| "callable": callable, | |
| "getattr": getattr, | |
| "hasattr": hasattr, | |
| "setattr": setattr, | |
| "issubclass": issubclass, | |
| "type": type, | |
| "complex": complex, | |
| } | |
| DANGEROUS_PATTERNS = ( | |
| "_os", | |
| "os", | |
| "subprocess", | |
| "_subprocess", | |
| "pty", | |
| "system", | |
| "popen", | |
| "spawn", | |
| "shutil", | |
| "sys", | |
| "pathlib", | |
| "io", | |
| "socket", | |
| "compile", | |
| "eval", | |
| "exec", | |
| "multiprocessing", | |
| ) | |
| class PrintContainer: | |
| def __init__(self): | |
| self.value = "" | |
| def append(self, text): | |
| self.value += text | |
| return self | |
| def __iadd__(self, other): | |
| """Implements the += operator""" | |
| self.value += str(other) | |
| return self | |
| def __str__(self): | |
| """String representation""" | |
| return self.value | |
| def __repr__(self): | |
| """Representation for debugging""" | |
| return f"PrintContainer({self.value})" | |
| def __len__(self): | |
| """Implements len() function support""" | |
| return len(self.value) | |
| class BreakException(Exception): | |
| pass | |
| class ContinueException(Exception): | |
| pass | |
| class ReturnException(Exception): | |
| def __init__(self, value): | |
| self.value = value | |
| def get_iterable(obj): | |
| if isinstance(obj, list): | |
| return obj | |
| elif hasattr(obj, "__iter__"): | |
| return list(obj) | |
| else: | |
| raise InterpreterError("Object is not iterable") | |
| def fix_final_answer_code(code: str) -> str: | |
| """ | |
| Sometimes an LLM can try to assign a variable to final_answer, which would break the final_answer() tool. | |
| This function fixes this behaviour by replacing variable assignments to final_answer with final_answer_variable, | |
| while preserving function calls to final_answer(). | |
| """ | |
| # First, find if there's a direct assignment to final_answer | |
| # Use word boundary and negative lookbehind to ensure it's not an object attribute | |
| assignment_pattern = r"(?<!\.)(?<!\w)\bfinal_answer\s*=" | |
| if "final_answer(" not in code or not re.search(assignment_pattern, code): | |
| # If final_answer tool is not called in this blob, then doing the replacement is hazardous because it could false the model's memory for next steps. | |
| # Let's not modify the code and leave the subsequent assignment error happen. | |
| return code | |
| # Pattern for replacing variable assignments | |
| # Looks for 'final_answer' followed by '=' with optional whitespace | |
| # Negative lookbehind ensures we don't match object attributes | |
| assignment_regex = r"(?<!\.)(?<!\w)(\bfinal_answer)(\s*=)" | |
| code = re.sub(assignment_regex, r"final_answer_variable\2", code) | |
| # Pattern for replacing variable usage but not function calls | |
| # Negative lookahead (?!\s*\() ensures we don't match function calls | |
| # Negative lookbehind (?<!\.|\w) ensures we don't match object methods or other variables | |
| variable_regex = r"(?<!\.)(?<!\w)(\bfinal_answer\b)(?!\s*\()" | |
| code = re.sub(variable_regex, "final_answer_variable", code) | |
| return code | |
| def evaluate_unaryop( | |
| expression: ast.UnaryOp, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| operand = evaluate_ast(expression.operand, state, static_tools, custom_tools, authorized_imports) | |
| if isinstance(expression.op, ast.USub): | |
| return -operand | |
| elif isinstance(expression.op, ast.UAdd): | |
| return operand | |
| elif isinstance(expression.op, ast.Not): | |
| return not operand | |
| elif isinstance(expression.op, ast.Invert): | |
| return ~operand | |
| else: | |
| raise InterpreterError(f"Unary operation {expression.op.__class__.__name__} is not supported.") | |
| def evaluate_lambda( | |
| lambda_expression: ast.Lambda, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Callable: | |
| args = [arg.arg for arg in lambda_expression.args.args] | |
| def lambda_func(*values: Any) -> Any: | |
| new_state = state.copy() | |
| for arg, value in zip(args, values): | |
| new_state[arg] = value | |
| return evaluate_ast( | |
| lambda_expression.body, | |
| new_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| return lambda_func | |
| def evaluate_while( | |
| while_loop: ast.While, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> None: | |
| iterations = 0 | |
| while evaluate_ast(while_loop.test, state, static_tools, custom_tools, authorized_imports): | |
| for node in while_loop.body: | |
| try: | |
| evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) | |
| except BreakException: | |
| return None | |
| except ContinueException: | |
| break | |
| iterations += 1 | |
| if iterations > MAX_WHILE_ITERATIONS: | |
| raise InterpreterError(f"Maximum number of {MAX_WHILE_ITERATIONS} iterations in While loop exceeded") | |
| return None | |
| def create_function( | |
| func_def: ast.FunctionDef, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Callable: | |
| def new_func(*args: Any, **kwargs: Any) -> Any: | |
| func_state = state.copy() | |
| arg_names = [arg.arg for arg in func_def.args.args] | |
| default_values = [ | |
| evaluate_ast(d, state, static_tools, custom_tools, authorized_imports) for d in func_def.args.defaults | |
| ] | |
| # Apply default values | |
| defaults = dict(zip(arg_names[-len(default_values) :], default_values)) | |
| # Set positional arguments | |
| for name, value in zip(arg_names, args): | |
| func_state[name] = value | |
| # Set keyword arguments | |
| for name, value in kwargs.items(): | |
| func_state[name] = value | |
| # Handle variable arguments | |
| if func_def.args.vararg: | |
| vararg_name = func_def.args.vararg.arg | |
| func_state[vararg_name] = args | |
| if func_def.args.kwarg: | |
| kwarg_name = func_def.args.kwarg.arg | |
| func_state[kwarg_name] = kwargs | |
| # Set default values for arguments that were not provided | |
| for name, value in defaults.items(): | |
| if name not in func_state: | |
| func_state[name] = value | |
| # Update function state with self and __class__ | |
| if func_def.args.args and func_def.args.args[0].arg == "self": | |
| if args: | |
| func_state["self"] = args[0] | |
| func_state["__class__"] = args[0].__class__ | |
| result = None | |
| try: | |
| for stmt in func_def.body: | |
| result = evaluate_ast(stmt, func_state, static_tools, custom_tools, authorized_imports) | |
| except ReturnException as e: | |
| result = e.value | |
| if func_def.name == "__init__": | |
| return None | |
| return result | |
| return new_func | |
| def evaluate_function_def( | |
| func_def: ast.FunctionDef, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Callable: | |
| custom_tools[func_def.name] = create_function(func_def, state, static_tools, custom_tools, authorized_imports) | |
| return custom_tools[func_def.name] | |
| def evaluate_class_def( | |
| class_def: ast.ClassDef, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> type: | |
| class_name = class_def.name | |
| bases = [evaluate_ast(base, state, static_tools, custom_tools, authorized_imports) for base in class_def.bases] | |
| class_dict = {} | |
| for stmt in class_def.body: | |
| if isinstance(stmt, ast.FunctionDef): | |
| class_dict[stmt.name] = evaluate_function_def(stmt, state, static_tools, custom_tools, authorized_imports) | |
| elif isinstance(stmt, ast.Assign): | |
| for target in stmt.targets: | |
| if isinstance(target, ast.Name): | |
| class_dict[target.id] = evaluate_ast( | |
| stmt.value, | |
| state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| elif isinstance(target, ast.Attribute): | |
| class_dict[target.attr] = evaluate_ast( | |
| stmt.value, | |
| state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| else: | |
| raise InterpreterError(f"Unsupported statement in class body: {stmt.__class__.__name__}") | |
| new_class = type(class_name, tuple(bases), class_dict) | |
| state[class_name] = new_class | |
| return new_class | |
| def evaluate_augassign( | |
| expression: ast.AugAssign, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| def get_current_value(target: ast.AST) -> Any: | |
| if isinstance(target, ast.Name): | |
| return state.get(target.id, 0) | |
| elif isinstance(target, ast.Subscript): | |
| obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) | |
| key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) | |
| return obj[key] | |
| elif isinstance(target, ast.Attribute): | |
| obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) | |
| return getattr(obj, target.attr) | |
| elif isinstance(target, ast.Tuple): | |
| return tuple(get_current_value(elt) for elt in target.elts) | |
| elif isinstance(target, ast.List): | |
| return [get_current_value(elt) for elt in target.elts] | |
| else: | |
| raise InterpreterError("AugAssign not supported for {type(target)} targets.") | |
| current_value = get_current_value(expression.target) | |
| value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools, authorized_imports) | |
| if isinstance(expression.op, ast.Add): | |
| if isinstance(current_value, list): | |
| if not isinstance(value_to_add, list): | |
| raise InterpreterError(f"Cannot add non-list value {value_to_add} to a list.") | |
| current_value += value_to_add | |
| else: | |
| current_value += value_to_add | |
| elif isinstance(expression.op, ast.Sub): | |
| current_value -= value_to_add | |
| elif isinstance(expression.op, ast.Mult): | |
| current_value *= value_to_add | |
| elif isinstance(expression.op, ast.Div): | |
| current_value /= value_to_add | |
| elif isinstance(expression.op, ast.Mod): | |
| current_value %= value_to_add | |
| elif isinstance(expression.op, ast.Pow): | |
| current_value **= value_to_add | |
| elif isinstance(expression.op, ast.FloorDiv): | |
| current_value //= value_to_add | |
| elif isinstance(expression.op, ast.BitAnd): | |
| current_value &= value_to_add | |
| elif isinstance(expression.op, ast.BitOr): | |
| current_value |= value_to_add | |
| elif isinstance(expression.op, ast.BitXor): | |
| current_value ^= value_to_add | |
| elif isinstance(expression.op, ast.LShift): | |
| current_value <<= value_to_add | |
| elif isinstance(expression.op, ast.RShift): | |
| current_value >>= value_to_add | |
| else: | |
| raise InterpreterError(f"Operation {type(expression.op).__name__} is not supported.") | |
| # Update the state: current_value has been updated in-place | |
| set_value( | |
| expression.target, | |
| current_value, | |
| state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| return current_value | |
| def evaluate_boolop( | |
| node: ast.BoolOp, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> bool: | |
| if isinstance(node.op, ast.And): | |
| for value in node.values: | |
| if not evaluate_ast(value, state, static_tools, custom_tools, authorized_imports): | |
| return False | |
| return True | |
| elif isinstance(node.op, ast.Or): | |
| for value in node.values: | |
| if evaluate_ast(value, state, static_tools, custom_tools, authorized_imports): | |
| return True | |
| return False | |
| def evaluate_binop( | |
| binop: ast.BinOp, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| # Recursively evaluate the left and right operands | |
| left_val = evaluate_ast(binop.left, state, static_tools, custom_tools, authorized_imports) | |
| right_val = evaluate_ast(binop.right, state, static_tools, custom_tools, authorized_imports) | |
| # Determine the operation based on the type of the operator in the BinOp | |
| if isinstance(binop.op, ast.Add): | |
| return left_val + right_val | |
| elif isinstance(binop.op, ast.Sub): | |
| return left_val - right_val | |
| elif isinstance(binop.op, ast.Mult): | |
| return left_val * right_val | |
| elif isinstance(binop.op, ast.Div): | |
| return left_val / right_val | |
| elif isinstance(binop.op, ast.Mod): | |
| return left_val % right_val | |
| elif isinstance(binop.op, ast.Pow): | |
| return left_val**right_val | |
| elif isinstance(binop.op, ast.FloorDiv): | |
| return left_val // right_val | |
| elif isinstance(binop.op, ast.BitAnd): | |
| return left_val & right_val | |
| elif isinstance(binop.op, ast.BitOr): | |
| return left_val | right_val | |
| elif isinstance(binop.op, ast.BitXor): | |
| return left_val ^ right_val | |
| elif isinstance(binop.op, ast.LShift): | |
| return left_val << right_val | |
| elif isinstance(binop.op, ast.RShift): | |
| return left_val >> right_val | |
| else: | |
| raise NotImplementedError(f"Binary operation {type(binop.op).__name__} is not implemented.") | |
| def evaluate_assign( | |
| assign: ast.Assign, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| result = evaluate_ast(assign.value, state, static_tools, custom_tools, authorized_imports) | |
| if len(assign.targets) == 1: | |
| target = assign.targets[0] | |
| set_value(target, result, state, static_tools, custom_tools, authorized_imports) | |
| else: | |
| if len(assign.targets) != len(result): | |
| raise InterpreterError(f"Assign failed: expected {len(result)} values but got {len(assign.targets)}.") | |
| expanded_values = [] | |
| for tgt in assign.targets: | |
| if isinstance(tgt, ast.Starred): | |
| expanded_values.extend(result) | |
| else: | |
| expanded_values.append(result) | |
| for tgt, val in zip(assign.targets, expanded_values): | |
| set_value(tgt, val, state, static_tools, custom_tools, authorized_imports) | |
| return result | |
| def set_value( | |
| target: ast.AST, | |
| value: Any, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> None: | |
| if isinstance(target, ast.Name): | |
| if target.id in static_tools: | |
| raise InterpreterError(f"Cannot assign to name '{target.id}': doing this would erase the existing tool!") | |
| state[target.id] = value | |
| elif isinstance(target, ast.Tuple): | |
| if not isinstance(value, tuple): | |
| if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)): | |
| value = tuple(value) | |
| else: | |
| raise InterpreterError("Cannot unpack non-tuple value") | |
| if len(target.elts) != len(value): | |
| raise InterpreterError("Cannot unpack tuple of wrong size") | |
| for i, elem in enumerate(target.elts): | |
| set_value(elem, value[i], state, static_tools, custom_tools, authorized_imports) | |
| elif isinstance(target, ast.Subscript): | |
| obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) | |
| key = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) | |
| obj[key] = value | |
| elif isinstance(target, ast.Attribute): | |
| obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) | |
| setattr(obj, target.attr, value) | |
| def evaluate_call( | |
| call: ast.Call, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| if not ( | |
| isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name) or isinstance(call.func, ast.Subscript) | |
| ): | |
| raise InterpreterError(f"This is not a correct function: {call.func}).") | |
| if isinstance(call.func, ast.Attribute): | |
| obj = evaluate_ast(call.func.value, state, static_tools, custom_tools, authorized_imports) | |
| func_name = call.func.attr | |
| if not hasattr(obj, func_name): | |
| raise InterpreterError(f"Object {obj} has no attribute {func_name}") | |
| func = getattr(obj, func_name) | |
| elif isinstance(call.func, ast.Name): | |
| func_name = call.func.id | |
| if func_name in state: | |
| func = state[func_name] | |
| elif func_name in static_tools: | |
| func = static_tools[func_name] | |
| elif func_name in custom_tools: | |
| func = custom_tools[func_name] | |
| elif func_name in ERRORS: | |
| func = ERRORS[func_name] | |
| else: | |
| raise InterpreterError( | |
| f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})." | |
| ) | |
| elif isinstance(call.func, ast.Subscript): | |
| func = evaluate_subscript(call.func, state, static_tools, custom_tools, authorized_imports) | |
| if not callable(func): | |
| raise InterpreterError(f"This is not a correct function: {call.func}).") | |
| func_name = None | |
| args = [] | |
| for arg in call.args: | |
| if isinstance(arg, ast.Starred): | |
| args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools, authorized_imports)) | |
| else: | |
| args.append(evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)) | |
| kwargs = { | |
| keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools, authorized_imports) | |
| for keyword in call.keywords | |
| } | |
| if func_name == "super": | |
| if not args: | |
| if "__class__" in state and "self" in state: | |
| return super(state["__class__"], state["self"]) | |
| else: | |
| raise InterpreterError("super() needs at least one argument") | |
| cls = args[0] | |
| if not isinstance(cls, type): | |
| raise InterpreterError("super() argument 1 must be type") | |
| if len(args) == 1: | |
| return super(cls) | |
| elif len(args) == 2: | |
| instance = args[1] | |
| return super(cls, instance) | |
| else: | |
| raise InterpreterError("super() takes at most 2 arguments") | |
| else: | |
| if func_name == "print": | |
| state["_print_outputs"] += " ".join(map(str, args)) + "\n" | |
| return None | |
| else: # Assume it's a callable object | |
| if ( | |
| (inspect.getmodule(func) == builtins) | |
| and inspect.isbuiltin(func) | |
| and (func not in static_tools.values()) | |
| ): | |
| raise InterpreterError( | |
| f"Invoking a builtin function that has not been explicitly added as a tool is not allowed ({func_name})." | |
| ) | |
| return func(*args, **kwargs) | |
| def evaluate_subscript( | |
| subscript: ast.Subscript, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| index = evaluate_ast(subscript.slice, state, static_tools, custom_tools, authorized_imports) | |
| value = evaluate_ast(subscript.value, state, static_tools, custom_tools, authorized_imports) | |
| if isinstance(value, str) and isinstance(index, str): | |
| raise InterpreterError("You're trying to subscript a string with a string index, which is impossible") | |
| if isinstance(value, pd.core.indexing._LocIndexer): | |
| parent_object = value.obj | |
| return parent_object.loc[index] | |
| if isinstance(value, pd.core.indexing._iLocIndexer): | |
| parent_object = value.obj | |
| return parent_object.iloc[index] | |
| if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)): | |
| return value[index] | |
| elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy): | |
| return value[index] | |
| elif isinstance(index, slice): | |
| return value[index] | |
| elif isinstance(value, (list, tuple)): | |
| if not (-len(value) <= index < len(value)): | |
| raise InterpreterError(f"Index {index} out of bounds for list of length {len(value)}") | |
| return value[int(index)] | |
| elif isinstance(value, str): | |
| if not (-len(value) <= index < len(value)): | |
| raise InterpreterError(f"Index {index} out of bounds for string of length {len(value)}") | |
| return value[index] | |
| elif index in value: | |
| return value[index] | |
| else: | |
| error_message = f"Could not index {value} with '{index}'." | |
| if isinstance(index, str) and isinstance(value, Mapping): | |
| close_matches = difflib.get_close_matches(index, list(value.keys())) | |
| if len(close_matches) > 0: | |
| error_message += f" Maybe you meant one of these indexes instead: {str(close_matches)}" | |
| raise InterpreterError(error_message) | |
| def evaluate_name( | |
| name: ast.Name, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| if name.id in state: | |
| return state[name.id] | |
| elif name.id in static_tools: | |
| return static_tools[name.id] | |
| elif name.id in custom_tools: | |
| return custom_tools[name.id] | |
| elif name.id in ERRORS: | |
| return ERRORS[name.id] | |
| close_matches = difflib.get_close_matches(name.id, list(state.keys())) | |
| if len(close_matches) > 0: | |
| return state[close_matches[0]] | |
| raise InterpreterError(f"The variable `{name.id}` is not defined.") | |
| def evaluate_condition( | |
| condition: ast.Compare, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> bool | object: | |
| result = True | |
| left = evaluate_ast(condition.left, state, static_tools, custom_tools, authorized_imports) | |
| for i, (op, comparator) in enumerate(zip(condition.ops, condition.comparators)): | |
| op = type(op) | |
| right = evaluate_ast(comparator, state, static_tools, custom_tools, authorized_imports) | |
| if op == ast.Eq: | |
| current_result = left == right | |
| elif op == ast.NotEq: | |
| current_result = left != right | |
| elif op == ast.Lt: | |
| current_result = left < right | |
| elif op == ast.LtE: | |
| current_result = left <= right | |
| elif op == ast.Gt: | |
| current_result = left > right | |
| elif op == ast.GtE: | |
| current_result = left >= right | |
| elif op == ast.Is: | |
| current_result = left is right | |
| elif op == ast.IsNot: | |
| current_result = left is not right | |
| elif op == ast.In: | |
| current_result = left in right | |
| elif op == ast.NotIn: | |
| current_result = left not in right | |
| else: | |
| raise InterpreterError(f"Unsupported comparison operator: {op}") | |
| if current_result is False: | |
| return False | |
| result = current_result if i == 0 else (result and current_result) | |
| left = right | |
| return result | |
| def evaluate_if( | |
| if_statement: ast.If, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| result = None | |
| test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools, authorized_imports) | |
| if test_result: | |
| for line in if_statement.body: | |
| line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports) | |
| if line_result is not None: | |
| result = line_result | |
| else: | |
| for line in if_statement.orelse: | |
| line_result = evaluate_ast(line, state, static_tools, custom_tools, authorized_imports) | |
| if line_result is not None: | |
| result = line_result | |
| return result | |
| def evaluate_for( | |
| for_loop: ast.For, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Any: | |
| result = None | |
| iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools, authorized_imports) | |
| for counter in iterator: | |
| set_value( | |
| for_loop.target, | |
| counter, | |
| state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| for node in for_loop.body: | |
| try: | |
| line_result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) | |
| if line_result is not None: | |
| result = line_result | |
| except BreakException: | |
| break | |
| except ContinueException: | |
| continue | |
| else: | |
| continue | |
| break | |
| return result | |
| def evaluate_listcomp( | |
| listcomp: ast.ListComp, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> List[Any]: | |
| def inner_evaluate(generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]) -> List[Any]: | |
| if index >= len(generators): | |
| return [ | |
| evaluate_ast( | |
| listcomp.elt, | |
| current_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| ] | |
| generator = generators[index] | |
| iter_value = evaluate_ast( | |
| generator.iter, | |
| current_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| result = [] | |
| for value in iter_value: | |
| new_state = current_state.copy() | |
| if isinstance(generator.target, ast.Tuple): | |
| for idx, elem in enumerate(generator.target.elts): | |
| new_state[elem.id] = value[idx] | |
| else: | |
| new_state[generator.target.id] = value | |
| if all( | |
| evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) | |
| for if_clause in generator.ifs | |
| ): | |
| result.extend(inner_evaluate(generators, index + 1, new_state)) | |
| return result | |
| return inner_evaluate(listcomp.generators, 0, state) | |
| def evaluate_setcomp( | |
| setcomp: ast.SetComp, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Set[Any]: | |
| result = set() | |
| for gen in setcomp.generators: | |
| iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports) | |
| for value in iter_value: | |
| new_state = state.copy() | |
| set_value( | |
| gen.target, | |
| value, | |
| new_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| if all( | |
| evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) | |
| for if_clause in gen.ifs | |
| ): | |
| element = evaluate_ast( | |
| setcomp.elt, | |
| new_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| result.add(element) | |
| return result | |
| def evaluate_try( | |
| try_node: ast.Try, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> None: | |
| try: | |
| for stmt in try_node.body: | |
| evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) | |
| except Exception as e: | |
| matched = False | |
| for handler in try_node.handlers: | |
| if handler.type is None or isinstance( | |
| e, | |
| evaluate_ast(handler.type, state, static_tools, custom_tools, authorized_imports), | |
| ): | |
| matched = True | |
| if handler.name: | |
| state[handler.name] = e | |
| for stmt in handler.body: | |
| evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) | |
| break | |
| if not matched: | |
| raise e | |
| else: | |
| if try_node.orelse: | |
| for stmt in try_node.orelse: | |
| evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) | |
| finally: | |
| if try_node.finalbody: | |
| for stmt in try_node.finalbody: | |
| evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) | |
| def evaluate_raise( | |
| raise_node: ast.Raise, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> None: | |
| if raise_node.exc is not None: | |
| exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools, authorized_imports) | |
| else: | |
| exc = None | |
| if raise_node.cause is not None: | |
| cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools, authorized_imports) | |
| else: | |
| cause = None | |
| if exc is not None: | |
| if cause is not None: | |
| raise exc from cause | |
| else: | |
| raise exc | |
| else: | |
| raise InterpreterError("Re-raise is not supported without an active exception") | |
| def evaluate_assert( | |
| assert_node: ast.Assert, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> None: | |
| test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools, authorized_imports) | |
| if not test_result: | |
| if assert_node.msg: | |
| msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools, authorized_imports) | |
| raise AssertionError(msg) | |
| else: | |
| # Include the failing condition in the assertion message | |
| test_code = ast.unparse(assert_node.test) | |
| raise AssertionError(f"Assertion failed: {test_code}") | |
| def evaluate_with( | |
| with_node: ast.With, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> None: | |
| contexts = [] | |
| for item in with_node.items: | |
| context_expr = evaluate_ast(item.context_expr, state, static_tools, custom_tools, authorized_imports) | |
| if item.optional_vars: | |
| state[item.optional_vars.id] = context_expr.__enter__() | |
| contexts.append(state[item.optional_vars.id]) | |
| else: | |
| context_var = context_expr.__enter__() | |
| contexts.append(context_var) | |
| try: | |
| for stmt in with_node.body: | |
| evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports) | |
| except Exception as e: | |
| for context in reversed(contexts): | |
| context.__exit__(type(e), e, e.__traceback__) | |
| raise | |
| else: | |
| for context in reversed(contexts): | |
| context.__exit__(None, None, None) | |
| def get_safe_module(raw_module, authorized_imports, visited=None): | |
| """Creates a safe copy of a module or returns the original if it's a function""" | |
| # If it's a function or non-module object, return it directly | |
| if not isinstance(raw_module, ModuleType): | |
| return raw_module | |
| # Handle circular references: Initialize visited set for the first call | |
| if visited is None: | |
| visited = set() | |
| module_id = id(raw_module) | |
| if module_id in visited: | |
| return raw_module # Return original for circular refs | |
| visited.add(module_id) | |
| # Create new module for actual modules | |
| safe_module = ModuleType(raw_module.__name__) | |
| # Copy all attributes by reference, recursively checking modules | |
| for attr_name in dir(raw_module): | |
| # Skip dangerous patterns at any level | |
| if any( | |
| pattern in raw_module.__name__.split(".") + [attr_name] | |
| and not check_module_authorized(pattern, authorized_imports) | |
| for pattern in DANGEROUS_PATTERNS | |
| ): | |
| logger.info(f"Skipping dangerous attribute {raw_module.__name__}.{attr_name}") | |
| continue | |
| try: | |
| attr_value = getattr(raw_module, attr_name) | |
| except (ImportError, AttributeError) as e: | |
| # lazy / dynamic loading module -> INFO log and skip | |
| logger.info( | |
| f"Skipping import error while copying {raw_module.__name__}.{attr_name}: {type(e).__name__} - {e}" | |
| ) | |
| continue | |
| # Recursively process nested modules, passing visited set | |
| if isinstance(attr_value, ModuleType): | |
| attr_value = get_safe_module(attr_value, authorized_imports, visited=visited) | |
| setattr(safe_module, attr_name, attr_value) | |
| return safe_module | |
| def check_module_authorized(module_name, authorized_imports): | |
| if "*" in authorized_imports: | |
| return True | |
| else: | |
| module_path = module_name.split(".") | |
| if any([module in DANGEROUS_PATTERNS and module not in authorized_imports for module in module_path]): | |
| return False | |
| # ["A", "B", "C"] -> ["A", "A.B", "A.B.C"] | |
| module_subpaths = [".".join(module_path[:i]) for i in range(1, len(module_path) + 1)] | |
| return any(subpath in authorized_imports for subpath in module_subpaths) | |
| def import_modules(expression, state, authorized_imports): | |
| if isinstance(expression, ast.Import): | |
| for alias in expression.names: | |
| if check_module_authorized(alias.name, authorized_imports): | |
| raw_module = import_module(alias.name) | |
| state[alias.asname or alias.name] = get_safe_module(raw_module, authorized_imports) | |
| else: | |
| raise InterpreterError( | |
| f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}" | |
| ) | |
| return None | |
| elif isinstance(expression, ast.ImportFrom): | |
| if check_module_authorized(expression.module, authorized_imports): | |
| raw_module = __import__(expression.module, fromlist=[alias.name for alias in expression.names]) | |
| module = get_safe_module(raw_module, authorized_imports) | |
| if expression.names[0].name == "*": # Handle "from module import *" | |
| if hasattr(module, "__all__"): # If module has __all__, import only those names | |
| for name in module.__all__: | |
| state[name] = getattr(module, name) | |
| else: # If no __all__, import all public names (those not starting with '_') | |
| for name in dir(module): | |
| if not name.startswith("_"): | |
| state[name] = getattr(module, name) | |
| else: # regular from imports | |
| for alias in expression.names: | |
| if hasattr(module, alias.name): | |
| state[alias.asname or alias.name] = getattr(module, alias.name) | |
| else: | |
| raise InterpreterError(f"Module {expression.module} has no attribute {alias.name}") | |
| else: | |
| raise InterpreterError( | |
| f"Import from {expression.module} is not allowed. Authorized imports are: {str(authorized_imports)}" | |
| ) | |
| return None | |
| def evaluate_dictcomp( | |
| dictcomp: ast.DictComp, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> Dict[Any, Any]: | |
| result = {} | |
| for gen in dictcomp.generators: | |
| iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools, authorized_imports) | |
| for value in iter_value: | |
| new_state = state.copy() | |
| set_value( | |
| gen.target, | |
| value, | |
| new_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| if all( | |
| evaluate_ast(if_clause, new_state, static_tools, custom_tools, authorized_imports) | |
| for if_clause in gen.ifs | |
| ): | |
| key = evaluate_ast( | |
| dictcomp.key, | |
| new_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| val = evaluate_ast( | |
| dictcomp.value, | |
| new_state, | |
| static_tools, | |
| custom_tools, | |
| authorized_imports, | |
| ) | |
| result[key] = val | |
| return result | |
| def evaluate_delete( | |
| delete_node: ast.Delete, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str], | |
| ) -> None: | |
| """ | |
| Evaluate a delete statement (del x, del x[y]). | |
| Args: | |
| delete_node: The AST Delete node to evaluate | |
| state: The current state dictionary | |
| static_tools: Dictionary of static tools | |
| custom_tools: Dictionary of custom tools | |
| authorized_imports: List of authorized imports | |
| """ | |
| for target in delete_node.targets: | |
| if isinstance(target, ast.Name): | |
| # Handle simple variable deletion (del x) | |
| if target.id in state: | |
| del state[target.id] | |
| else: | |
| raise InterpreterError(f"Cannot delete name '{target.id}': name is not defined") | |
| elif isinstance(target, ast.Subscript): | |
| # Handle index/key deletion (del x[y]) | |
| obj = evaluate_ast(target.value, state, static_tools, custom_tools, authorized_imports) | |
| index = evaluate_ast(target.slice, state, static_tools, custom_tools, authorized_imports) | |
| try: | |
| del obj[index] | |
| except (TypeError, KeyError, IndexError) as e: | |
| raise InterpreterError(f"Cannot delete index/key: {str(e)}") | |
| else: | |
| raise InterpreterError(f"Deletion of {type(target).__name__} targets is not supported") | |
| def evaluate_ast( | |
| expression: ast.AST, | |
| state: Dict[str, Any], | |
| static_tools: Dict[str, Callable], | |
| custom_tools: Dict[str, Callable], | |
| authorized_imports: List[str] = BASE_BUILTIN_MODULES, | |
| ): | |
| """ | |
| Evaluate an abstract syntax tree using the content of the variables stored in a state and only evaluating a given | |
| set of functions. | |
| This function will recurse through the nodes of the tree provided. | |
| Args: | |
| expression (`ast.AST`): | |
| The code to evaluate, as an abstract syntax tree. | |
| state (`Dict[str, Any]`): | |
| A dictionary mapping variable names to values. The `state` is updated if need be when the evaluation | |
| encounters assignments. | |
| static_tools (`Dict[str, Callable]`): | |
| Functions that may be called during the evaluation. Trying to change one of these static_tools will raise an error. | |
| custom_tools (`Dict[str, Callable]`): | |
| Functions that may be called during the evaluation. These static_tools can be overwritten. | |
| authorized_imports (`List[str]`): | |
| The list of modules that can be imported by the code. By default, only a few safe modules are allowed. | |
| If it contains "*", it will authorize any import. Use this at your own risk! | |
| """ | |
| if state.setdefault("_operations_count", 0) >= MAX_OPERATIONS: | |
| raise InterpreterError( | |
| f"Reached the max number of operations of {MAX_OPERATIONS}. Maybe there is an infinite loop somewhere in the code, or you're just asking too many calculations." | |
| ) | |
| state["_operations_count"] += 1 | |
| common_params = (state, static_tools, custom_tools, authorized_imports) | |
| if isinstance(expression, ast.Assign): | |
| # Assignment -> we evaluate the assignment which should update the state | |
| # We return the variable assigned as it may be used to determine the final result. | |
| return evaluate_assign(expression, *common_params) | |
| elif isinstance(expression, ast.AugAssign): | |
| return evaluate_augassign(expression, *common_params) | |
| elif isinstance(expression, ast.Call): | |
| # Function call -> we return the value of the function call | |
| return evaluate_call(expression, *common_params) | |
| elif isinstance(expression, ast.Constant): | |
| # Constant -> just return the value | |
| return expression.value | |
| elif isinstance(expression, ast.Tuple): | |
| return tuple((evaluate_ast(elt, *common_params) for elt in expression.elts)) | |
| elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)): | |
| return evaluate_listcomp(expression, *common_params) | |
| elif isinstance(expression, ast.DictComp): | |
| return evaluate_dictcomp(expression, *common_params) | |
| elif isinstance(expression, ast.SetComp): | |
| return evaluate_setcomp(expression, *common_params) | |
| elif isinstance(expression, ast.UnaryOp): | |
| return evaluate_unaryop(expression, *common_params) | |
| elif isinstance(expression, ast.Starred): | |
| return evaluate_ast(expression.value, *common_params) | |
| elif isinstance(expression, ast.BoolOp): | |
| # Boolean operation -> evaluate the operation | |
| return evaluate_boolop(expression, *common_params) | |
| elif isinstance(expression, ast.Break): | |
| raise BreakException() | |
| elif isinstance(expression, ast.Continue): | |
| raise ContinueException() | |
| elif isinstance(expression, ast.BinOp): | |
| # Binary operation -> execute operation | |
| return evaluate_binop(expression, *common_params) | |
| elif isinstance(expression, ast.Compare): | |
| # Comparison -> evaluate the comparison | |
| return evaluate_condition(expression, *common_params) | |
| elif isinstance(expression, ast.Lambda): | |
| return evaluate_lambda(expression, *common_params) | |
| elif isinstance(expression, ast.FunctionDef): | |
| return evaluate_function_def(expression, *common_params) | |
| elif isinstance(expression, ast.Dict): | |
| # Dict -> evaluate all keys and values | |
| keys = (evaluate_ast(k, *common_params) for k in expression.keys) | |
| values = (evaluate_ast(v, *common_params) for v in expression.values) | |
| return dict(zip(keys, values)) | |
| elif isinstance(expression, ast.Expr): | |
| # Expression -> evaluate the content | |
| return evaluate_ast(expression.value, *common_params) | |
| elif isinstance(expression, ast.For): | |
| # For loop -> execute the loop | |
| return evaluate_for(expression, *common_params) | |
| elif isinstance(expression, ast.FormattedValue): | |
| # Formatted value (part of f-string) -> evaluate the content and format it | |
| value = evaluate_ast(expression.value, *common_params) | |
| # Early return if no format spec | |
| if not expression.format_spec: | |
| return value | |
| # Apply format specification | |
| format_spec = evaluate_ast(expression.format_spec, *common_params) | |
| return format(value, format_spec) | |
| elif isinstance(expression, ast.If): | |
| # If -> execute the right branch | |
| return evaluate_if(expression, *common_params) | |
| elif hasattr(ast, "Index") and isinstance(expression, ast.Index): | |
| return evaluate_ast(expression.value, *common_params) | |
| elif isinstance(expression, ast.JoinedStr): | |
| return "".join([str(evaluate_ast(v, *common_params)) for v in expression.values]) | |
| elif isinstance(expression, ast.List): | |
| # List -> evaluate all elements | |
| return [evaluate_ast(elt, *common_params) for elt in expression.elts] | |
| elif isinstance(expression, ast.Name): | |
| # Name -> pick up the value in the state | |
| return evaluate_name(expression, *common_params) | |
| elif isinstance(expression, ast.Subscript): | |
| # Subscript -> return the value of the indexing | |
| return evaluate_subscript(expression, *common_params) | |
| elif isinstance(expression, ast.IfExp): | |
| test_val = evaluate_ast(expression.test, *common_params) | |
| if test_val: | |
| return evaluate_ast(expression.body, *common_params) | |
| else: | |
| return evaluate_ast(expression.orelse, *common_params) | |
| elif isinstance(expression, ast.Attribute): | |
| value = evaluate_ast(expression.value, *common_params) | |
| return getattr(value, expression.attr) | |
| elif isinstance(expression, ast.Slice): | |
| return slice( | |
| evaluate_ast(expression.lower, *common_params) if expression.lower is not None else None, | |
| evaluate_ast(expression.upper, *common_params) if expression.upper is not None else None, | |
| evaluate_ast(expression.step, *common_params) if expression.step is not None else None, | |
| ) | |
| elif isinstance(expression, ast.While): | |
| return evaluate_while(expression, *common_params) | |
| elif isinstance(expression, (ast.Import, ast.ImportFrom)): | |
| return import_modules(expression, state, authorized_imports) | |
| elif isinstance(expression, ast.ClassDef): | |
| return evaluate_class_def(expression, *common_params) | |
| elif isinstance(expression, ast.Try): | |
| return evaluate_try(expression, *common_params) | |
| elif isinstance(expression, ast.Raise): | |
| return evaluate_raise(expression, *common_params) | |
| elif isinstance(expression, ast.Assert): | |
| return evaluate_assert(expression, *common_params) | |
| elif isinstance(expression, ast.With): | |
| return evaluate_with(expression, *common_params) | |
| elif isinstance(expression, ast.Set): | |
| return set((evaluate_ast(elt, *common_params) for elt in expression.elts)) | |
| elif isinstance(expression, ast.Return): | |
| raise ReturnException(evaluate_ast(expression.value, *common_params) if expression.value else None) | |
| elif isinstance(expression, ast.Pass): | |
| return None | |
| elif isinstance(expression, ast.Delete): | |
| return evaluate_delete(expression, *common_params) | |
| else: | |
| # For now we refuse anything else. Let's add things as we need them. | |
| raise InterpreterError(f"{expression.__class__.__name__} is not supported.") | |
| class FinalAnswerException(Exception): | |
| def __init__(self, value): | |
| self.value = value | |
| def evaluate_python_code( | |
| code: str, | |
| static_tools: Optional[Dict[str, Callable]] = None, | |
| custom_tools: Optional[Dict[str, Callable]] = None, | |
| state: Optional[Dict[str, Any]] = None, | |
| authorized_imports: List[str] = BASE_BUILTIN_MODULES, | |
| max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT, | |
| ): | |
| """ | |
| Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set | |
| of functions. | |
| This function will recurse through the nodes of the tree provided. | |
| Args: | |
| code (`str`): | |
| The code to evaluate. | |
| static_tools (`Dict[str, Callable]`): | |
| The functions that may be called during the evaluation. These can also be agents in a multiagent setting. | |
| These tools cannot be overwritten in the code: any assignment to their name will raise an error. | |
| custom_tools (`Dict[str, Callable]`): | |
| The functions that may be called during the evaluation. | |
| These tools can be overwritten in the code: any assignment to their name will overwrite them. | |
| state (`Dict[str, Any]`): | |
| A dictionary mapping variable names to values. The `state` should contain the initial inputs but will be | |
| updated by this function to contain all variables as they are evaluated. | |
| The print outputs will be stored in the state under the key "_print_outputs". | |
| """ | |
| try: | |
| expression = ast.parse(code) | |
| except SyntaxError as e: | |
| raise InterpreterError( | |
| f"Code parsing failed on line {e.lineno} due to: {type(e).__name__}\n" | |
| f"{e.text}" | |
| f"{' ' * (e.offset or 0)}^\n" | |
| f"Error: {str(e)}" | |
| ) | |
| if state is None: | |
| state = {} | |
| static_tools = static_tools.copy() if static_tools is not None else {} | |
| custom_tools = custom_tools if custom_tools is not None else {} | |
| result = None | |
| state["_print_outputs"] = PrintContainer() | |
| if "final_answer" in static_tools: | |
| previous_final_answer = static_tools["final_answer"] | |
| def final_answer(value): | |
| raise FinalAnswerException(previous_final_answer(value)) | |
| static_tools["final_answer"] = final_answer | |
| try: | |
| for node in expression.body: | |
| result = evaluate_ast(node, state, static_tools, custom_tools, authorized_imports) | |
| state["_print_outputs"].value = truncate_content( | |
| str(state["_print_outputs"]), max_length=max_print_outputs_length | |
| ) | |
| is_final_answer = False | |
| return result, is_final_answer | |
| except FinalAnswerException as e: | |
| state["_print_outputs"].value = truncate_content( | |
| str(state["_print_outputs"]), max_length=max_print_outputs_length | |
| ) | |
| is_final_answer = True | |
| return e.value, is_final_answer | |
| except Exception as e: | |
| state["_print_outputs"].value = truncate_content( | |
| str(state["_print_outputs"]), max_length=max_print_outputs_length | |
| ) | |
| raise InterpreterError( | |
| f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {type(e).__name__}: {e}" | |
| ) | |
| class PythonExecutor: | |
| pass | |
| class LocalPythonExecutor(PythonExecutor): | |
| def __init__( | |
| self, | |
| additional_authorized_imports: List[str], | |
| max_print_outputs_length: Optional[int] = None, | |
| ): | |
| self.custom_tools = {} | |
| self.state = {} | |
| self.max_print_outputs_length = max_print_outputs_length | |
| if max_print_outputs_length is None: | |
| self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT | |
| self.additional_authorized_imports = additional_authorized_imports | |
| self.authorized_imports = list(set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)) | |
| # TODO: assert self.authorized imports are all installed locally | |
| self.static_tools = None | |
| def __call__(self, code_action: str) -> Tuple[Any, str, bool]: | |
| output, is_final_answer = evaluate_python_code( | |
| code_action, | |
| static_tools=self.static_tools, | |
| custom_tools=self.custom_tools, | |
| state=self.state, | |
| authorized_imports=self.authorized_imports, | |
| max_print_outputs_length=self.max_print_outputs_length, | |
| ) | |
| logs = str(self.state["_print_outputs"]) | |
| return output, logs, is_final_answer | |
| def send_variables(self, variables: dict): | |
| self.state.update(variables) | |
| def send_tools(self, tools: Dict[str, Tool]): | |
| self.static_tools = {**tools, **BASE_PYTHON_TOOLS.copy()} | |
| __all__ = ["evaluate_python_code", "LocalPythonExecutor"] | |