Spaces:
Running
Running
| import ast | |
| import contextlib | |
| import inspect | |
| import traceback | |
| from itertools import starmap | |
| from pathlib import Path | |
| from typing import Any | |
| from cachetools import TTLCache, keys | |
| from fastapi import HTTPException | |
| from loguru import logger | |
| from langflow.custom.eval import eval_custom_component_code | |
| from langflow.custom.schema import CallableCodeDetails, ClassCodeDetails, MissingDefault | |
| class CodeSyntaxError(HTTPException): | |
| pass | |
| def get_data_type(): | |
| from langflow.field_typing import Data | |
| return Data | |
| def find_class_ast_node(class_obj): | |
| """Finds the AST node corresponding to the given class object.""" | |
| # Get the source file where the class is defined | |
| source_file = inspect.getsourcefile(class_obj) | |
| if not source_file: | |
| return None, [] | |
| # Read the source code from the file | |
| source_code = Path(source_file).read_text(encoding="utf-8") | |
| # Parse the source code into an AST | |
| tree = ast.parse(source_code) | |
| # Search for the class definition node in the AST | |
| class_node = None | |
| import_nodes = [] | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.ClassDef) and node.name == class_obj.__name__: | |
| class_node = node | |
| elif isinstance(node, ast.Import | ast.ImportFrom): | |
| import_nodes.append(node) | |
| return class_node, import_nodes | |
| def imports_key(*args, **kwargs): | |
| imports = kwargs.pop("imports") | |
| key = keys.methodkey(*args, **kwargs) | |
| key += tuple(imports) | |
| return key | |
| class CodeParser: | |
| """A parser for Python source code, extracting code details.""" | |
| def __init__(self, code: str | type) -> None: | |
| """Initializes the parser with the provided code.""" | |
| self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60) | |
| if isinstance(code, type): | |
| if not inspect.isclass(code): | |
| msg = "The provided code must be a class." | |
| raise ValueError(msg) | |
| # If the code is a class, get its source code | |
| code = inspect.getsource(code) | |
| self.code = code | |
| self.data: dict[str, Any] = { | |
| "imports": [], | |
| "functions": [], | |
| "classes": [], | |
| "global_vars": [], | |
| } | |
| self.handlers = { | |
| ast.Import: self.parse_imports, | |
| ast.ImportFrom: self.parse_imports, | |
| ast.FunctionDef: self.parse_functions, | |
| ast.ClassDef: self.parse_classes, | |
| ast.Assign: self.parse_global_vars, | |
| } | |
| def get_tree(self): | |
| """Parses the provided code to validate its syntax. | |
| It tries to parse the code into an abstract syntax tree (AST). | |
| """ | |
| try: | |
| tree = ast.parse(self.code) | |
| except SyntaxError as err: | |
| raise CodeSyntaxError( | |
| status_code=400, | |
| detail={"error": err.msg, "traceback": traceback.format_exc()}, | |
| ) from err | |
| return tree | |
| def parse_node(self, node: ast.stmt | ast.AST) -> None: | |
| """Parses an AST node and updates the data dictionary with the relevant information.""" | |
| if handler := self.handlers.get(type(node)): | |
| handler(node) # type: ignore[operator] | |
| def parse_imports(self, node: ast.Import | ast.ImportFrom) -> None: | |
| """Extracts "imports" from the code, including aliases.""" | |
| if isinstance(node, ast.Import): | |
| for alias in node.names: | |
| if alias.asname: | |
| self.data["imports"].append(f"{alias.name} as {alias.asname}") | |
| else: | |
| self.data["imports"].append(alias.name) | |
| elif isinstance(node, ast.ImportFrom): | |
| for alias in node.names: | |
| if alias.asname: | |
| self.data["imports"].append((node.module, f"{alias.name} as {alias.asname}")) | |
| else: | |
| self.data["imports"].append((node.module, alias.name)) | |
| def parse_functions(self, node: ast.FunctionDef) -> None: | |
| """Extracts "functions" from the code.""" | |
| self.data["functions"].append(self.parse_callable_details(node)) | |
| def parse_arg(self, arg, default): | |
| """Parses an argument and its default value.""" | |
| arg_dict = {"name": arg.arg, "default": default} | |
| if arg.annotation: | |
| arg_dict["type"] = ast.unparse(arg.annotation) | |
| return arg_dict | |
| # @cachedmethod(operator.attrgetter("cache")) | |
| def construct_eval_env(self, return_type_str: str, imports) -> dict: | |
| """Constructs an evaluation environment. | |
| Constructs an evaluation environment with the necessary imports for the return type, | |
| taking into account module aliases. | |
| """ | |
| eval_env: dict = {} | |
| for import_entry in imports: | |
| if isinstance(import_entry, tuple): # from module import name | |
| module, name = import_entry | |
| if name in return_type_str: | |
| exec(f"import {module}", eval_env) | |
| exec(f"from {module} import {name}", eval_env) | |
| else: # import module | |
| module = import_entry | |
| alias = None | |
| if " as " in module: | |
| module, alias = module.split(" as ") | |
| if module in return_type_str or (alias and alias in return_type_str): | |
| exec(f"import {module} as {alias or module}", eval_env) | |
| return eval_env | |
| def parse_callable_details(self, node: ast.FunctionDef) -> dict[str, Any]: | |
| """Extracts details from a single function or method node.""" | |
| return_type = None | |
| if node.returns: | |
| return_type_str = ast.unparse(node.returns) | |
| eval_env = self.construct_eval_env(return_type_str, tuple(self.data["imports"])) | |
| # Handle cases where the type is not found in the constructed environment | |
| with contextlib.suppress(NameError): | |
| return_type = eval(return_type_str, eval_env) # noqa: S307 | |
| func = CallableCodeDetails( | |
| name=node.name, | |
| doc=ast.get_docstring(node), | |
| args=self.parse_function_args(node), | |
| body=self.parse_function_body(node), | |
| return_type=return_type, | |
| has_return=self.parse_return_statement(node), | |
| ) | |
| return func.model_dump() | |
| def parse_function_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]: | |
| """Parses the arguments of a function or method node.""" | |
| args = [] | |
| args += self.parse_positional_args(node) | |
| args += self.parse_varargs(node) | |
| args += self.parse_keyword_args(node) | |
| # Commented out because we don't want kwargs | |
| # showing up as fields in the frontend | |
| args += self.parse_kwargs(node) | |
| return args | |
| def parse_positional_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]: | |
| """Parses the positional arguments of a function or method node.""" | |
| num_args = len(node.args.args) | |
| num_defaults = len(node.args.defaults) | |
| num_missing_defaults = num_args - num_defaults | |
| missing_defaults = [MissingDefault()] * num_missing_defaults | |
| default_values = [ast.unparse(default).strip("'") if default else None for default in node.args.defaults] | |
| # Now check all default values to see if there | |
| # are any "None" values in the middle | |
| default_values = [None if value == "None" else value for value in default_values] | |
| defaults = missing_defaults + default_values | |
| return list(starmap(self.parse_arg, zip(node.args.args, defaults, strict=True))) | |
| def parse_varargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]: | |
| """Parses the *args argument of a function or method node.""" | |
| args = [] | |
| if node.args.vararg: | |
| args.append(self.parse_arg(node.args.vararg, None)) | |
| return args | |
| def parse_keyword_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]: | |
| """Parses the keyword-only arguments of a function or method node.""" | |
| kw_defaults = [None] * (len(node.args.kwonlyargs) - len(node.args.kw_defaults)) + [ | |
| ast.unparse(default) if default else None for default in node.args.kw_defaults | |
| ] | |
| return list(starmap(self.parse_arg, zip(node.args.kwonlyargs, kw_defaults, strict=True))) | |
| def parse_kwargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]: | |
| """Parses the **kwargs argument of a function or method node.""" | |
| args = [] | |
| if node.args.kwarg: | |
| args.append(self.parse_arg(node.args.kwarg, None)) | |
| return args | |
| def parse_function_body(self, node: ast.FunctionDef) -> list[str]: | |
| """Parses the body of a function or method node.""" | |
| return [ast.unparse(line) for line in node.body] | |
| def parse_return_statement(self, node: ast.FunctionDef) -> bool: | |
| """Parses the return statement of a function or method node, including nested returns.""" | |
| def has_return(node): | |
| if isinstance(node, ast.Return): | |
| return True | |
| if isinstance(node, ast.If): | |
| return any(has_return(child) for child in node.body) or any(has_return(child) for child in node.orelse) | |
| if isinstance(node, ast.Try): | |
| return ( | |
| any(has_return(child) for child in node.body) | |
| or any(has_return(child) for child in node.handlers) | |
| or any(has_return(child) for child in node.finalbody) | |
| ) | |
| if isinstance(node, ast.For | ast.While): | |
| return any(has_return(child) for child in node.body) or any(has_return(child) for child in node.orelse) | |
| if isinstance(node, ast.With): | |
| return any(has_return(child) for child in node.body) | |
| return False | |
| return any(has_return(child) for child in node.body) | |
| def parse_assign(self, stmt): | |
| """Parses an Assign statement and returns a dictionary with the target's name and value.""" | |
| for target in stmt.targets: | |
| if isinstance(target, ast.Name): | |
| return {"name": target.id, "value": ast.unparse(stmt.value)} | |
| return None | |
| def parse_ann_assign(self, stmt): | |
| """Parses an AnnAssign statement and returns a dictionary with the target's name, value, and annotation.""" | |
| if isinstance(stmt.target, ast.Name): | |
| return { | |
| "name": stmt.target.id, | |
| "value": ast.unparse(stmt.value) if stmt.value else None, | |
| "annotation": ast.unparse(stmt.annotation), | |
| } | |
| return None | |
| def parse_function_def(self, stmt): | |
| """Parse a FunctionDef statement. | |
| Parse a FunctionDef statement and return the parsed method and a boolean indicating if it's an __init__ method. | |
| """ | |
| method = self.parse_callable_details(stmt) | |
| return (method, True) if stmt.name == "__init__" else (method, False) | |
| def get_base_classes(self): | |
| """Returns the base classes of the custom component class.""" | |
| try: | |
| bases = self.execute_and_inspect_classes(self.code) | |
| except Exception: | |
| # If the code cannot be executed, return an empty list | |
| bases = [] | |
| raise | |
| return bases | |
| def parse_classes(self, node: ast.ClassDef) -> None: | |
| """Extracts "classes" from the code, including inheritance and init methods.""" | |
| bases = self.get_base_classes() | |
| nodes = [] | |
| for base in bases: | |
| if base.__name__ == node.name or base.__name__ in {"CustomComponent", "Component", "BaseComponent"}: | |
| continue | |
| try: | |
| class_node, import_nodes = find_class_ast_node(base) | |
| if class_node is None: | |
| continue | |
| for import_node in import_nodes: | |
| self.parse_imports(import_node) | |
| nodes.append(class_node) | |
| except Exception: # noqa: BLE001 | |
| logger.exception("Error finding base class node") | |
| nodes.insert(0, node) | |
| class_details = ClassCodeDetails( | |
| name=node.name, | |
| doc=ast.get_docstring(node), | |
| bases=[b.__name__ for b in bases], | |
| attributes=[], | |
| methods=[], | |
| init=None, | |
| ) | |
| for _node in nodes: | |
| self.process_class_node(_node, class_details) | |
| self.data["classes"].append(class_details.model_dump()) | |
| def process_class_node(self, node, class_details) -> None: | |
| for stmt in node.body: | |
| if isinstance(stmt, ast.Assign): | |
| if attr := self.parse_assign(stmt): | |
| class_details.attributes.append(attr) | |
| elif isinstance(stmt, ast.AnnAssign): | |
| if attr := self.parse_ann_assign(stmt): | |
| class_details.attributes.append(attr) | |
| elif isinstance(stmt, ast.FunctionDef | ast.AsyncFunctionDef): | |
| method, is_init = self.parse_function_def(stmt) | |
| if is_init: | |
| class_details.init = method | |
| else: | |
| class_details.methods.append(method) | |
| def parse_global_vars(self, node: ast.Assign) -> None: | |
| """Extracts global variables from the code.""" | |
| global_var = { | |
| "targets": [t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets], | |
| "value": ast.unparse(node.value), | |
| } | |
| self.data["global_vars"].append(global_var) | |
| def execute_and_inspect_classes(self, code: str): | |
| custom_component_class = eval_custom_component_code(code) | |
| custom_component = custom_component_class(_code=code) | |
| dunder_class = custom_component.__class__ | |
| # Get the base classes at two levels of inheritance | |
| bases = [] | |
| for base in dunder_class.__bases__: | |
| bases.append(base) | |
| bases.extend(base.__bases__) | |
| return bases | |
| def parse_code(self) -> dict[str, Any]: | |
| """Runs all parsing operations and returns the resulting data.""" | |
| tree = self.get_tree() | |
| for node in ast.walk(tree): | |
| self.parse_node(node) | |
| return self.data | |