| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
|
|
| |
| |
| |
| |
| |
| |
| @dataclass(frozen=True) |
| class SelectiveBuildOperator: |
| |
| |
| |
| |
| |
| |
| name: str |
|
|
| |
| |
| |
| |
| |
| |
| is_root_operator: bool |
|
|
| |
| |
| |
| |
| |
| is_used_for_training: bool |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| include_all_overloads: bool |
|
|
| |
| _debug_info: tuple[str, ...] | None |
|
|
| @staticmethod |
| def from_yaml_dict( |
| op_name: str, op_info: dict[str, object] |
| ) -> SelectiveBuildOperator: |
| allowed_keys = { |
| "name", |
| "is_root_operator", |
| "is_used_for_training", |
| "include_all_overloads", |
| "debug_info", |
| } |
|
|
| if len(set(op_info.keys()) - allowed_keys) > 0: |
| raise Exception( |
| "Got unexpected top level keys: {}".format( |
| ",".join(set(op_info.keys()) - allowed_keys), |
| ) |
| ) |
|
|
| if "name" in op_info: |
| assert op_name == op_info["name"] |
|
|
| is_root_operator = op_info.get("is_root_operator", True) |
| assert isinstance(is_root_operator, bool) |
|
|
| is_used_for_training = op_info.get("is_used_for_training", True) |
| assert isinstance(is_used_for_training, bool) |
|
|
| include_all_overloads = op_info.get("include_all_overloads", True) |
| assert isinstance(include_all_overloads, bool) |
|
|
| debug_info: tuple[str, ...] | None = None |
| if "debug_info" in op_info: |
| di_list = op_info["debug_info"] |
| assert isinstance(di_list, list) |
| debug_info = tuple(str(x) for x in di_list) |
|
|
| return SelectiveBuildOperator( |
| name=op_name, |
| is_root_operator=is_root_operator, |
| is_used_for_training=is_used_for_training, |
| include_all_overloads=include_all_overloads, |
| _debug_info=debug_info, |
| ) |
|
|
| @staticmethod |
| def from_legacy_operator_name_without_overload( |
| name: str, |
| ) -> SelectiveBuildOperator: |
| return SelectiveBuildOperator( |
| name=name, |
| is_root_operator=True, |
| is_used_for_training=True, |
| include_all_overloads=True, |
| _debug_info=None, |
| ) |
|
|
| def to_dict(self) -> dict[str, object]: |
| ret: dict[str, object] = { |
| "is_root_operator": self.is_root_operator, |
| "is_used_for_training": self.is_used_for_training, |
| "include_all_overloads": self.include_all_overloads, |
| } |
| if self._debug_info is not None: |
| ret["debug_info"] = self._debug_info |
|
|
| return ret |
|
|
|
|
| def merge_debug_info( |
| lhs: tuple[str, ...] | None, |
| rhs: tuple[str, ...] | None, |
| ) -> tuple[str, ...] | None: |
| |
| if lhs is None and rhs is None: |
| return None |
|
|
| return tuple(set((lhs or ()) + (rhs or ()))) |
|
|
|
|
| def combine_operators( |
| lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator |
| ) -> SelectiveBuildOperator: |
| if str(lhs.name) != str(rhs.name): |
| raise Exception( |
| f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" |
| ) |
|
|
| return SelectiveBuildOperator( |
| name=lhs.name, |
| |
| |
| |
| is_root_operator=lhs.is_root_operator or rhs.is_root_operator, |
| |
| |
| |
| is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, |
| include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, |
| _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), |
| ) |
|
|
|
|
| def merge_operator_dicts( |
| lhs: dict[str, SelectiveBuildOperator], |
| rhs: dict[str, SelectiveBuildOperator], |
| ) -> dict[str, SelectiveBuildOperator]: |
| operators: dict[str, SelectiveBuildOperator] = {} |
| for op_name, op in list(lhs.items()) + list(rhs.items()): |
| new_op = op |
| if op_name in operators: |
| new_op = combine_operators(operators[op_name], op) |
|
|
| operators[op_name] = new_op |
|
|
| return operators |
|
|
|
|
| def strip_operator_overload_name(op_name: str) -> str: |
| return op_name.split(".", maxsplit=1)[0] |
|
|