Spaces:
Running
Running
| from __future__ import annotations | |
| import ast | |
| import asyncio | |
| import inspect | |
| import os | |
| import traceback | |
| import types | |
| from collections.abc import AsyncIterator, Callable, Iterator, Mapping | |
| from enum import Enum | |
| from typing import TYPE_CHECKING, Any | |
| import pandas as pd | |
| from loguru import logger | |
| from langflow.exceptions.component import ComponentBuildError | |
| from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData | |
| from langflow.graph.utils import UnbuiltObject, UnbuiltResult, log_transaction | |
| from langflow.interface import initialize | |
| from langflow.interface.listing import lazy_load_dict | |
| from langflow.schema.artifact import ArtifactType | |
| from langflow.schema.data import Data | |
| from langflow.schema.message import Message | |
| from langflow.schema.schema import INPUT_FIELD_NAME, OutputValue, build_output_logs | |
| from langflow.services.deps import get_storage_service | |
| from langflow.utils.constants import DIRECT_TYPES | |
| from langflow.utils.schemas import ChatOutputResponse | |
| from langflow.utils.util import sync_to_async, unescape_string | |
| if TYPE_CHECKING: | |
| from uuid import UUID | |
| from langflow.custom import Component | |
| from langflow.events.event_manager import EventManager | |
| from langflow.graph.edge.base import CycleEdge, Edge | |
| from langflow.graph.graph.base import Graph | |
| from langflow.graph.vertex.schema import NodeData | |
| from langflow.services.tracing.schema import Log | |
| class VertexStates(str, Enum): | |
| """Vertex are related to it being active, inactive, or in an error state.""" | |
| ACTIVE = "ACTIVE" | |
| INACTIVE = "INACTIVE" | |
| ERROR = "ERROR" | |
| class Vertex: | |
| def __init__( | |
| self, | |
| data: NodeData, | |
| graph: Graph, | |
| *, | |
| base_type: str | None = None, | |
| is_task: bool = False, | |
| params: dict | None = None, | |
| ) -> None: | |
| # is_external means that the Vertex send or receives data from | |
| # an external source (e.g the chat) | |
| self._lock = asyncio.Lock() | |
| self.will_stream = False | |
| self.updated_raw_params = False | |
| self.id: str = data["id"] | |
| self.base_name = self.id.split("-")[0] | |
| self.is_state = False | |
| self.is_input = any(input_component_name in self.id for input_component_name in INPUT_COMPONENTS) | |
| self.is_output = any(output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS) | |
| self.has_session_id = None | |
| self.custom_component = None | |
| self.has_external_input = False | |
| self.has_external_output = False | |
| self.graph = graph | |
| self.full_data = data.copy() | |
| self.base_type: str | None = base_type | |
| self.outputs: list[dict] = [] | |
| self.parse_data() | |
| self.built_object: Any = UnbuiltObject() | |
| self.built_result: Any = None | |
| self.built = False | |
| self._successors_ids: list[str] | None = None | |
| self.artifacts: dict[str, Any] = {} | |
| self.artifacts_raw: dict[str, Any] = {} | |
| self.artifacts_type: dict[str, str] = {} | |
| self.steps: list[Callable] = [self._build] | |
| self.steps_ran: list[Callable] = [] | |
| self.task_id: str | None = None | |
| self.is_task = is_task | |
| self.params = params or {} | |
| self.parent_node_id: str | None = self.full_data.get("parent_node_id") | |
| self.load_from_db_fields: list[str] = [] | |
| self.parent_is_top_level = False | |
| self.layer = None | |
| self.result: ResultData | None = None | |
| self.results: dict[str, Any] = {} | |
| self.outputs_logs: dict[str, OutputValue] = {} | |
| self.logs: dict[str, list[Log]] = {} | |
| self.has_cycle_edges = False | |
| try: | |
| self.is_interface_component = self.vertex_type in InterfaceComponentTypes | |
| except ValueError: | |
| self.is_interface_component = False | |
| self.use_result = False | |
| self.build_times: list[float] = [] | |
| self.state = VertexStates.ACTIVE | |
| self.log_transaction_tasks: set[asyncio.Task] = set() | |
| def set_input_value(self, name: str, value: Any) -> None: | |
| if self.custom_component is None: | |
| msg = f"Vertex {self.id} does not have a component instance." | |
| raise ValueError(msg) | |
| self.custom_component._set_input_value(name, value) | |
| def to_data(self): | |
| return self.full_data | |
| def add_component_instance(self, component_instance: Component) -> None: | |
| component_instance.set_vertex(self) | |
| self.custom_component = component_instance | |
| def add_result(self, name: str, result: Any) -> None: | |
| self.results[name] = result | |
| def update_graph_state(self, key, new_state, *, append: bool) -> None: | |
| if append: | |
| self.graph.append_state(key, new_state, caller=self.id) | |
| else: | |
| self.graph.update_state(key, new_state, caller=self.id) | |
| def set_state(self, state: str) -> None: | |
| self.state = VertexStates[state] | |
| if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] <= 1: | |
| # If the vertex is inactive and has only one in degree | |
| # it means that it is not a merge point in the graph | |
| self.graph.inactivated_vertices.add(self.id) | |
| elif self.state == VertexStates.ACTIVE and self.id in self.graph.inactivated_vertices: | |
| self.graph.inactivated_vertices.remove(self.id) | |
| def is_active(self): | |
| return self.state == VertexStates.ACTIVE | |
| def avg_build_time(self): | |
| return sum(self.build_times) / len(self.build_times) if self.build_times else 0 | |
| def add_build_time(self, time) -> None: | |
| self.build_times.append(time) | |
| def set_result(self, result: ResultData) -> None: | |
| self.result = result | |
| def get_built_result(self): | |
| # If the Vertex.type is a power component | |
| # then we need to return the built object | |
| # instead of the result dict | |
| if self.is_interface_component and not isinstance(self.built_object, UnbuiltObject): | |
| result = self.built_object | |
| # if it is not a dict or a string and hasattr model_dump then | |
| # return the model_dump | |
| if not isinstance(result, dict | str) and hasattr(result, "content"): | |
| return result.content | |
| return result | |
| if isinstance(self.built_object, str): | |
| self.built_result = self.built_object | |
| if isinstance(self.built_result, UnbuiltResult): | |
| return {} | |
| return self.built_result if isinstance(self.built_result, dict) else {"result": self.built_result} | |
| def set_artifacts(self) -> None: | |
| pass | |
| def edges(self) -> list[CycleEdge]: | |
| return self.graph.get_vertex_edges(self.id) | |
| def outgoing_edges(self) -> list[CycleEdge]: | |
| return [edge for edge in self.edges if edge.source_id == self.id] | |
| def incoming_edges(self) -> list[CycleEdge]: | |
| return [edge for edge in self.edges if edge.target_id == self.id] | |
| def edges_source_names(self) -> set[str | None]: | |
| return {edge.source_handle.name for edge in self.edges} | |
| def predecessors(self) -> list[Vertex]: | |
| return self.graph.get_predecessors(self) | |
| def successors(self) -> list[Vertex]: | |
| return self.graph.get_successors(self) | |
| def successors_ids(self) -> list[str]: | |
| return self.graph.successor_map.get(self.id, []) | |
| def __getstate__(self): | |
| state = self.__dict__.copy() | |
| state["_lock"] = None # Locks are not serializable | |
| state["built_object"] = None if isinstance(self.built_object, UnbuiltObject) else self.built_object | |
| state["built_result"] = None if isinstance(self.built_result, UnbuiltResult) else self.built_result | |
| return state | |
| def __setstate__(self, state): | |
| self.__dict__.update(state) | |
| self._lock = asyncio.Lock() # Reinitialize the lock | |
| self.built_object = state.get("built_object") or UnbuiltObject() | |
| self.built_result = state.get("built_result") or UnbuiltResult() | |
| def set_top_level(self, top_level_vertices: list[str]) -> None: | |
| self.parent_is_top_level = self.parent_node_id in top_level_vertices | |
| def parse_data(self) -> None: | |
| self.data = self.full_data["data"] | |
| if self.data["node"]["template"]["_type"] == "Component": | |
| if "outputs" not in self.data["node"]: | |
| msg = f"Outputs not found for {self.display_name}" | |
| raise ValueError(msg) | |
| self.outputs = self.data["node"]["outputs"] | |
| else: | |
| self.outputs = self.data["node"].get("outputs", []) | |
| self.output = self.data["node"]["base_classes"] | |
| self.display_name: str = self.data["node"].get("display_name", self.id.split("-")[0]) | |
| self.icon: str = self.data["node"].get("icon", self.id.split("-")[0]) | |
| self.description: str = self.data["node"].get("description", "") | |
| self.frozen: bool = self.data["node"].get("frozen", False) | |
| self.is_input = self.data["node"].get("is_input") or self.is_input | |
| self.is_output = self.data["node"].get("is_output") or self.is_output | |
| template_dicts = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} | |
| self.has_session_id = "session_id" in template_dicts | |
| self.required_inputs: list[str] = [] | |
| self.optional_inputs: list[str] = [] | |
| for value_dict in template_dicts.values(): | |
| list_to_append = self.required_inputs if value_dict.get("required") else self.optional_inputs | |
| if "type" in value_dict: | |
| list_to_append.append(value_dict["type"]) | |
| if "input_types" in value_dict: | |
| list_to_append.extend(value_dict["input_types"]) | |
| template_dict = self.data["node"]["template"] | |
| self.vertex_type = ( | |
| self.data["type"] | |
| if "Tool" not in [type_ for out in self.outputs for type_ in out["types"]] | |
| or template_dict["_type"].islower() | |
| else template_dict["_type"] | |
| ) | |
| if self.base_type is None: | |
| for base_type, value in lazy_load_dict.all_types_dict.items(): | |
| if self.vertex_type in value: | |
| self.base_type = base_type | |
| break | |
| def get_value_from_template_dict(self, key: str): | |
| template_dict = self.data.get("node", {}).get("template", {}) | |
| if key not in template_dict: | |
| msg = f"Key {key} not found in template dict" | |
| raise ValueError(msg) | |
| return template_dict.get(key, {}).get("value") | |
| def get_task(self): | |
| # using the task_id, get the task from celery | |
| # and return it | |
| from celery.result import AsyncResult | |
| return AsyncResult(self.task_id) | |
| def _set_params_from_normal_edge(self, params: dict, edge: Edge, template_dict: dict): | |
| param_key = edge.target_param | |
| # If the param_key is in the template_dict and the edge.target_id is the current node | |
| # We check this to make sure params with the same name but different target_id | |
| # don't get overwritten | |
| if param_key in template_dict and edge.target_id == self.id: | |
| if template_dict[param_key].get("list"): | |
| if param_key not in params: | |
| params[param_key] = [] | |
| params[param_key].append(self.graph.get_vertex(edge.source_id)) | |
| elif edge.target_id == self.id: | |
| if isinstance(template_dict[param_key].get("value"), dict): | |
| # we don't know the key of the dict but we need to set the value | |
| # to the vertex that is the source of the edge | |
| param_dict = template_dict[param_key]["value"] | |
| if not param_dict or len(param_dict) != 1: | |
| params[param_key] = self.graph.get_vertex(edge.source_id) | |
| else: | |
| params[param_key] = {key: self.graph.get_vertex(edge.source_id) for key in param_dict} | |
| else: | |
| params[param_key] = self.graph.get_vertex(edge.source_id) | |
| return params | |
| def build_params(self) -> None: | |
| # sourcery skip: merge-list-append, remove-redundant-if | |
| # Some params are required, some are optional | |
| # but most importantly, some params are python base classes | |
| # like str and others are LangChain objects like LLMChain, BasePromptTemplate | |
| # so we need to be able to distinguish between the two | |
| # The dicts with "type" == "str" are the ones that are python base classes | |
| # and most likely have a "value" key | |
| # So for each key besides "_type" in the template dict, we have a dict | |
| # with a "type" key. If the type is not "str", then we need to get the | |
| # edge that connects to that node and get the Node with the required data | |
| # and use that as the value for the param | |
| # If the type is "str", then we need to get the value of the "value" key | |
| # and use that as the value for the param | |
| if self.graph is None: | |
| msg = "Graph not found" | |
| raise ValueError(msg) | |
| if self.updated_raw_params: | |
| self.updated_raw_params = False | |
| return | |
| template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} | |
| params: dict = {} | |
| for edge in self.edges: | |
| if not hasattr(edge, "target_param"): | |
| continue | |
| params = self._set_params_from_normal_edge(params, edge, template_dict) | |
| load_from_db_fields = [] | |
| for field_name, field in template_dict.items(): | |
| if field_name in params: | |
| continue | |
| # Skip _type and any value that has show == False and is not code | |
| # If we don't want to show code but we want to use it | |
| if field_name == "_type" or (not field.get("show") and field_name != "code"): | |
| continue | |
| # If the type is not transformable to a python base class | |
| # then we need to get the edge that connects to this node | |
| if field.get("type") == "file": | |
| # Load the type in value.get('fileTypes') using | |
| # what is inside value.get('content') | |
| # value.get('value') is the file name | |
| if file_path := field.get("file_path"): | |
| storage_service = get_storage_service() | |
| try: | |
| flow_id, file_name = os.path.split(file_path) | |
| full_path = storage_service.build_full_path(flow_id, file_name) | |
| except ValueError as e: | |
| if "too many values to unpack" in str(e): | |
| full_path = file_path | |
| else: | |
| raise | |
| params[field_name] = full_path | |
| elif field.get("required"): | |
| field_display_name = field.get("display_name") | |
| logger.warning( | |
| f"File path not found for {field_display_name} in component {self.display_name}. " | |
| "Setting to None." | |
| ) | |
| params[field_name] = None | |
| elif field["list"]: | |
| params[field_name] = [] | |
| else: | |
| params[field_name] = None | |
| elif field.get("type") in DIRECT_TYPES and params.get(field_name) is None: | |
| val = field.get("value") | |
| if field.get("type") == "code": | |
| try: | |
| if field_name == "code": | |
| params[field_name] = val | |
| else: | |
| params[field_name] = ast.literal_eval(val) if val else None | |
| except Exception: # noqa: BLE001 | |
| logger.debug(f"Error evaluating code for {field_name}") | |
| params[field_name] = val | |
| elif field.get("type") in {"dict", "NestedDict"}: | |
| # When dict comes from the frontend it comes as a | |
| # list of dicts, so we need to convert it to a dict | |
| # before passing it to the build method | |
| if isinstance(val, list): | |
| params[field_name] = {k: v for item in field.get("value", []) for k, v in item.items()} | |
| elif isinstance(val, dict): | |
| params[field_name] = val | |
| elif field.get("type") == "int" and val is not None: | |
| try: | |
| params[field_name] = int(val) | |
| except ValueError: | |
| params[field_name] = val | |
| elif field.get("type") == "float" and val is not None: | |
| try: | |
| params[field_name] = float(val) | |
| except ValueError: | |
| params[field_name] = val | |
| params[field_name] = val | |
| elif field.get("type") == "str" and val is not None: | |
| # val may contain escaped \n, \t, etc. | |
| # so we need to unescape it | |
| if isinstance(val, list): | |
| params[field_name] = [unescape_string(v) for v in val] | |
| elif isinstance(val, str): | |
| params[field_name] = unescape_string(val) | |
| elif isinstance(val, Data): | |
| params[field_name] = unescape_string(val.get_text()) | |
| elif field.get("type") == "bool" and val is not None: | |
| if isinstance(val, bool): | |
| params[field_name] = val | |
| elif isinstance(val, str): | |
| params[field_name] = bool(val) | |
| elif field.get("type") == "table" and val is not None: | |
| # check if the value is a list of dicts | |
| # if it is, create a pandas dataframe from it | |
| if isinstance(val, list) and all(isinstance(item, dict) for item in val): | |
| params[field_name] = pd.DataFrame(val) | |
| else: | |
| msg = f"Invalid value type {type(val)} for field {field_name}" | |
| raise ValueError(msg) | |
| elif val is not None and val != "": | |
| params[field_name] = val | |
| if field.get("load_from_db"): | |
| load_from_db_fields.append(field_name) | |
| if not field.get("required") and params.get(field_name) is None: | |
| if field.get("default"): | |
| params[field_name] = field.get("default") | |
| else: | |
| params.pop(field_name, None) | |
| # Add _type to params | |
| self.params = params | |
| self.load_from_db_fields = load_from_db_fields | |
| self.raw_params = params.copy() | |
| def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False) -> None: | |
| """Update the raw parameters of the vertex with the given new parameters. | |
| Args: | |
| new_params (Dict[str, Any]): The new parameters to update. | |
| overwrite (bool, optional): Whether to overwrite the existing parameters. | |
| Defaults to False. | |
| Raises: | |
| ValueError: If any key in new_params is not found in self.raw_params. | |
| """ | |
| # First check if the input_value in raw_params is not a vertex | |
| if not new_params: | |
| return | |
| if any(isinstance(self.raw_params.get(key), Vertex) for key in new_params): | |
| return | |
| if not overwrite: | |
| for key in new_params.copy(): # type: ignore[attr-defined] | |
| if key not in self.raw_params: | |
| new_params.pop(key) # type: ignore[attr-defined] | |
| self.raw_params.update(new_params) | |
| self.params = self.raw_params.copy() | |
| self.updated_raw_params = True | |
| def instantiate_component(self, user_id=None) -> None: | |
| if not self.custom_component: | |
| self.custom_component, _ = initialize.loading.instantiate_class( | |
| user_id=user_id, | |
| vertex=self, | |
| ) | |
| async def _build( | |
| self, | |
| fallback_to_env_vars, | |
| user_id=None, | |
| event_manager: EventManager | None = None, | |
| ) -> None: | |
| """Initiate the build process.""" | |
| logger.debug(f"Building {self.display_name}") | |
| await self._build_each_vertex_in_params_dict() | |
| if self.base_type is None: | |
| msg = f"Base type for vertex {self.display_name} not found" | |
| raise ValueError(msg) | |
| if not self.custom_component: | |
| custom_component, custom_params = initialize.loading.instantiate_class( | |
| user_id=user_id, vertex=self, event_manager=event_manager | |
| ) | |
| else: | |
| custom_component = self.custom_component | |
| if hasattr(self.custom_component, "set_event_manager"): | |
| self.custom_component.set_event_manager(event_manager) | |
| custom_params = initialize.loading.get_params(self.params) | |
| await self._build_results( | |
| custom_component=custom_component, | |
| custom_params=custom_params, | |
| fallback_to_env_vars=fallback_to_env_vars, | |
| base_type=self.base_type, | |
| ) | |
| self._validate_built_object() | |
| self.built = True | |
| def extract_messages_from_artifacts(self, artifacts: dict[str, Any]) -> list[dict]: | |
| """Extracts messages from the artifacts. | |
| Args: | |
| artifacts (Dict[str, Any]): The artifacts to extract messages from. | |
| Returns: | |
| List[str]: The extracted messages. | |
| """ | |
| try: | |
| text = artifacts["text"] | |
| sender = artifacts.get("sender") | |
| sender_name = artifacts.get("sender_name") | |
| session_id = artifacts.get("session_id") | |
| stream_url = artifacts.get("stream_url") | |
| files = [{"path": file} if isinstance(file, str) else file for file in artifacts.get("files", [])] | |
| component_id = self.id | |
| type_ = self.artifacts_type | |
| if isinstance(sender_name, Data | Message): | |
| sender_name = sender_name.get_text() | |
| messages = [ | |
| ChatOutputResponse( | |
| message=text, | |
| sender=sender, | |
| sender_name=sender_name, | |
| session_id=session_id, | |
| stream_url=stream_url, | |
| files=files, | |
| component_id=component_id, | |
| type=type_, | |
| ).model_dump(exclude_none=True) | |
| ] | |
| except KeyError: | |
| messages = [] | |
| return messages | |
| def finalize_build(self) -> None: | |
| result_dict = self.get_built_result() | |
| # We need to set the artifacts to pass information | |
| # to the frontend | |
| self.set_artifacts() | |
| artifacts = self.artifacts_raw | |
| messages = self.extract_messages_from_artifacts(artifacts) if isinstance(artifacts, dict) else [] | |
| result_dict = ResultData( | |
| results=result_dict, | |
| artifacts=artifacts, | |
| outputs=self.outputs_logs, | |
| logs=self.logs, | |
| messages=messages, | |
| component_display_name=self.display_name, | |
| component_id=self.id, | |
| ) | |
| self.set_result(result_dict) | |
| async def _build_each_vertex_in_params_dict(self) -> None: | |
| """Iterates over each vertex in the params dictionary and builds it.""" | |
| for key, value in self.raw_params.items(): | |
| if self._is_vertex(value): | |
| if value == self: | |
| del self.params[key] | |
| continue | |
| await self._build_vertex_and_update_params( | |
| key, | |
| value, | |
| ) | |
| elif isinstance(value, list) and self._is_list_of_vertices(value): | |
| await self._build_list_of_vertices_and_update_params(key, value) | |
| elif isinstance(value, dict): | |
| await self._build_dict_and_update_params( | |
| key, | |
| value, | |
| ) | |
| elif key not in self.params or self.updated_raw_params: | |
| self.params[key] = value | |
| async def _build_dict_and_update_params( | |
| self, | |
| key, | |
| vertices_dict: dict[str, Vertex], | |
| ) -> None: | |
| """Iterates over a dictionary of vertices, builds each and updates the params dictionary.""" | |
| for sub_key, value in vertices_dict.items(): | |
| if not self._is_vertex(value): | |
| self.params[key][sub_key] = value | |
| else: | |
| result = await value.get_result(self, target_handle_name=key) | |
| self.params[key][sub_key] = result | |
| def _is_vertex(self, value): | |
| """Checks if the provided value is an instance of Vertex.""" | |
| return isinstance(value, Vertex) | |
| def _is_list_of_vertices(self, value): | |
| """Checks if the provided value is a list of Vertex instances.""" | |
| return all(self._is_vertex(vertex) for vertex in value) | |
| async def get_result(self, requester: Vertex, target_handle_name: str | None = None) -> Any: | |
| """Retrieves the result of the vertex. | |
| This is a read-only method so it raises an error if the vertex has not been built yet. | |
| Returns: | |
| The result of the vertex. | |
| """ | |
| async with self._lock: | |
| return await self._get_result(requester, target_handle_name) | |
| def _log_transaction_async( | |
| self, flow_id: str | UUID, source: Vertex, status, target: Vertex | None = None, error=None | |
| ) -> None: | |
| task = asyncio.create_task(log_transaction(flow_id, source, status, target, error)) | |
| self.log_transaction_tasks.add(task) | |
| task.add_done_callback(self.log_transaction_tasks.discard) | |
| async def _get_result( | |
| self, | |
| requester: Vertex, | |
| target_handle_name: str | None = None, # noqa: ARG002 | |
| ) -> Any: | |
| """Retrieves the result of the built component. | |
| If the component has not been built yet, a ValueError is raised. | |
| Returns: | |
| The built result if use_result is True, else the built object. | |
| """ | |
| flow_id = self.graph.flow_id | |
| if not self.built: | |
| if flow_id: | |
| self._log_transaction_async(str(flow_id), source=self, target=requester, status="error") | |
| msg = f"Component {self.display_name} has not been built yet" | |
| raise ValueError(msg) | |
| result = self.built_result if self.use_result else self.built_object | |
| if flow_id: | |
| self._log_transaction_async(str(flow_id), source=self, target=requester, status="success") | |
| return result | |
| async def _build_vertex_and_update_params(self, key, vertex: Vertex) -> None: | |
| """Builds a given vertex and updates the params dictionary accordingly.""" | |
| result = await vertex.get_result(self, target_handle_name=key) | |
| self._handle_func(key, result) | |
| if isinstance(result, list): | |
| self._extend_params_list_with_result(key, result) | |
| self.params[key] = result | |
| async def _build_list_of_vertices_and_update_params( | |
| self, | |
| key, | |
| vertices: list[Vertex], | |
| ) -> None: | |
| """Iterates over a list of vertices, builds each and updates the params dictionary.""" | |
| self.params[key] = [] | |
| for vertex in vertices: | |
| result = await vertex.get_result(self, target_handle_name=key) | |
| # Weird check to see if the params[key] is a list | |
| # because sometimes it is a Data and breaks the code | |
| if not isinstance(self.params[key], list): | |
| self.params[key] = [self.params[key]] | |
| if isinstance(result, list): | |
| self.params[key].extend(result) | |
| else: | |
| try: | |
| if self.params[key] == result: | |
| continue | |
| self.params[key].append(result) | |
| except AttributeError as e: | |
| logger.exception(e) | |
| msg = ( | |
| f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {result}" | |
| f"Error building Component {self.display_name}: \n\n{e}" | |
| ) | |
| raise ValueError(msg) from e | |
| def _handle_func(self, key, result) -> None: | |
| """Handles 'func' key by checking if the result is a function and setting it as coroutine.""" | |
| if key == "func": | |
| if not isinstance(result, types.FunctionType): | |
| if hasattr(result, "run"): | |
| result = result.run | |
| elif hasattr(result, "get_function"): | |
| result = result.get_function() | |
| elif inspect.iscoroutinefunction(result): | |
| self.params["coroutine"] = result | |
| else: | |
| self.params["coroutine"] = sync_to_async(result) | |
| def _extend_params_list_with_result(self, key, result) -> None: | |
| """Extends a list in the params dictionary with the given result if it exists.""" | |
| if isinstance(self.params[key], list): | |
| self.params[key].extend(result) | |
| async def _build_results( | |
| self, custom_component, custom_params, base_type: str, *, fallback_to_env_vars=False | |
| ) -> None: | |
| try: | |
| result = await initialize.loading.get_instance_results( | |
| custom_component=custom_component, | |
| custom_params=custom_params, | |
| vertex=self, | |
| fallback_to_env_vars=fallback_to_env_vars, | |
| base_type=base_type, | |
| ) | |
| self.outputs_logs = build_output_logs(self, result) | |
| self._update_built_object_and_artifacts(result) | |
| except Exception as exc: | |
| tb = traceback.format_exc() | |
| logger.exception(exc) | |
| msg = f"Error building Component {self.display_name}: \n\n{exc}" | |
| raise ComponentBuildError(msg, tb) from exc | |
| def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]) -> None: | |
| """Updates the built object and its artifacts.""" | |
| if isinstance(result, tuple): | |
| if len(result) == 2: # noqa: PLR2004 | |
| self.built_object, self.artifacts = result | |
| elif len(result) == 3: # noqa: PLR2004 | |
| self.custom_component, self.built_object, self.artifacts = result | |
| self.logs = self.custom_component._output_logs | |
| self.artifacts_raw = self.artifacts.get("raw", None) | |
| self.artifacts_type = { | |
| self.outputs[0]["name"]: self.artifacts.get("type", None) or ArtifactType.UNKNOWN.value | |
| } | |
| self.artifacts = {self.outputs[0]["name"]: self.artifacts} | |
| else: | |
| self.built_object = result | |
| def _validate_built_object(self) -> None: | |
| """Checks if the built object is None and raises a ValueError if so.""" | |
| if isinstance(self.built_object, UnbuiltObject): | |
| msg = f"{self.display_name}: {self.built_object_repr()}" | |
| raise TypeError(msg) | |
| if self.built_object is None: | |
| message = f"{self.display_name} returned None." | |
| if self.base_type == "custom_components": | |
| message += " Make sure your build method returns a component." | |
| logger.warning(message) | |
| elif isinstance(self.built_object, Iterator | AsyncIterator): | |
| if self.display_name == "Text Output": | |
| msg = f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead." | |
| raise ValueError(msg) | |
| def _reset(self) -> None: | |
| self.built = False | |
| self.built_object = UnbuiltObject() | |
| self.built_result = UnbuiltResult() | |
| self.artifacts = {} | |
| self.steps_ran = [] | |
| self.build_params() | |
| def _is_chat_input(self) -> bool: | |
| return False | |
| def build_inactive(self) -> None: | |
| # Just set the results to None | |
| self.built = True | |
| self.built_object = None | |
| self.built_result = None | |
| async def build( | |
| self, | |
| user_id=None, | |
| inputs: dict[str, Any] | None = None, | |
| files: list[str] | None = None, | |
| requester: Vertex | None = None, | |
| event_manager: EventManager | None = None, | |
| **kwargs, | |
| ) -> Any: | |
| async with self._lock: | |
| if self.state == VertexStates.INACTIVE: | |
| # If the vertex is inactive, return None | |
| self.build_inactive() | |
| return None | |
| if self.frozen and self.built: | |
| return await self.get_requester_result(requester) | |
| if self.built and requester is not None: | |
| # This means that the vertex has already been built | |
| # and we are just getting the result for the requester | |
| return await self.get_requester_result(requester) | |
| self._reset() | |
| # inject session_id if it is not None | |
| if inputs is not None and "session" in inputs and inputs["session"] is not None and self.has_session_id: | |
| session_id_value = self.get_value_from_template_dict("session_id") | |
| if session_id_value == "": | |
| self.update_raw_params({"session_id": inputs["session"]}, overwrite=True) | |
| if self._is_chat_input() and (inputs or files): | |
| chat_input = {} | |
| if ( | |
| inputs | |
| and isinstance(inputs, dict) | |
| and "input_value" in inputs | |
| and inputs.get("input_value") is not None | |
| ): | |
| chat_input.update({"input_value": inputs.get(INPUT_FIELD_NAME, "")}) | |
| if files: | |
| chat_input.update({"files": files}) | |
| self.update_raw_params(chat_input, overwrite=True) | |
| # Run steps | |
| for step in self.steps: | |
| if step not in self.steps_ran: | |
| await step(user_id=user_id, event_manager=event_manager, **kwargs) | |
| self.steps_ran.append(step) | |
| self.finalize_build() | |
| return await self.get_requester_result(requester) | |
| async def get_requester_result(self, requester: Vertex | None): | |
| # If the requester is None, this means that | |
| # the Vertex is the root of the graph | |
| if requester is None: | |
| return self.built_object | |
| # Get the requester edge | |
| requester_edge = next((edge for edge in self.edges if edge.target_id == requester.id), None) | |
| # Return the result of the requester edge | |
| return ( | |
| None | |
| if requester_edge is None | |
| else await requester_edge.get_result_from_source(source=self, target=requester) | |
| ) | |
| def add_edge(self, edge: CycleEdge) -> None: | |
| if edge not in self.edges: | |
| self.edges.append(edge) | |
| def __repr__(self) -> str: | |
| return f"Vertex(display_name={self.display_name}, id={self.id}, data={self.data})" | |
| def __eq__(self, /, other: object) -> bool: | |
| try: | |
| if not isinstance(other, Vertex): | |
| return False | |
| # We should create a more robust comparison | |
| # for the Vertex class | |
| ids_are_equal = self.id == other.id | |
| # self.data is a dict and we need to compare them | |
| # to check if they are equal | |
| data_are_equal = self.data == other.data | |
| except AttributeError: | |
| return False | |
| else: | |
| return ids_are_equal and data_are_equal | |
| def __hash__(self) -> int: | |
| return id(self) | |
| def built_object_repr(self) -> str: | |
| # Add a message with an emoji, stars for success, | |
| return "Built successfully ✨" if self.built_object is not None else "Failed to build 😵💫" | |
| def apply_on_outputs(self, func: Callable[[Any], Any]) -> None: | |
| """Applies a function to the outputs of the vertex.""" | |
| if not self.custom_component or not self.custom_component.outputs: | |
| return | |
| # Apply the function to each output | |
| [func(output) for output in self.custom_component._outputs_map.values()] | |