Spaces:
Configuration error
Configuration error
| from typing import Any, Dict, Optional, Set, Union | |
| import numpy as np | |
| from networkx import DiGraph | |
| from inference.core.utils.image_utils import ImageType | |
| from inference.enterprise.workflows.complier.steps_executors.constants import ( | |
| IMAGE_TYPE_KEY, | |
| IMAGE_VALUE_KEY, | |
| PARENT_ID_KEY, | |
| ) | |
| from inference.enterprise.workflows.complier.utils import ( | |
| get_nodes_of_specific_kind, | |
| is_input_selector, | |
| ) | |
| from inference.enterprise.workflows.constants import INPUT_NODE_KIND, STEP_NODE_KIND | |
| from inference.enterprise.workflows.entities.validators import get_last_selector_chunk | |
| from inference.enterprise.workflows.errors import ( | |
| InvalidStepInputDetected, | |
| RuntimeParameterMissingError, | |
| ) | |
| def prepare_runtime_parameters( | |
| execution_graph: DiGraph, | |
| runtime_parameters: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| ensure_all_parameters_filled( | |
| execution_graph=execution_graph, | |
| runtime_parameters=runtime_parameters, | |
| ) | |
| runtime_parameters = fill_runtime_parameters_with_defaults( | |
| execution_graph=execution_graph, | |
| runtime_parameters=runtime_parameters, | |
| ) | |
| runtime_parameters = assembly_input_images( | |
| execution_graph=execution_graph, | |
| runtime_parameters=runtime_parameters, | |
| ) | |
| validate_inputs_binding( | |
| execution_graph=execution_graph, | |
| runtime_parameters=runtime_parameters, | |
| ) | |
| return runtime_parameters | |
| def ensure_all_parameters_filled( | |
| execution_graph: DiGraph, | |
| runtime_parameters: Dict[str, Any], | |
| ) -> None: | |
| parameters_without_default_values = get_input_parameters_without_default_values( | |
| execution_graph=execution_graph, | |
| ) | |
| missing_parameters = [] | |
| for name in parameters_without_default_values: | |
| if name not in runtime_parameters: | |
| missing_parameters.append(name) | |
| if len(missing_parameters) > 0: | |
| raise RuntimeParameterMissingError( | |
| f"Parameters passed to execution runtime do not define required inputs: {missing_parameters}" | |
| ) | |
| def get_input_parameters_without_default_values(execution_graph: DiGraph) -> Set[str]: | |
| input_nodes = get_nodes_of_specific_kind( | |
| execution_graph=execution_graph, | |
| kind=INPUT_NODE_KIND, | |
| ) | |
| result = set() | |
| for input_node in input_nodes: | |
| definition = execution_graph.nodes[input_node]["definition"] | |
| if definition.type == "InferenceImage": | |
| result.add(definition.name) | |
| continue | |
| if definition.type == "InferenceParameter" and definition.default_value is None: | |
| result.add(definition.name) | |
| continue | |
| return result | |
| def fill_runtime_parameters_with_defaults( | |
| execution_graph: DiGraph, | |
| runtime_parameters: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| default_values_parameters = get_input_parameters_default_values( | |
| execution_graph=execution_graph | |
| ) | |
| default_values_parameters.update(runtime_parameters) | |
| return default_values_parameters | |
| def get_input_parameters_default_values(execution_graph: DiGraph) -> Dict[str, Any]: | |
| input_nodes = get_nodes_of_specific_kind( | |
| execution_graph=execution_graph, | |
| kind=INPUT_NODE_KIND, | |
| ) | |
| result = {} | |
| for input_node in input_nodes: | |
| definition = execution_graph.nodes[input_node]["definition"] | |
| if ( | |
| definition.type == "InferenceParameter" | |
| and definition.default_value is not None | |
| ): | |
| result[definition.name] = definition.default_value | |
| return result | |
| def assembly_input_images( | |
| execution_graph: DiGraph, | |
| runtime_parameters: Dict[str, Any], | |
| ) -> Dict[str, Any]: | |
| input_nodes = get_nodes_of_specific_kind( | |
| execution_graph=execution_graph, | |
| kind=INPUT_NODE_KIND, | |
| ) | |
| for input_node in input_nodes: | |
| definition = execution_graph.nodes[input_node]["definition"] | |
| if definition.type != "InferenceImage": | |
| continue | |
| if issubclass(type(runtime_parameters[definition.name]), list): | |
| runtime_parameters[definition.name] = [ | |
| assembly_input_image( | |
| parameter=input_node, | |
| image=image, | |
| identifier=i, | |
| ) | |
| for i, image in enumerate(runtime_parameters[definition.name]) | |
| ] | |
| else: | |
| runtime_parameters[definition.name] = [ | |
| assembly_input_image( | |
| parameter=input_node, image=runtime_parameters[definition.name] | |
| ) | |
| ] | |
| return runtime_parameters | |
| def assembly_input_image( | |
| parameter: str, image: Any, identifier: Optional[int] = None | |
| ) -> Dict[str, Union[str, np.ndarray]]: | |
| parent = parameter | |
| if identifier is not None: | |
| parent = f"{parent}.[{identifier}]" | |
| if issubclass(type(image), dict): | |
| image[PARENT_ID_KEY] = parent | |
| return image | |
| if issubclass(type(image), np.ndarray): | |
| return { | |
| IMAGE_TYPE_KEY: ImageType.NUMPY_OBJECT.value, | |
| IMAGE_VALUE_KEY: image, | |
| PARENT_ID_KEY: parent, | |
| } | |
| raise InvalidStepInputDetected( | |
| f"Detected runtime parameter `{parameter}` defined as `InferenceImage` with type {type(image)} that is invalid." | |
| ) | |
| def validate_inputs_binding( | |
| execution_graph: DiGraph, | |
| runtime_parameters: Dict[str, Any], | |
| ) -> None: | |
| step_nodes = get_nodes_of_specific_kind( | |
| execution_graph=execution_graph, | |
| kind=STEP_NODE_KIND, | |
| ) | |
| for step in step_nodes: | |
| validate_step_input_bindings( | |
| step=step, | |
| execution_graph=execution_graph, | |
| runtime_parameters=runtime_parameters, | |
| ) | |
| def validate_step_input_bindings( | |
| step: str, | |
| execution_graph: DiGraph, | |
| runtime_parameters: Dict[str, Any], | |
| ) -> None: | |
| step_definition = execution_graph.nodes[step]["definition"] | |
| for input_name in step_definition.get_input_names(): | |
| selector_or_value = getattr(step_definition, input_name) | |
| if not is_input_selector(selector_or_value=selector_or_value): | |
| continue | |
| input_parameter_name = get_last_selector_chunk(selector=selector_or_value) | |
| parameter_value = runtime_parameters[input_parameter_name] | |
| step_definition.validate_field_binding( | |
| field_name=input_name, value=parameter_value | |
| ) | |