Spaces:
Sleeping
Sleeping
| import ast | |
| import builtins | |
| from itertools import zip_longest | |
| from typing import Set | |
| from .utils import BASE_BUILTIN_MODULES, get_source | |
| _BUILTIN_NAMES = set(vars(builtins)) | |
| class MethodChecker(ast.NodeVisitor): | |
| """ | |
| Checks that a method | |
| - only uses defined names | |
| - contains no local imports (e.g. numpy is ok but local_script is not) | |
| """ | |
| def __init__(self, class_attributes: Set[str], check_imports: bool = True): | |
| self.undefined_names = set() | |
| self.imports = {} | |
| self.from_imports = {} | |
| self.assigned_names = set() | |
| self.arg_names = set() | |
| self.class_attributes = class_attributes | |
| self.errors = [] | |
| self.check_imports = check_imports | |
| self.typing_names = {"Any"} | |
| def visit_arguments(self, node): | |
| """Collect function arguments""" | |
| self.arg_names = {arg.arg for arg in node.args} | |
| if node.kwarg: | |
| self.arg_names.add(node.kwarg.arg) | |
| if node.vararg: | |
| self.arg_names.add(node.vararg.arg) | |
| def visit_Import(self, node): | |
| for name in node.names: | |
| actual_name = name.asname or name.name | |
| self.imports[actual_name] = name.name | |
| def visit_ImportFrom(self, node): | |
| module = node.module or "" | |
| for name in node.names: | |
| actual_name = name.asname or name.name | |
| self.from_imports[actual_name] = (module, name.name) | |
| def visit_Assign(self, node): | |
| for target in node.targets: | |
| if isinstance(target, ast.Name): | |
| self.assigned_names.add(target.id) | |
| self.visit(node.value) | |
| def visit_With(self, node): | |
| """Track aliases in 'with' statements (the 'y' in 'with X as y')""" | |
| for item in node.items: | |
| if item.optional_vars: # This is the 'y' in 'with X as y' | |
| if isinstance(item.optional_vars, ast.Name): | |
| self.assigned_names.add(item.optional_vars.id) | |
| self.generic_visit(node) | |
| def visit_ExceptHandler(self, node): | |
| """Track exception aliases (the 'e' in 'except Exception as e')""" | |
| if node.name: # This is the 'e' in 'except Exception as e' | |
| self.assigned_names.add(node.name) | |
| self.generic_visit(node) | |
| def visit_AnnAssign(self, node): | |
| """Track annotated assignments.""" | |
| if isinstance(node.target, ast.Name): | |
| self.assigned_names.add(node.target.id) | |
| if node.value: | |
| self.visit(node.value) | |
| def visit_For(self, node): | |
| target = node.target | |
| if isinstance(target, ast.Name): | |
| self.assigned_names.add(target.id) | |
| elif isinstance(target, ast.Tuple): | |
| for elt in target.elts: | |
| if isinstance(elt, ast.Name): | |
| self.assigned_names.add(elt.id) | |
| self.generic_visit(node) | |
| def _handle_comprehension_generators(self, generators): | |
| """Helper method to handle generators in all types of comprehensions""" | |
| for generator in generators: | |
| if isinstance(generator.target, ast.Name): | |
| self.assigned_names.add(generator.target.id) | |
| elif isinstance(generator.target, ast.Tuple): | |
| for elt in generator.target.elts: | |
| if isinstance(elt, ast.Name): | |
| self.assigned_names.add(elt.id) | |
| def visit_ListComp(self, node): | |
| """Track variables in list comprehensions""" | |
| self._handle_comprehension_generators(node.generators) | |
| self.generic_visit(node) | |
| def visit_DictComp(self, node): | |
| """Track variables in dictionary comprehensions""" | |
| self._handle_comprehension_generators(node.generators) | |
| self.generic_visit(node) | |
| def visit_SetComp(self, node): | |
| """Track variables in set comprehensions""" | |
| self._handle_comprehension_generators(node.generators) | |
| self.generic_visit(node) | |
| def visit_Attribute(self, node): | |
| if not (isinstance(node.value, ast.Name) and node.value.id == "self"): | |
| self.generic_visit(node) | |
| def visit_Name(self, node): | |
| if isinstance(node.ctx, ast.Load): | |
| if not ( | |
| node.id in _BUILTIN_NAMES | |
| or node.id in BASE_BUILTIN_MODULES | |
| or node.id in self.arg_names | |
| or node.id == "self" | |
| or node.id in self.class_attributes | |
| or node.id in self.imports | |
| or node.id in self.from_imports | |
| or node.id in self.assigned_names | |
| or node.id in self.typing_names | |
| ): | |
| self.errors.append(f"Name '{node.id}' is undefined.") | |
| def visit_Call(self, node): | |
| if isinstance(node.func, ast.Name): | |
| if not ( | |
| node.func.id in _BUILTIN_NAMES | |
| or node.func.id in BASE_BUILTIN_MODULES | |
| or node.func.id in self.arg_names | |
| or node.func.id == "self" | |
| or node.func.id in self.class_attributes | |
| or node.func.id in self.imports | |
| or node.func.id in self.from_imports | |
| or node.func.id in self.assigned_names | |
| ): | |
| self.errors.append(f"Name '{node.func.id}' is undefined.") | |
| self.generic_visit(node) | |
| def validate_tool_attributes(cls, check_imports: bool = True) -> None: | |
| """ | |
| Validates that a Tool class follows the proper patterns: | |
| 0. Any argument of __init__ should have a default. | |
| Args chosen at init are not traceable, so we cannot rebuild the source code for them, thus any important arg should be defined as a class attribute. | |
| 1. About the class: | |
| - Class attributes should only be strings or dicts | |
| - Class attributes cannot be complex attributes | |
| 2. About all class methods: | |
| - Imports must be from packages, not local files | |
| - All methods must be self-contained | |
| Raises all errors encountered, if no error returns None. | |
| """ | |
| class ClassLevelChecker(ast.NodeVisitor): | |
| def __init__(self): | |
| self.imported_names = set() | |
| self.complex_attributes = set() | |
| self.class_attributes = set() | |
| self.non_defaults = set() | |
| self.non_literal_defaults = set() | |
| self.in_method = False | |
| def visit_FunctionDef(self, node): | |
| if node.name == "__init__": | |
| self._check_init_function_parameters(node) | |
| old_context = self.in_method | |
| self.in_method = True | |
| self.generic_visit(node) | |
| self.in_method = old_context | |
| def visit_Assign(self, node): | |
| if self.in_method: | |
| return | |
| # Track class attributes | |
| for target in node.targets: | |
| if isinstance(target, ast.Name): | |
| self.class_attributes.add(target.id) | |
| # Check if the assignment is more complex than simple literals | |
| if not all( | |
| isinstance(val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)) | |
| for val in ast.walk(node.value) | |
| ): | |
| for target in node.targets: | |
| if isinstance(target, ast.Name): | |
| self.complex_attributes.add(target.id) | |
| def _check_init_function_parameters(self, node): | |
| # Check defaults in parameters | |
| for arg, default in reversed(list(zip_longest(reversed(node.args.args), reversed(node.args.defaults)))): | |
| if default is None: | |
| if arg.arg != "self": | |
| self.non_defaults.add(arg.arg) | |
| elif not isinstance(default, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)): | |
| self.non_literal_defaults.add(arg.arg) | |
| class_level_checker = ClassLevelChecker() | |
| source = get_source(cls) | |
| tree = ast.parse(source) | |
| class_node = tree.body[0] | |
| if not isinstance(class_node, ast.ClassDef): | |
| raise ValueError("Source code must define a class") | |
| class_level_checker.visit(class_node) | |
| errors = [] | |
| if class_level_checker.complex_attributes: | |
| errors.append( | |
| f"Complex attributes should be defined in __init__, not as class attributes: " | |
| f"{', '.join(class_level_checker.complex_attributes)}" | |
| ) | |
| if class_level_checker.non_defaults: | |
| errors.append( | |
| f"Parameters in __init__ must have default values, found required parameters: " | |
| f"{', '.join(class_level_checker.non_defaults)}" | |
| ) | |
| if class_level_checker.non_literal_defaults: | |
| errors.append( | |
| f"Parameters in __init__ must have literal default values, found non-literal defaults: " | |
| f"{', '.join(class_level_checker.non_literal_defaults)}" | |
| ) | |
| # Run checks on all methods | |
| for node in class_node.body: | |
| if isinstance(node, ast.FunctionDef): | |
| method_checker = MethodChecker(class_level_checker.class_attributes, check_imports=check_imports) | |
| method_checker.visit(node) | |
| errors += [f"- {node.name}: {error}" for error in method_checker.errors] | |
| if errors: | |
| raise ValueError(f"Tool validation failed for {cls.__name__}:\n" + "\n".join(errors)) | |
| return | |