Spaces:
Running
Running
| from __future__ import annotations | |
| from typing import TYPE_CHECKING, Any, cast | |
| from loguru import logger | |
| from langflow.graph.edge.schema import EdgeData, SourceHandle, TargetHandle, TargetHandleDict | |
| from langflow.schema.schema import INPUT_FIELD_NAME | |
| if TYPE_CHECKING: | |
| from langflow.graph.vertex.base import Vertex | |
| class Edge: | |
| def __init__(self, source: Vertex, target: Vertex, edge: EdgeData): | |
| self.source_id: str = source.id if source else "" | |
| self.target_id: str = target.id if target else "" | |
| self.valid_handles: bool = False | |
| self.target_param: str | None = None | |
| self._target_handle: TargetHandleDict | str | None = None | |
| self._data = edge.copy() | |
| self.is_cycle = False | |
| if data := edge.get("data", {}): | |
| self._source_handle = data.get("sourceHandle", {}) | |
| self._target_handle = cast("TargetHandleDict", data.get("targetHandle", {})) | |
| self.source_handle: SourceHandle = SourceHandle(**self._source_handle) | |
| if isinstance(self._target_handle, dict): | |
| try: | |
| self.target_handle: TargetHandle = TargetHandle(**self._target_handle) | |
| except Exception as e: | |
| if "inputTypes" in self._target_handle and self._target_handle["inputTypes"] is None: | |
| # Check if self._target_handle['fieldName'] | |
| if hasattr(target, "custom_component"): | |
| display_name = getattr(target.custom_component, "display_name", "") | |
| msg = ( | |
| f"Component {display_name} field '{self._target_handle['fieldName']}' " | |
| "might not be a valid input." | |
| ) | |
| raise ValueError(msg) from e | |
| msg = ( | |
| f"Field '{self._target_handle['fieldName']}' on {target.display_name} " | |
| "might not be a valid input." | |
| ) | |
| raise ValueError(msg) from e | |
| raise | |
| else: | |
| msg = "Target handle is not a dictionary" | |
| raise ValueError(msg) | |
| self.target_param = self.target_handle.field_name | |
| # validate handles | |
| self.validate_handles(source, target) | |
| else: | |
| # Logging here because this is a breaking change | |
| logger.error("Edge data is empty") | |
| self._source_handle = edge.get("sourceHandle", "") # type: ignore[assignment] | |
| self._target_handle = edge.get("targetHandle", "") # type: ignore[assignment] | |
| # 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD' | |
| # target_param is documents | |
| if isinstance(self._target_handle, str): | |
| self.target_param = self._target_handle.split("|")[1] | |
| self.source_handle = None | |
| self.target_handle = None | |
| else: | |
| msg = "Target handle is not a string" | |
| raise ValueError(msg) | |
| # Validate in __init__ to fail fast | |
| self.validate_edge(source, target) | |
| def to_data(self): | |
| return self._data | |
| def validate_handles(self, source, target) -> None: | |
| if isinstance(self._source_handle, str) or self.source_handle.base_classes: | |
| self._legacy_validate_handles(source, target) | |
| else: | |
| self._validate_handles(source, target) | |
| def _validate_handles(self, source, target) -> None: | |
| if self.target_handle.input_types is None: | |
| self.valid_handles = self.target_handle.type in self.source_handle.output_types | |
| elif self.source_handle.output_types is not None: | |
| self.valid_handles = ( | |
| any(output_type in self.target_handle.input_types for output_type in self.source_handle.output_types) | |
| or self.target_handle.type in self.source_handle.output_types | |
| ) | |
| if not self.valid_handles: | |
| logger.debug(self.source_handle) | |
| logger.debug(self.target_handle) | |
| msg = f"Edge between {source.display_name} and {target.display_name} has invalid handles" | |
| raise ValueError(msg) | |
| def _legacy_validate_handles(self, source, target) -> None: | |
| if self.target_handle.input_types is None: | |
| self.valid_handles = self.target_handle.type in self.source_handle.base_classes | |
| else: | |
| self.valid_handles = ( | |
| any(baseClass in self.target_handle.input_types for baseClass in self.source_handle.base_classes) | |
| or self.target_handle.type in self.source_handle.base_classes | |
| ) | |
| if not self.valid_handles: | |
| logger.debug(self.source_handle) | |
| logger.debug(self.target_handle) | |
| msg = f"Edge between {source.vertex_type} and {target.vertex_type} has invalid handles" | |
| raise ValueError(msg) | |
| def __setstate__(self, state): | |
| self.source_id = state["source_id"] | |
| self.target_id = state["target_id"] | |
| self.target_param = state["target_param"] | |
| self.source_handle = state.get("source_handle") | |
| self.target_handle = state.get("target_handle") | |
| self._source_handle = state.get("_source_handle") | |
| self._target_handle = state.get("_target_handle") | |
| self._data = state.get("_data") | |
| self.valid_handles = state.get("valid_handles") | |
| self.source_types = state.get("source_types") | |
| self.target_reqs = state.get("target_reqs") | |
| self.matched_type = state.get("matched_type") | |
| def validate_edge(self, source, target) -> None: | |
| # If the self.source_handle has base_classes, then we are using the legacy | |
| # way of defining the source and target handles | |
| if isinstance(self._source_handle, str) or self.source_handle.base_classes: | |
| self._legacy_validate_edge(source, target) | |
| else: | |
| self._validate_edge(source, target) | |
| def _validate_edge(self, source, target) -> None: | |
| # Validate that the outputs of the source node are valid inputs | |
| # for the target node | |
| # .outputs is a list of Output objects as dictionaries | |
| # meaning: check for "types" key in each dictionary | |
| self.source_types = [output for output in source.outputs if output["name"] == self.source_handle.name] | |
| self.target_reqs = target.required_inputs + target.optional_inputs | |
| # Both lists contain strings and sometimes a string contains the value we are | |
| # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] | |
| # so we need to check if any of the strings in source_types is in target_reqs | |
| self.valid = any( | |
| any(output_type in target_req for output_type in output["types"]) | |
| for output in self.source_types | |
| for target_req in self.target_reqs | |
| ) | |
| # Get what type of input the target node is expecting | |
| # Update the matched type to be the first found match | |
| self.matched_type = next( | |
| ( | |
| output_type | |
| for output in self.source_types | |
| for output_type in output["types"] | |
| for target_req in self.target_reqs | |
| if output_type in target_req | |
| ), | |
| None, | |
| ) | |
| no_matched_type = self.matched_type is None | |
| if no_matched_type: | |
| logger.debug(self.source_types) | |
| logger.debug(self.target_reqs) | |
| msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type." | |
| raise ValueError(msg) | |
| def _legacy_validate_edge(self, source, target) -> None: | |
| # Validate that the outputs of the source node are valid inputs | |
| # for the target node | |
| self.source_types = source.output | |
| self.target_reqs = target.required_inputs + target.optional_inputs | |
| # Both lists contain strings and sometimes a string contains the value we are | |
| # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] | |
| # so we need to check if any of the strings in source_types is in target_reqs | |
| self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs) | |
| # Get what type of input the target node is expecting | |
| self.matched_type = next( | |
| (output for output in self.source_types if output in self.target_reqs), | |
| None, | |
| ) | |
| no_matched_type = self.matched_type is None | |
| if no_matched_type: | |
| logger.debug(self.source_types) | |
| logger.debug(self.target_reqs) | |
| msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type" | |
| raise ValueError(msg) | |
| def __repr__(self) -> str: | |
| if (hasattr(self, "source_handle") and self.source_handle) and ( | |
| hasattr(self, "target_handle") and self.target_handle | |
| ): | |
| return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.field_name}]-> {self.target_id}" | |
| return f"{self.source_id} -[{self.target_param}]-> {self.target_id}" | |
| def __hash__(self) -> int: | |
| return hash(self.__repr__()) | |
| def __eq__(self, /, other: object) -> bool: | |
| if not isinstance(other, Edge): | |
| return False | |
| return ( | |
| self._source_handle == other._source_handle | |
| and self._target_handle == other._target_handle | |
| and self.target_param == other.target_param | |
| ) | |
| def __str__(self) -> str: | |
| return self.__repr__() | |
| class CycleEdge(Edge): | |
| def __init__(self, source: Vertex, target: Vertex, raw_edge: EdgeData): | |
| super().__init__(source, target, raw_edge) | |
| self.is_fulfilled = False # Whether the contract has been fulfilled. | |
| self.result: Any = None | |
| self.is_cycle = True | |
| source.has_cycle_edges = True | |
| target.has_cycle_edges = True | |
| async def honor(self, source: Vertex, target: Vertex) -> None: | |
| """Fulfills the contract by setting the result of the source vertex to the target vertex's parameter. | |
| If the edge is runnable, the source vertex is run with the message text and the target vertex's | |
| root_field param is set to the | |
| result. If the edge is not runnable, the target vertex's parameter is set to the result. | |
| :param message: The message object to be processed if the edge is runnable. | |
| """ | |
| if self.is_fulfilled: | |
| return | |
| if not source.built: | |
| # The system should be read-only, so we should not be building vertices | |
| # that are not already built. | |
| msg = f"Source vertex {source.id} is not built." | |
| raise ValueError(msg) | |
| if self.matched_type == "Text": | |
| self.result = source.built_result | |
| else: | |
| self.result = source.built_object | |
| target.params[self.target_param] = self.result | |
| self.is_fulfilled = True | |
| async def get_result_from_source(self, source: Vertex, target: Vertex): | |
| # Fulfill the contract if it has not been fulfilled. | |
| if not self.is_fulfilled: | |
| await self.honor(source, target) | |
| # If the target vertex is a power component we log messages | |
| if ( | |
| target.vertex_type == "ChatOutput" | |
| and isinstance(target.params.get(INPUT_FIELD_NAME), str | dict) | |
| and target.params.get("message") == "" | |
| ): | |
| return self.result | |
| return self.result | |
| def __repr__(self) -> str: | |
| str_repr = super().__repr__() | |
| # Add a symbol to show this is a cycle edge | |
| return f"{str_repr} 🔄" | |