import ast import json from typing import Any from langchain.agents import Tool from langchain_core.tools import StructuredTool from loguru import logger from pydantic.v1 import Field, create_model from pydantic.v1.fields import Undefined from typing_extensions import override from langflow.base.langchain_utilities.model import LCToolComponent from langflow.inputs.inputs import ( BoolInput, DropdownInput, FieldTypes, HandleInput, MessageTextInput, MultilineInput, ) from langflow.io import Output from langflow.schema import Data from langflow.schema.dotdict import dotdict class PythonCodeStructuredTool(LCToolComponent): DEFAULT_KEYS = [ "code", "_type", "text_key", "tool_code", "tool_name", "tool_description", "return_direct", "tool_function", "global_variables", "_classes", "_functions", ] display_name = "Python Code Structured" description = "structuredtool dataclass code to tool" documentation = "https://python.langchain.com/docs/modules/tools/custom_tools/#structuredtool-dataclass" name = "PythonCodeStructuredTool" icon = "Python" field_order = ["name", "description", "tool_code", "return_direct", "tool_function"] legacy: bool = True inputs = [ MultilineInput( name="tool_code", display_name="Tool Code", info="Enter the dataclass code.", placeholder="def my_function(args):\n pass", required=True, real_time_refresh=True, refresh_button=True, ), MessageTextInput( name="tool_name", display_name="Tool Name", info="Enter the name of the tool.", required=True, ), MessageTextInput( name="tool_description", display_name="Description", info="Enter the description of the tool.", required=True, ), BoolInput( name="return_direct", display_name="Return Directly", info="Should the tool return the function output directly?", ), DropdownInput( name="tool_function", display_name="Tool Function", info="Select the function for additional expressions.", options=[], required=True, real_time_refresh=True, refresh_button=True, ), HandleInput( name="global_variables", display_name="Global Variables", info="Enter the global variables or Create Data Component.", input_types=["Data"], field_type=FieldTypes.DICT, is_list=True, ), MessageTextInput(name="_classes", display_name="Classes", advanced=True), MessageTextInput(name="_functions", display_name="Functions", advanced=True), ] outputs = [ Output(display_name="Tool", name="result_tool", method="build_tool"), ] @override def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None) -> dotdict: if field_name is None: return build_config if field_name not in {"tool_code", "tool_function"}: return build_config try: named_functions = {} [classes, functions] = self._parse_code(build_config["tool_code"]["value"]) existing_fields = {} if len(build_config) > len(self.DEFAULT_KEYS): for key in build_config.copy(): if key not in self.DEFAULT_KEYS: existing_fields[key] = build_config.pop(key) names = [] for func in functions: named_functions[func["name"]] = func names.append(func["name"]) for arg in func["args"]: field_name = f"{func['name']}|{arg['name']}" if field_name in existing_fields: build_config[field_name] = existing_fields[field_name] continue field = MessageTextInput( display_name=f"{arg['name']}: Description", name=field_name, info=f"Enter the description for {arg['name']}", required=True, ) build_config[field_name] = field.to_dict() build_config["_functions"]["value"] = json.dumps(named_functions) build_config["_classes"]["value"] = json.dumps(classes) build_config["tool_function"]["options"] = names except Exception as e: # noqa: BLE001 self.status = f"Failed to extract names: {e}" logger.opt(exception=True).debug(self.status) build_config["tool_function"]["options"] = ["Failed to parse", str(e)] return build_config async def build_tool(self) -> Tool: local_namespace = {} # type: ignore[var-annotated] modules = self._find_imports(self.tool_code) import_code = "" for module in modules["imports"]: import_code += f"global {module}\nimport {module}\n" for from_module in modules["from_imports"]: for alias in from_module.names: import_code += f"global {alias.name}\n" import_code += ( f"from {from_module.module} import {', '.join([alias.name for alias in from_module.names])}\n" ) exec(import_code, globals()) exec(self.tool_code, globals(), local_namespace) class PythonCodeToolFunc: params: dict = {} def run(**kwargs): for key, arg in kwargs.items(): if key not in PythonCodeToolFunc.params: PythonCodeToolFunc.params[key] = arg return local_namespace[self.tool_function](**PythonCodeToolFunc.params) globals_ = globals() local = {} local[self.tool_function] = PythonCodeToolFunc globals_.update(local) if isinstance(self.global_variables, list): for data in self.global_variables: if isinstance(data, Data): globals_.update(data.data) elif isinstance(self.global_variables, dict): globals_.update(self.global_variables) classes = json.loads(self._attributes["_classes"]) for class_dict in classes: exec("\n".join(class_dict["code"]), globals_) named_functions = json.loads(self._attributes["_functions"]) schema_fields = {} for attr in self._attributes: if attr in self.DEFAULT_KEYS: continue func_name = attr.split("|")[0] field_name = attr.split("|")[1] func_arg = self._find_arg(named_functions, func_name, field_name) if func_arg is None: msg = f"Failed to find arg: {field_name}" raise ValueError(msg) field_annotation = func_arg["annotation"] field_description = self._get_value(self._attributes[attr], str) if field_annotation: exec(f"temp_annotation_type = {field_annotation}", globals_) schema_annotation = globals_["temp_annotation_type"] else: schema_annotation = Any schema_fields[field_name] = ( schema_annotation, Field( default=func_arg.get("default", Undefined), description=field_description, ), ) if "temp_annotation_type" in globals_: globals_.pop("temp_annotation_type") python_code_tool_schema = None if schema_fields: python_code_tool_schema = create_model("PythonCodeToolSchema", **schema_fields) return StructuredTool.from_function( func=local[self.tool_function].run, args_schema=python_code_tool_schema, name=self.tool_name, description=self.tool_description, return_direct=self.return_direct, ) def post_code_processing(self, new_frontend_node: dict, current_frontend_node: dict): """This function is called after the code validation is done.""" frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node) frontend_node["template"] = self.update_build_config( frontend_node["template"], frontend_node["template"]["tool_code"]["value"], "tool_code", ) frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node) for key in frontend_node["template"]: if key in self.DEFAULT_KEYS: continue frontend_node["template"] = self.update_build_config( frontend_node["template"], frontend_node["template"][key]["value"], key ) frontend_node = super().post_code_processing(new_frontend_node, current_frontend_node) return frontend_node def _parse_code(self, code: str) -> tuple[list[dict], list[dict]]: parsed_code = ast.parse(code) lines = code.split("\n") classes = [] functions = [] for node in parsed_code.body: if isinstance(node, ast.ClassDef): class_lines = lines[node.lineno - 1 : node.end_lineno] class_lines[-1] = class_lines[-1][: node.end_col_offset] class_lines[0] = class_lines[0][node.col_offset :] classes.append( { "name": node.name, "code": class_lines, } ) continue if not isinstance(node, ast.FunctionDef): continue func = {"name": node.name, "args": []} for arg in node.args.args: if arg.lineno != arg.end_lineno: msg = "Multiline arguments are not supported" raise ValueError(msg) func_arg = { "name": arg.arg, "annotation": None, } for default in node.args.defaults: if ( arg.lineno > default.lineno or arg.col_offset > default.col_offset or ( arg.end_lineno is not None and default.end_lineno is not None and arg.end_lineno < default.end_lineno ) or ( arg.end_col_offset is not None and default.end_col_offset is not None and arg.end_col_offset < default.end_col_offset ) ): continue if isinstance(default, ast.Name): func_arg["default"] = default.id elif isinstance(default, ast.Constant): func_arg["default"] = default.value if arg.annotation: annotation_line = lines[arg.annotation.lineno - 1] annotation_line = annotation_line[: arg.annotation.end_col_offset] annotation_line = annotation_line[arg.annotation.col_offset :] func_arg["annotation"] = annotation_line if isinstance(func_arg["annotation"], str) and func_arg["annotation"].count("=") > 0: func_arg["annotation"] = "=".join(func_arg["annotation"].split("=")[:-1]).strip() if isinstance(func["args"], list): func["args"].append(func_arg) functions.append(func) return classes, functions def _find_imports(self, code: str) -> dotdict: imports: list[str] = [] from_imports = [] parsed_code = ast.parse(code) for node in parsed_code.body: if isinstance(node, ast.Import): imports.extend(alias.name for alias in node.names) elif isinstance(node, ast.ImportFrom): from_imports.append(node) return dotdict({"imports": imports, "from_imports": from_imports}) def _get_value(self, value: Any, annotation: Any) -> Any: return value if isinstance(value, annotation) else value["value"] def _find_arg(self, named_functions: dict, func_name: str, arg_name: str) -> dict | None: for arg in named_functions[func_name]["args"]: if arg["name"] == arg_name: return arg return None