|
|
from __future__ import annotations |
|
|
|
|
|
from typing import TYPE_CHECKING, Any, cast |
|
|
|
|
|
from loguru import logger |
|
|
|
|
|
from langflow.graph.edge.schema import EdgeData, SourceHandle, TargetHandle, TargetHandleDict |
|
|
from langflow.schema.schema import INPUT_FIELD_NAME |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langflow.graph.vertex.base import Vertex |
|
|
|
|
|
|
|
|
class Edge: |
|
|
def __init__(self, source: Vertex, target: Vertex, edge: EdgeData): |
|
|
self.source_id: str = source.id if source else "" |
|
|
self.target_id: str = target.id if target else "" |
|
|
self.valid_handles: bool = False |
|
|
self.target_param: str | None = None |
|
|
self._target_handle: TargetHandleDict | str | None = None |
|
|
self._data = edge.copy() |
|
|
self.is_cycle = False |
|
|
if data := edge.get("data", {}): |
|
|
self._source_handle = data.get("sourceHandle", {}) |
|
|
self._target_handle = cast("TargetHandleDict", data.get("targetHandle", {})) |
|
|
self.source_handle: SourceHandle = SourceHandle(**self._source_handle) |
|
|
if isinstance(self._target_handle, dict): |
|
|
try: |
|
|
self.target_handle: TargetHandle = TargetHandle(**self._target_handle) |
|
|
except Exception as e: |
|
|
if "inputTypes" in self._target_handle and self._target_handle["inputTypes"] is None: |
|
|
|
|
|
if hasattr(target, "custom_component"): |
|
|
display_name = getattr(target.custom_component, "display_name", "") |
|
|
msg = ( |
|
|
f"Component {display_name} field '{self._target_handle['fieldName']}' " |
|
|
"might not be a valid input." |
|
|
) |
|
|
raise ValueError(msg) from e |
|
|
msg = ( |
|
|
f"Field '{self._target_handle['fieldName']}' on {target.display_name} " |
|
|
"might not be a valid input." |
|
|
) |
|
|
raise ValueError(msg) from e |
|
|
raise |
|
|
|
|
|
else: |
|
|
msg = "Target handle is not a dictionary" |
|
|
raise ValueError(msg) |
|
|
self.target_param = self.target_handle.field_name |
|
|
|
|
|
self.validate_handles(source, target) |
|
|
else: |
|
|
|
|
|
logger.error("Edge data is empty") |
|
|
self._source_handle = edge.get("sourceHandle", "") |
|
|
self._target_handle = edge.get("targetHandle", "") |
|
|
|
|
|
|
|
|
if isinstance(self._target_handle, str): |
|
|
self.target_param = self._target_handle.split("|")[1] |
|
|
self.source_handle = None |
|
|
self.target_handle = None |
|
|
else: |
|
|
msg = "Target handle is not a string" |
|
|
raise ValueError(msg) |
|
|
|
|
|
self.validate_edge(source, target) |
|
|
|
|
|
def to_data(self): |
|
|
return self._data |
|
|
|
|
|
def validate_handles(self, source, target) -> None: |
|
|
if isinstance(self._source_handle, str) or self.source_handle.base_classes: |
|
|
self._legacy_validate_handles(source, target) |
|
|
else: |
|
|
self._validate_handles(source, target) |
|
|
|
|
|
def _validate_handles(self, source, target) -> None: |
|
|
if self.target_handle.input_types is None: |
|
|
self.valid_handles = self.target_handle.type in self.source_handle.output_types |
|
|
|
|
|
elif self.source_handle.output_types is not None: |
|
|
self.valid_handles = ( |
|
|
any(output_type in self.target_handle.input_types for output_type in self.source_handle.output_types) |
|
|
or self.target_handle.type in self.source_handle.output_types |
|
|
) |
|
|
|
|
|
if not self.valid_handles: |
|
|
logger.debug(self.source_handle) |
|
|
logger.debug(self.target_handle) |
|
|
msg = f"Edge between {source.display_name} and {target.display_name} has invalid handles" |
|
|
raise ValueError(msg) |
|
|
|
|
|
def _legacy_validate_handles(self, source, target) -> None: |
|
|
if self.target_handle.input_types is None: |
|
|
self.valid_handles = self.target_handle.type in self.source_handle.base_classes |
|
|
else: |
|
|
self.valid_handles = ( |
|
|
any(baseClass in self.target_handle.input_types for baseClass in self.source_handle.base_classes) |
|
|
or self.target_handle.type in self.source_handle.base_classes |
|
|
) |
|
|
if not self.valid_handles: |
|
|
logger.debug(self.source_handle) |
|
|
logger.debug(self.target_handle) |
|
|
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has invalid handles" |
|
|
raise ValueError(msg) |
|
|
|
|
|
def __setstate__(self, state): |
|
|
self.source_id = state["source_id"] |
|
|
self.target_id = state["target_id"] |
|
|
self.target_param = state["target_param"] |
|
|
self.source_handle = state.get("source_handle") |
|
|
self.target_handle = state.get("target_handle") |
|
|
self._source_handle = state.get("_source_handle") |
|
|
self._target_handle = state.get("_target_handle") |
|
|
self._data = state.get("_data") |
|
|
self.valid_handles = state.get("valid_handles") |
|
|
self.source_types = state.get("source_types") |
|
|
self.target_reqs = state.get("target_reqs") |
|
|
self.matched_type = state.get("matched_type") |
|
|
|
|
|
def validate_edge(self, source, target) -> None: |
|
|
|
|
|
|
|
|
if isinstance(self._source_handle, str) or self.source_handle.base_classes: |
|
|
self._legacy_validate_edge(source, target) |
|
|
else: |
|
|
self._validate_edge(source, target) |
|
|
|
|
|
def _validate_edge(self, source, target) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.source_types = [output for output in source.outputs if output["name"] == self.source_handle.name] |
|
|
self.target_reqs = target.required_inputs + target.optional_inputs |
|
|
|
|
|
|
|
|
|
|
|
self.valid = any( |
|
|
any(output_type in target_req for output_type in output["types"]) |
|
|
for output in self.source_types |
|
|
for target_req in self.target_reqs |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.matched_type = next( |
|
|
( |
|
|
output_type |
|
|
for output in self.source_types |
|
|
for output_type in output["types"] |
|
|
for target_req in self.target_reqs |
|
|
if output_type in target_req |
|
|
), |
|
|
None, |
|
|
) |
|
|
no_matched_type = self.matched_type is None |
|
|
if no_matched_type: |
|
|
logger.debug(self.source_types) |
|
|
logger.debug(self.target_reqs) |
|
|
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type." |
|
|
raise ValueError(msg) |
|
|
|
|
|
def _legacy_validate_edge(self, source, target) -> None: |
|
|
|
|
|
|
|
|
self.source_types = source.output |
|
|
self.target_reqs = target.required_inputs + target.optional_inputs |
|
|
|
|
|
|
|
|
|
|
|
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs) |
|
|
|
|
|
|
|
|
self.matched_type = next( |
|
|
(output for output in self.source_types if output in self.target_reqs), |
|
|
None, |
|
|
) |
|
|
no_matched_type = self.matched_type is None |
|
|
if no_matched_type: |
|
|
logger.debug(self.source_types) |
|
|
logger.debug(self.target_reqs) |
|
|
msg = f"Edge between {source.vertex_type} and {target.vertex_type} has no matched type" |
|
|
raise ValueError(msg) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
if (hasattr(self, "source_handle") and self.source_handle) and ( |
|
|
hasattr(self, "target_handle") and self.target_handle |
|
|
): |
|
|
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.field_name}]-> {self.target_id}" |
|
|
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}" |
|
|
|
|
|
def __hash__(self) -> int: |
|
|
return hash(self.__repr__()) |
|
|
|
|
|
def __eq__(self, /, other: object) -> bool: |
|
|
if not isinstance(other, Edge): |
|
|
return False |
|
|
return ( |
|
|
self._source_handle == other._source_handle |
|
|
and self._target_handle == other._target_handle |
|
|
and self.target_param == other.target_param |
|
|
) |
|
|
|
|
|
def __str__(self) -> str: |
|
|
return self.__repr__() |
|
|
|
|
|
|
|
|
class CycleEdge(Edge): |
|
|
def __init__(self, source: Vertex, target: Vertex, raw_edge: EdgeData): |
|
|
super().__init__(source, target, raw_edge) |
|
|
self.is_fulfilled = False |
|
|
self.result: Any = None |
|
|
self.is_cycle = True |
|
|
source.has_cycle_edges = True |
|
|
target.has_cycle_edges = True |
|
|
|
|
|
async def honor(self, source: Vertex, target: Vertex) -> None: |
|
|
"""Fulfills the contract by setting the result of the source vertex to the target vertex's parameter. |
|
|
|
|
|
If the edge is runnable, the source vertex is run with the message text and the target vertex's |
|
|
root_field param is set to the |
|
|
result. If the edge is not runnable, the target vertex's parameter is set to the result. |
|
|
:param message: The message object to be processed if the edge is runnable. |
|
|
""" |
|
|
if self.is_fulfilled: |
|
|
return |
|
|
|
|
|
if not source.built: |
|
|
|
|
|
|
|
|
msg = f"Source vertex {source.id} is not built." |
|
|
raise ValueError(msg) |
|
|
|
|
|
if self.matched_type == "Text": |
|
|
self.result = source.built_result |
|
|
else: |
|
|
self.result = source.built_object |
|
|
|
|
|
target.params[self.target_param] = self.result |
|
|
self.is_fulfilled = True |
|
|
|
|
|
async def get_result_from_source(self, source: Vertex, target: Vertex): |
|
|
|
|
|
if not self.is_fulfilled: |
|
|
await self.honor(source, target) |
|
|
|
|
|
|
|
|
if ( |
|
|
target.vertex_type == "ChatOutput" |
|
|
and isinstance(target.params.get(INPUT_FIELD_NAME), str | dict) |
|
|
and target.params.get("message") == "" |
|
|
): |
|
|
return self.result |
|
|
return self.result |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
str_repr = super().__repr__() |
|
|
|
|
|
return f"{str_repr} 🔄" |
|
|
|