Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/example_parser_configuration.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/function_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/is_in_graph_mode.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/keras_deps.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/keyword_args.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/lazy_loader.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/lock_util.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/module_wrapper.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/nest.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/nest_util.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/numpy_compat.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/object_identity.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/serialization.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_contextlib.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_decorator.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_decorator_export.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_export.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_inspect.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_should_use.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_stack.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/traceback_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/type_annotations.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/variable_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_checkpoint_reader.pyi +25 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_checkpoint_reader.so +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_determinism.pyi +17 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_determinism.so +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_kernel_registry.pyi +16 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_kernel_registry.so +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_nest.pyi +16 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_nest.so +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_stat_summarizer.pyi +26 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_tensor_float_32_execution.pyi +17 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_tfprof.so +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_transform_graph.pyi +16 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_transform_graph.so +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_utils.pyi +35 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_tf_stack.pyi +64 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_tf_stack.so +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/all_util.py +117 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/compat.py +226 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/custom_nest_protocol.py +120 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/decorator_utils.py +203 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/deprecated_module.py +24 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/deprecated_module_new.py +22 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py +763 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/dispatch.py +1302 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/example_parser_configuration.py +206 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/fast_module_type.pyi +16 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/example_parser_configuration.cpython-310.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/function_utils.cpython-310.pyc
ADDED
|
Binary file (3.36 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/is_in_graph_mode.cpython-310.pyc
ADDED
|
Binary file (401 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/keras_deps.cpython-310.pyc
ADDED
|
Binary file (2.12 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/keyword_args.cpython-310.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/lazy_loader.cpython-310.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/lock_util.cpython-310.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/module_wrapper.cpython-310.pyc
ADDED
|
Binary file (7.54 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/nest.cpython-310.pyc
ADDED
|
Binary file (47.2 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/nest_util.cpython-310.pyc
ADDED
|
Binary file (52.5 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/numpy_compat.cpython-310.pyc
ADDED
|
Binary file (5.54 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/object_identity.cpython-310.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/serialization.cpython-310.pyc
ADDED
|
Binary file (1.52 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_contextlib.cpython-310.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_decorator.cpython-310.pyc
ADDED
|
Binary file (9.48 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_decorator_export.cpython-310.pyc
ADDED
|
Binary file (528 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_export.cpython-310.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_inspect.cpython-310.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_should_use.cpython-310.pyc
ADDED
|
Binary file (8.52 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/tf_stack.cpython-310.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/traceback_utils.cpython-310.pyc
ADDED
|
Binary file (5.22 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/type_annotations.cpython-310.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/__pycache__/variable_utils.cpython-310.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_checkpoint_reader.pyi
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
class CheckpointReader:
|
| 19 |
+
def __init__(self, arg0: str) -> None: ...
|
| 20 |
+
@classmethod
|
| 21 |
+
def CheckpointReader_GetTensor(cls, arg0: CheckpointReader, arg1: str) -> object: ...
|
| 22 |
+
def _GetVariableToDataTypeMap(self, *args, **kwargs) -> Any: ...
|
| 23 |
+
def _HasTensor(self, arg0: str) -> bool: ...
|
| 24 |
+
def debug_string(self) -> bytes: ...
|
| 25 |
+
def get_variable_to_shape_map(self, *args, **kwargs) -> Any: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_checkpoint_reader.so
ADDED
|
Binary file (344 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_determinism.pyi
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
def enable(arg0: bool) -> None: ...
|
| 17 |
+
def is_enabled() -> bool: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_determinism.so
ADDED
|
Binary file (143 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_kernel_registry.pyi
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
def TryFindKernelClass(arg0: str) -> bytes: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_kernel_registry.so
ADDED
|
Binary file (198 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_nest.pyi
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
def FlattenDictItems(arg0: object) -> object: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_nest.so
ADDED
|
Binary file (141 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_stat_summarizer.pyi
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
from typing import overload
|
| 17 |
+
|
| 18 |
+
class StatSummarizer:
|
| 19 |
+
@overload
|
| 20 |
+
def __init__(self, arg0: str) -> None: ...
|
| 21 |
+
@overload
|
| 22 |
+
def __init__(self) -> None: ...
|
| 23 |
+
def GetOutputString(self) -> str: ...
|
| 24 |
+
def PrintStepStats(self) -> None: ...
|
| 25 |
+
def ProcessStepStats(self, arg0) -> None: ...
|
| 26 |
+
def ProcessStepStatsStr(self, arg0: str) -> None: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_tensor_float_32_execution.pyi
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
def enable(arg0: bool) -> None: ...
|
| 17 |
+
def is_enabled() -> bool: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_tfprof.so
ADDED
|
Binary file (230 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_transform_graph.pyi
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
def TransformGraphWithStringInputs(arg0: object, arg1: object, arg2: object, arg3: object) -> bytes: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_transform_graph.so
ADDED
|
Binary file (271 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_pywrap_utils.pyi
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
def AssertSameStructure(arg0: object, arg1: object, arg2: bool, arg3: bool) -> bool: ...
|
| 17 |
+
def AssertSameStructureForData(arg0: object, arg1: object, arg2: bool) -> bool: ...
|
| 18 |
+
def Flatten(arg0: object, arg1: bool) -> object: ...
|
| 19 |
+
def FlattenForData(arg0: object) -> object: ...
|
| 20 |
+
def IsAttrs(arg0: object) -> bool: ...
|
| 21 |
+
def IsCompositeTensor(arg0: object) -> bool: ...
|
| 22 |
+
def IsDataTypeSupportedByOneDNNOnThisCPU(arg0) -> bool: ...
|
| 23 |
+
def IsMapping(arg0: object) -> bool: ...
|
| 24 |
+
def IsMappingView(arg0: object) -> bool: ...
|
| 25 |
+
def IsMutableMapping(arg0: object) -> bool: ...
|
| 26 |
+
def IsNamedtuple(arg0: object, arg1: bool) -> object: ...
|
| 27 |
+
def IsNested(arg0: object) -> bool: ...
|
| 28 |
+
def IsNestedForData(arg0: object) -> bool: ...
|
| 29 |
+
def IsNestedOrComposite(arg0: object) -> bool: ...
|
| 30 |
+
def IsResourceVariable(arg0: object) -> bool: ...
|
| 31 |
+
def IsTensor(arg0: object) -> bool: ...
|
| 32 |
+
def IsTypeSpec(arg0: object) -> bool: ...
|
| 33 |
+
def IsVariable(arg0: object) -> bool: ...
|
| 34 |
+
def RegisterPyObject(arg0: object, arg1: object) -> object: ...
|
| 35 |
+
def SameNamedtuples(arg0: object, arg1: object) -> object: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_tf_stack.pyi
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
from typing import Iterator
|
| 17 |
+
|
| 18 |
+
from typing import overload
|
| 19 |
+
|
| 20 |
+
class GraphDebugInfoBuilder:
|
| 21 |
+
def __init__(self) -> None: ...
|
| 22 |
+
def AccumulateStackTrace(self, function: str, op: str, trace) -> None: ...
|
| 23 |
+
def AppendGraphDebugInfo(self, prefix: str, debug_info: bytes) -> None: ...
|
| 24 |
+
def Build(self) -> bytes: ...
|
| 25 |
+
|
| 26 |
+
class PyBindFileSet:
|
| 27 |
+
def __init__(self) -> None: ...
|
| 28 |
+
def update_to(self, arg0: set) -> None: ...
|
| 29 |
+
|
| 30 |
+
class PyBindSourceMap:
|
| 31 |
+
def __init__(self) -> None: ...
|
| 32 |
+
def update_to(self, arg0: tuple) -> None: ...
|
| 33 |
+
|
| 34 |
+
class StackFrame:
|
| 35 |
+
def __init__(self, *args, **kwargs) -> None: ...
|
| 36 |
+
def __eq__(self, arg0: StackFrame) -> bool: ...
|
| 37 |
+
def __getitem__(self, arg0: object) -> object: ...
|
| 38 |
+
def __hash__(self) -> int: ...
|
| 39 |
+
def __iter__(self) -> Iterator: ...
|
| 40 |
+
def __len__(self) -> int: ...
|
| 41 |
+
def __ne__(self, arg0: StackFrame) -> bool: ...
|
| 42 |
+
@property
|
| 43 |
+
def filename(self) -> str: ...
|
| 44 |
+
@property
|
| 45 |
+
def line(self) -> str: ...
|
| 46 |
+
@property
|
| 47 |
+
def lineno(self) -> int: ...
|
| 48 |
+
@property
|
| 49 |
+
def name(self) -> str: ...
|
| 50 |
+
|
| 51 |
+
class StackTrace:
|
| 52 |
+
def __init__(self, *args, **kwargs) -> None: ...
|
| 53 |
+
def get_user_frames(self) -> StackTrace: ...
|
| 54 |
+
def last_user_frame(self) -> StackFrame: ...
|
| 55 |
+
def __eq__(self, arg0: StackTrace) -> bool: ...
|
| 56 |
+
@overload
|
| 57 |
+
def __getitem__(self, arg0: int) -> StackFrame: ...
|
| 58 |
+
@overload
|
| 59 |
+
def __getitem__(self, arg0: slice) -> StackTrace: ...
|
| 60 |
+
def __hash__(self) -> int: ...
|
| 61 |
+
def __len__(self) -> int: ...
|
| 62 |
+
|
| 63 |
+
def LoadTracesFromDebugInfo(debug_info_proto: bytes) -> dict[str,StackTrace]: ...
|
| 64 |
+
def extract_stack(source_map: PyBindSourceMap, file_set: PyBindFileSet, stacklevel: int = ...) -> StackTrace: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/_tf_stack.so
ADDED
|
Binary file (655 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/all_util.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Generate __all__ from a module docstring."""
|
| 17 |
+
import re as _re
|
| 18 |
+
import sys as _sys
|
| 19 |
+
|
| 20 |
+
from tensorflow.python.util import tf_inspect as _tf_inspect
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_reference_pattern = _re.compile(r'^@@(\w+)$', flags=_re.MULTILINE)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_all(module_name, doc_string_modules=None):
|
| 27 |
+
"""Generates `__all__` from the docstring of one or more modules.
|
| 28 |
+
|
| 29 |
+
Usage: `make_all(__name__)` or
|
| 30 |
+
`make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
|
| 31 |
+
modules must each a docstring, and `__all__` will contain all symbols with
|
| 32 |
+
`@@` references, where that symbol currently exists in the module named
|
| 33 |
+
`module_name`.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
module_name: The name of the module (usually `__name__`).
|
| 37 |
+
doc_string_modules: a list of modules from which to take docstring.
|
| 38 |
+
If None, then a list containing only the module named `module_name` is used.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A list suitable for use as `__all__`.
|
| 42 |
+
"""
|
| 43 |
+
if doc_string_modules is None:
|
| 44 |
+
doc_string_modules = [_sys.modules[module_name]]
|
| 45 |
+
cur_members = set(
|
| 46 |
+
name for name, _ in _tf_inspect.getmembers(_sys.modules[module_name]))
|
| 47 |
+
|
| 48 |
+
results = set()
|
| 49 |
+
for doc_module in doc_string_modules:
|
| 50 |
+
results.update([m.group(1)
|
| 51 |
+
for m in _reference_pattern.finditer(doc_module.__doc__)
|
| 52 |
+
if m.group(1) in cur_members])
|
| 53 |
+
return list(results)
|
| 54 |
+
|
| 55 |
+
# Hidden attributes are attributes that have been hidden by
|
| 56 |
+
# `remove_undocumented`. They can be re-instated by `reveal_undocumented`.
|
| 57 |
+
# This maps symbol names to a tuple, containing:
|
| 58 |
+
# (module object, attribute value)
|
| 59 |
+
_HIDDEN_ATTRIBUTES = {}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def reveal_undocumented(symbol_name, target_module=None):
|
| 63 |
+
"""Reveals a symbol that was previously removed by `remove_undocumented`.
|
| 64 |
+
|
| 65 |
+
This should be used by tensorflow internal tests only. It explicitly
|
| 66 |
+
defeats the encapsulation afforded by `remove_undocumented`.
|
| 67 |
+
|
| 68 |
+
It throws an exception when the symbol was not hidden in the first place.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
symbol_name: a string representing the full absolute path of the symbol.
|
| 72 |
+
target_module: if specified, the module in which to restore the symbol.
|
| 73 |
+
"""
|
| 74 |
+
if symbol_name not in _HIDDEN_ATTRIBUTES:
|
| 75 |
+
raise LookupError('Symbol %s is not a hidden symbol' % symbol_name)
|
| 76 |
+
symbol_basename = symbol_name.split('.')[-1]
|
| 77 |
+
(original_module, attr_value) = _HIDDEN_ATTRIBUTES[symbol_name]
|
| 78 |
+
if not target_module: target_module = original_module
|
| 79 |
+
setattr(target_module, symbol_basename, attr_value)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def remove_undocumented(module_name, allowed_exception_list=None,
|
| 83 |
+
doc_string_modules=None):
|
| 84 |
+
"""Removes symbols in a module that are not referenced by a docstring.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
module_name: the name of the module (usually `__name__`).
|
| 88 |
+
allowed_exception_list: a list of names that should not be removed.
|
| 89 |
+
doc_string_modules: a list of modules from which to take the docstrings.
|
| 90 |
+
If None, then a list containing only the module named `module_name` is used.
|
| 91 |
+
|
| 92 |
+
Furthermore, if a symbol previously added with `add_to_global_allowlist`,
|
| 93 |
+
then it will always be allowed. This is useful for internal tests.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
None
|
| 97 |
+
"""
|
| 98 |
+
current_symbols = set(dir(_sys.modules[module_name]))
|
| 99 |
+
should_have = make_all(module_name, doc_string_modules)
|
| 100 |
+
should_have += allowed_exception_list or []
|
| 101 |
+
extra_symbols = current_symbols - set(should_have)
|
| 102 |
+
target_module = _sys.modules[module_name]
|
| 103 |
+
for extra_symbol in extra_symbols:
|
| 104 |
+
# Skip over __file__, etc. Also preserves internal symbols.
|
| 105 |
+
if extra_symbol.startswith('_'): continue
|
| 106 |
+
fully_qualified_name = module_name + '.' + extra_symbol
|
| 107 |
+
_HIDDEN_ATTRIBUTES[fully_qualified_name] = (target_module,
|
| 108 |
+
getattr(target_module,
|
| 109 |
+
extra_symbol))
|
| 110 |
+
delattr(target_module, extra_symbol)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
__all__ = [
|
| 114 |
+
'make_all',
|
| 115 |
+
'remove_undocumented',
|
| 116 |
+
'reveal_undocumented',
|
| 117 |
+
]
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/compat.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Compatibility functions.
|
| 16 |
+
|
| 17 |
+
The `tf.compat` module contains two sets of compatibility functions.
|
| 18 |
+
|
| 19 |
+
## Tensorflow 1.x and 2.x APIs
|
| 20 |
+
|
| 21 |
+
The `compat.v1` and `compat.v2` submodules provide a complete copy of both the
|
| 22 |
+
`v1` and `v2` APIs for backwards and forwards compatibility across TensorFlow
|
| 23 |
+
versions 1.x and 2.x. See the
|
| 24 |
+
[migration guide](https://www.tensorflow.org/guide/migrate) for details.
|
| 25 |
+
|
| 26 |
+
## Utilities for writing compatible code
|
| 27 |
+
|
| 28 |
+
Aside from the `compat.v1` and `compat.v2` submodules, `tf.compat` also contains
|
| 29 |
+
a set of helper functions for writing code that works in both:
|
| 30 |
+
|
| 31 |
+
* TensorFlow 1.x and 2.x
|
| 32 |
+
* Python 2 and 3
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## Type collections
|
| 36 |
+
|
| 37 |
+
The compatibility module also provides the following aliases for common
|
| 38 |
+
sets of python types:
|
| 39 |
+
|
| 40 |
+
* `bytes_or_text_types`
|
| 41 |
+
* `complex_types`
|
| 42 |
+
* `integral_types`
|
| 43 |
+
* `real_types`
|
| 44 |
+
|
| 45 |
+
API docstring: tensorflow.compat
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
import codecs
|
| 49 |
+
import collections.abc as collections_abc # pylint: disable=unused-import
|
| 50 |
+
import numbers as _numbers
|
| 51 |
+
|
| 52 |
+
import numpy as _np
|
| 53 |
+
|
| 54 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def as_bytes(bytes_or_text, encoding='utf-8'):
|
| 58 |
+
"""Converts `bytearray`, `bytes`, or unicode python input types to `bytes`.
|
| 59 |
+
|
| 60 |
+
Uses utf-8 encoding for text by default.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
bytes_or_text: A `bytearray`, `bytes`, `str`, or `unicode` object.
|
| 64 |
+
encoding: A string indicating the charset for encoding unicode.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
A `bytes` object.
|
| 68 |
+
|
| 69 |
+
Raises:
|
| 70 |
+
TypeError: If `bytes_or_text` is not a binary or unicode string.
|
| 71 |
+
"""
|
| 72 |
+
# Validate encoding, a LookupError will be raised if invalid.
|
| 73 |
+
encoding = codecs.lookup(encoding).name
|
| 74 |
+
if isinstance(bytes_or_text, bytearray):
|
| 75 |
+
return bytes(bytes_or_text)
|
| 76 |
+
elif isinstance(bytes_or_text, str):
|
| 77 |
+
return bytes_or_text.encode(encoding)
|
| 78 |
+
elif isinstance(bytes_or_text, bytes):
|
| 79 |
+
return bytes_or_text
|
| 80 |
+
else:
|
| 81 |
+
raise TypeError('Expected binary or unicode string, got %r' %
|
| 82 |
+
(bytes_or_text,))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def as_text(bytes_or_text, encoding='utf-8'):
|
| 86 |
+
"""Converts any string-like python input types to unicode.
|
| 87 |
+
|
| 88 |
+
Returns the input as a unicode string. Uses utf-8 encoding for text
|
| 89 |
+
by default.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
bytes_or_text: A `bytes`, `str`, or `unicode` object.
|
| 93 |
+
encoding: A string indicating the charset for decoding unicode.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
A `unicode` (Python 2) or `str` (Python 3) object.
|
| 97 |
+
|
| 98 |
+
Raises:
|
| 99 |
+
TypeError: If `bytes_or_text` is not a binary or unicode string.
|
| 100 |
+
"""
|
| 101 |
+
# Validate encoding, a LookupError will be raised if invalid.
|
| 102 |
+
encoding = codecs.lookup(encoding).name
|
| 103 |
+
if isinstance(bytes_or_text, str):
|
| 104 |
+
return bytes_or_text
|
| 105 |
+
elif isinstance(bytes_or_text, bytes):
|
| 106 |
+
return bytes_or_text.decode(encoding)
|
| 107 |
+
else:
|
| 108 |
+
raise TypeError('Expected binary or unicode string, got %r' % bytes_or_text)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def as_str(bytes_or_text, encoding='utf-8'):
|
| 112 |
+
"""Acts as an alias for the `as_text` function..
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
bytes_or_text: The input value to be converted. A bytes or unicode object.
|
| 116 |
+
encoding: Optional string. The encoding to use if bytes_or_text is a bytes
|
| 117 |
+
object. Defaults to 'utf-8'.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
A unicode string.
|
| 121 |
+
|
| 122 |
+
Raises:
|
| 123 |
+
TypeError: If bytes_or_text is not a bytes or unicode object.
|
| 124 |
+
UnicodeDecodeError: If bytes_or_text is a bytes object and cannot be
|
| 125 |
+
decoded using the specified encoding.
|
| 126 |
+
"""
|
| 127 |
+
return as_text(bytes_or_text, encoding)
|
| 128 |
+
|
| 129 |
+
tf_export('compat.as_text')(as_text)
|
| 130 |
+
tf_export('compat.as_bytes')(as_bytes)
|
| 131 |
+
tf_export('compat.as_str')(as_str)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
@tf_export('compat.as_str_any')
|
| 135 |
+
def as_str_any(value, encoding='utf-8'):
|
| 136 |
+
"""Converts input to `str` type.
|
| 137 |
+
|
| 138 |
+
Uses `str(value)`, except for `bytes` typed inputs, which are converted
|
| 139 |
+
using `as_str`.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
value: A object that can be converted to `str`.
|
| 143 |
+
encoding: Encoding for `bytes` typed inputs.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
A `str` object.
|
| 147 |
+
"""
|
| 148 |
+
if isinstance(value, bytes):
|
| 149 |
+
return as_str(value, encoding=encoding)
|
| 150 |
+
else:
|
| 151 |
+
return str(value)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@tf_export('compat.path_to_str')
|
| 155 |
+
def path_to_str(path):
|
| 156 |
+
r"""Converts input which is a `PathLike` object to `str` type.
|
| 157 |
+
|
| 158 |
+
Converts from any python constant representation of a `PathLike` object to
|
| 159 |
+
a string. If the input is not a `PathLike` object, simply returns the input.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
path: An object that can be converted to path representation.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
A `str` object.
|
| 166 |
+
|
| 167 |
+
Usage:
|
| 168 |
+
In case a simplified `str` version of the path is needed from an
|
| 169 |
+
`os.PathLike` object.
|
| 170 |
+
|
| 171 |
+
Examples:
|
| 172 |
+
```python
|
| 173 |
+
$ tf.compat.path_to_str('C:\XYZ\tensorflow\./.././tensorflow')
|
| 174 |
+
'C:\XYZ\tensorflow\./.././tensorflow' # Windows OS
|
| 175 |
+
$ tf.compat.path_to_str(Path('C:\XYZ\tensorflow\./.././tensorflow'))
|
| 176 |
+
'C:\XYZ\tensorflow\..\tensorflow' # Windows OS
|
| 177 |
+
$ tf.compat.path_to_str(Path('./corpus'))
|
| 178 |
+
'corpus' # Linux OS
|
| 179 |
+
$ tf.compat.path_to_str('./.././Corpus')
|
| 180 |
+
'./.././Corpus' # Linux OS
|
| 181 |
+
$ tf.compat.path_to_str(Path('./.././Corpus'))
|
| 182 |
+
'../Corpus' # Linux OS
|
| 183 |
+
$ tf.compat.path_to_str(Path('./..////../'))
|
| 184 |
+
'../..' # Linux OS
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
"""
|
| 188 |
+
if hasattr(path, '__fspath__'):
|
| 189 |
+
path = as_str_any(path.__fspath__())
|
| 190 |
+
return path
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def path_to_bytes(path):
|
| 194 |
+
r"""Converts input which is a `PathLike` object to `bytes`.
|
| 195 |
+
|
| 196 |
+
Converts from any python constant representation of a `PathLike` object
|
| 197 |
+
or `str` to bytes.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
path: An object that can be converted to path representation.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
A `bytes` object.
|
| 204 |
+
|
| 205 |
+
Usage:
|
| 206 |
+
In case a simplified `bytes` version of the path is needed from an
|
| 207 |
+
`os.PathLike` object.
|
| 208 |
+
"""
|
| 209 |
+
if hasattr(path, '__fspath__'):
|
| 210 |
+
path = path.__fspath__()
|
| 211 |
+
return as_bytes(path)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we
|
| 215 |
+
# need to check them specifically. The same goes from Real and Complex.
|
| 216 |
+
integral_types = (_numbers.Integral, _np.integer)
|
| 217 |
+
tf_export('compat.integral_types').export_constant(__name__, 'integral_types')
|
| 218 |
+
real_types = (_numbers.Real, _np.integer, _np.floating)
|
| 219 |
+
tf_export('compat.real_types').export_constant(__name__, 'real_types')
|
| 220 |
+
complex_types = (_numbers.Complex, _np.number)
|
| 221 |
+
tf_export('compat.complex_types').export_constant(__name__, 'complex_types')
|
| 222 |
+
|
| 223 |
+
# Either bytes or text.
|
| 224 |
+
bytes_or_text_types = (bytes, str)
|
| 225 |
+
tf_export('compat.bytes_or_text_types').export_constant(__name__,
|
| 226 |
+
'bytes_or_text_types')
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/custom_nest_protocol.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Protocol class for custom tf.nest support."""
|
| 17 |
+
|
| 18 |
+
import typing
|
| 19 |
+
from typing import Protocol
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@typing.runtime_checkable
|
| 23 |
+
class CustomNestProtocol(Protocol):
|
| 24 |
+
"""Protocol for adding custom tf.nest support in user-defined classes.
|
| 25 |
+
|
| 26 |
+
User classes should implement the two methods defined in this protocol in
|
| 27 |
+
order to be supported by nest functions.
|
| 28 |
+
- `__tf_flatten__` for generating the flattened components and the metadata
|
| 29 |
+
of the current object.
|
| 30 |
+
- `__tf_unflatten__` for creating a new object based on the input metadata
|
| 31 |
+
and the components.
|
| 32 |
+
See the method doc for details.
|
| 33 |
+
|
| 34 |
+
In terms of support level, classes implementing this protocol
|
| 35 |
+
- are supported by tf.nest and tf.data functions.
|
| 36 |
+
- have limited support from tf.function, which requires writing a custom
|
| 37 |
+
TraceType subclass to be used as the input or output of a tf.function.
|
| 38 |
+
- are NOT supported by SavedModel.
|
| 39 |
+
|
| 40 |
+
Code Examples:
|
| 41 |
+
|
| 42 |
+
>>> import dataclasses
|
| 43 |
+
>>> @dataclasses.dataclass
|
| 44 |
+
... class MaskedTensor:
|
| 45 |
+
... mask: bool
|
| 46 |
+
... value: tf.Tensor
|
| 47 |
+
...
|
| 48 |
+
... def __tf_flatten__(self):
|
| 49 |
+
... metadata = (self.mask,) # static config.
|
| 50 |
+
... components = (self.value,) # dynamic values.
|
| 51 |
+
... return metadata, components
|
| 52 |
+
...
|
| 53 |
+
... @classmethod
|
| 54 |
+
... def __tf_unflatten__(cls, metadata, components):
|
| 55 |
+
... mask = metadata[0]
|
| 56 |
+
... value = components[0]
|
| 57 |
+
... return MaskedTensor(mask=mask, value=value)
|
| 58 |
+
...
|
| 59 |
+
>>> mt = MaskedTensor(mask=True, value=tf.constant([1]))
|
| 60 |
+
>>> mt
|
| 61 |
+
MaskedTensor(mask=True, value=<tf.Tensor: ... numpy=array([1], dtype=int32)>)
|
| 62 |
+
>>> tf.nest.is_nested(mt)
|
| 63 |
+
True
|
| 64 |
+
>>> mt2 = MaskedTensor(mask=False, value=tf.constant([2]))
|
| 65 |
+
>>> tf.nest.assert_same_structure(mt, mt2)
|
| 66 |
+
|
| 67 |
+
>>> leaves = tf.nest.flatten(mt)
|
| 68 |
+
>>> leaves
|
| 69 |
+
[<tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>]
|
| 70 |
+
|
| 71 |
+
>>> mt3 = tf.nest.pack_sequence_as(mt, leaves)
|
| 72 |
+
>>> mt3
|
| 73 |
+
MaskedTensor(mask=True, value=<tf.Tensor: ... numpy=array([1], dtype=int32)>)
|
| 74 |
+
>>> bool(mt == mt3)
|
| 75 |
+
True
|
| 76 |
+
|
| 77 |
+
>>> tf.nest.map_structure(lambda x: x * 2, mt)
|
| 78 |
+
MaskedTensor(mask=True, value=<tf.Tensor: ... numpy=array([2], dtype=int32)>)
|
| 79 |
+
|
| 80 |
+
More examples are available in the unit tests (nest_test.py).
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __tf_flatten__(self):
|
| 84 |
+
"""Flatten current object into (metadata, components).
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
A `tuple` of (metadata, components), where
|
| 88 |
+
- metadata is a custom Python object that stands for the static config
|
| 89 |
+
of the current object, which is supposed to be fixed and not affected
|
| 90 |
+
by data transformation.
|
| 91 |
+
- components is a `tuple` that contains the modifiable fields of the
|
| 92 |
+
current object.
|
| 93 |
+
|
| 94 |
+
Implementation Note:
|
| 95 |
+
- This method should not invoke any TensorFlow ops.
|
| 96 |
+
- This method only needs to flatten the current level. If current object has
|
| 97 |
+
an attribute that also need custom flattening, nest functions (such as
|
| 98 |
+
`nest.flatten`) will utilize this method to do recursive flattening.
|
| 99 |
+
- Components must ba a `tuple`, not a `list`
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
@classmethod
|
| 103 |
+
def __tf_unflatten__(cls, metadata, components):
|
| 104 |
+
"""Create a user-defined object from (metadata, components).
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
metadata: a custom Python objet that stands for the static config for
|
| 108 |
+
reconstructing a new object of the current class.
|
| 109 |
+
components: a `tuple` that contains the dynamic data fields of the current
|
| 110 |
+
class, for object reconstruction.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
The user-defined object, with the same class of the current object.
|
| 114 |
+
|
| 115 |
+
Implementation Note:
|
| 116 |
+
- This method should not invoke any TensorFlow ops.
|
| 117 |
+
- This method only needs to unflatten the current level. If the object has
|
| 118 |
+
an attribute that also need custom unflattening, nest functions will
|
| 119 |
+
utilize this method to do recursive unflattening.
|
| 120 |
+
"""
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/decorator_utils.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Utility functions for writing decorators (which modify docstrings)."""
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_qualified_name(function):
|
| 21 |
+
# Python 3
|
| 22 |
+
if hasattr(function, '__qualname__'):
|
| 23 |
+
return function.__qualname__
|
| 24 |
+
|
| 25 |
+
# Python 2
|
| 26 |
+
if hasattr(function, 'im_class'):
|
| 27 |
+
return function.im_class.__name__ + '.' + function.__name__
|
| 28 |
+
return function.__name__
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _normalize_docstring(docstring):
|
| 32 |
+
"""Normalizes the docstring.
|
| 33 |
+
|
| 34 |
+
Replaces tabs with spaces, removes leading and trailing blanks lines, and
|
| 35 |
+
removes any indentation.
|
| 36 |
+
|
| 37 |
+
Copied from PEP-257:
|
| 38 |
+
https://www.python.org/dev/peps/pep-0257/#handling-docstring-indentation
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
docstring: the docstring to normalize
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
The normalized docstring
|
| 45 |
+
"""
|
| 46 |
+
if not docstring:
|
| 47 |
+
return ''
|
| 48 |
+
# Convert tabs to spaces (following the normal Python rules)
|
| 49 |
+
# and split into a list of lines:
|
| 50 |
+
lines = docstring.expandtabs().splitlines()
|
| 51 |
+
# Determine minimum indentation (first line doesn't count):
|
| 52 |
+
# (we use sys.maxsize because sys.maxint doesn't exist in Python 3)
|
| 53 |
+
indent = sys.maxsize
|
| 54 |
+
for line in lines[1:]:
|
| 55 |
+
stripped = line.lstrip()
|
| 56 |
+
if stripped:
|
| 57 |
+
indent = min(indent, len(line) - len(stripped))
|
| 58 |
+
# Remove indentation (first line is special):
|
| 59 |
+
trimmed = [lines[0].strip()]
|
| 60 |
+
if indent < sys.maxsize:
|
| 61 |
+
for line in lines[1:]:
|
| 62 |
+
trimmed.append(line[indent:].rstrip())
|
| 63 |
+
# Strip off trailing and leading blank lines:
|
| 64 |
+
while trimmed and not trimmed[-1]:
|
| 65 |
+
trimmed.pop()
|
| 66 |
+
while trimmed and not trimmed[0]:
|
| 67 |
+
trimmed.pop(0)
|
| 68 |
+
# Return a single string:
|
| 69 |
+
return '\n'.join(trimmed)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def add_notice_to_docstring(doc,
|
| 73 |
+
instructions,
|
| 74 |
+
no_doc_str,
|
| 75 |
+
suffix_str,
|
| 76 |
+
notice,
|
| 77 |
+
notice_type='Warning'):
|
| 78 |
+
"""Adds a deprecation notice to a docstring.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
doc: The original docstring.
|
| 82 |
+
instructions: A string, describing how to fix the problem.
|
| 83 |
+
no_doc_str: The default value to use for `doc` if `doc` is empty.
|
| 84 |
+
suffix_str: Is added to the end of the first line.
|
| 85 |
+
notice: A list of strings. The main notice warning body.
|
| 86 |
+
notice_type: The type of notice to use. Should be one of `[Caution,
|
| 87 |
+
Deprecated, Important, Note, Warning]`
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
A new docstring, with the notice attached.
|
| 91 |
+
|
| 92 |
+
Raises:
|
| 93 |
+
ValueError: If `notice` is empty.
|
| 94 |
+
"""
|
| 95 |
+
allowed_notice_types = ['Deprecated', 'Warning', 'Caution', 'Important',
|
| 96 |
+
'Note']
|
| 97 |
+
if notice_type not in allowed_notice_types:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
f'Unrecognized notice type. Should be one of: {allowed_notice_types}')
|
| 100 |
+
|
| 101 |
+
if not doc:
|
| 102 |
+
lines = [no_doc_str]
|
| 103 |
+
else:
|
| 104 |
+
lines = _normalize_docstring(doc).splitlines()
|
| 105 |
+
lines[0] += ' ' + suffix_str
|
| 106 |
+
|
| 107 |
+
if not notice:
|
| 108 |
+
raise ValueError('The `notice` arg must not be empty.')
|
| 109 |
+
|
| 110 |
+
notice[0] = f'{notice_type}: {notice[0]}'
|
| 111 |
+
notice = [''] + notice + ([instructions] if instructions else [])
|
| 112 |
+
|
| 113 |
+
if len(lines) > 1:
|
| 114 |
+
# Make sure that we keep our distance from the main body
|
| 115 |
+
if lines[1].strip():
|
| 116 |
+
notice.append('')
|
| 117 |
+
|
| 118 |
+
lines[1:1] = notice
|
| 119 |
+
else:
|
| 120 |
+
lines += notice
|
| 121 |
+
|
| 122 |
+
return '\n'.join(lines)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def validate_callable(func, decorator_name):
|
| 126 |
+
if not hasattr(func, '__call__'):
|
| 127 |
+
raise ValueError(
|
| 128 |
+
'%s is not a function. If this is a property, make sure'
|
| 129 |
+
' @property appears before @%s in your source code:'
|
| 130 |
+
'\n\n@property\n@%s\ndef method(...)' % (
|
| 131 |
+
func, decorator_name, decorator_name))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class classproperty(object): # pylint: disable=invalid-name
|
| 135 |
+
"""Class property decorator.
|
| 136 |
+
|
| 137 |
+
Example usage:
|
| 138 |
+
|
| 139 |
+
class MyClass(object):
|
| 140 |
+
|
| 141 |
+
@classproperty
|
| 142 |
+
def value(cls):
|
| 143 |
+
return '123'
|
| 144 |
+
|
| 145 |
+
> print MyClass.value
|
| 146 |
+
123
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(self, func):
|
| 150 |
+
self._func = func
|
| 151 |
+
|
| 152 |
+
def __get__(self, owner_self, owner_cls):
|
| 153 |
+
return self._func(owner_cls)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class _CachedClassProperty(object):
|
| 157 |
+
"""Cached class property decorator.
|
| 158 |
+
|
| 159 |
+
Transforms a class method into a property whose value is computed once
|
| 160 |
+
and then cached as a normal attribute for the life of the class. Example
|
| 161 |
+
usage:
|
| 162 |
+
|
| 163 |
+
>>> class MyClass(object):
|
| 164 |
+
... @cached_classproperty
|
| 165 |
+
... def value(cls):
|
| 166 |
+
... print("Computing value")
|
| 167 |
+
... return '<property of %s>' % cls.__name__
|
| 168 |
+
>>> class MySubclass(MyClass):
|
| 169 |
+
... pass
|
| 170 |
+
>>> MyClass.value
|
| 171 |
+
Computing value
|
| 172 |
+
'<property of MyClass>'
|
| 173 |
+
>>> MyClass.value # uses cached value
|
| 174 |
+
'<property of MyClass>'
|
| 175 |
+
>>> MySubclass.value
|
| 176 |
+
Computing value
|
| 177 |
+
'<property of MySubclass>'
|
| 178 |
+
|
| 179 |
+
This decorator is similar to `functools.cached_property`, but it adds a
|
| 180 |
+
property to the class, not to individual instances.
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, func):
|
| 184 |
+
self._func = func
|
| 185 |
+
self._cache = {}
|
| 186 |
+
|
| 187 |
+
def __get__(self, obj, objtype):
|
| 188 |
+
if objtype not in self._cache:
|
| 189 |
+
self._cache[objtype] = self._func(objtype)
|
| 190 |
+
return self._cache[objtype]
|
| 191 |
+
|
| 192 |
+
def __set__(self, obj, value):
|
| 193 |
+
raise AttributeError('property %s is read-only' % self._func.__name__)
|
| 194 |
+
|
| 195 |
+
def __delete__(self, obj):
|
| 196 |
+
raise AttributeError('property %s is read-only' % self._func.__name__)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def cached_classproperty(func):
|
| 200 |
+
return _CachedClassProperty(func)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
cached_classproperty.__doc__ = _CachedClassProperty.__doc__
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/deprecated_module.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""A deprecated module.
|
| 16 |
+
|
| 17 |
+
For testing `deprecation.deprecate_moved_module`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from tensorflow.python.util import deprecated_module_new
|
| 21 |
+
from tensorflow.python.util import deprecation
|
| 22 |
+
|
| 23 |
+
__getattr__ = deprecation.deprecate_moved_module(
|
| 24 |
+
__name__, deprecated_module_new, "2.9")
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/deprecated_module_new.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""A module to replace deprecated_module.
|
| 16 |
+
|
| 17 |
+
For testing `deprecation.deprecate_moved_module`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def a():
|
| 22 |
+
return 1
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/deprecation.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Tensor utility functions."""
|
| 16 |
+
import collections
|
| 17 |
+
import functools
|
| 18 |
+
import inspect
|
| 19 |
+
import re
|
| 20 |
+
|
| 21 |
+
from tensorflow.python.framework import strict_mode
|
| 22 |
+
from tensorflow.python.platform import tf_logging as logging
|
| 23 |
+
from tensorflow.python.util import decorator_utils
|
| 24 |
+
from tensorflow.python.util import is_in_graph_mode
|
| 25 |
+
from tensorflow.python.util import tf_contextlib
|
| 26 |
+
from tensorflow.python.util import tf_decorator
|
| 27 |
+
from tensorflow.python.util import tf_inspect
|
| 28 |
+
from tensorflow.tools.docs import doc_controls
|
| 29 |
+
|
| 30 |
+
# Allow deprecation warnings to be silenced temporarily with a context manager.
|
| 31 |
+
_PRINT_DEPRECATION_WARNINGS = True
|
| 32 |
+
|
| 33 |
+
# Remember which deprecation warnings have been printed already.
|
| 34 |
+
_PRINTED_WARNING = {}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DeprecatedNamesAlreadySetError(Exception):
|
| 38 |
+
"""Raised when setting deprecated names multiple times for the same symbol."""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _log_deprecation(msg, *args, **kwargs):
|
| 42 |
+
"""Raises errors for deprecated methods if in strict mode, warns otherwise."""
|
| 43 |
+
if strict_mode.STRICT_MODE:
|
| 44 |
+
logging.error(msg, *args, **kwargs)
|
| 45 |
+
raise RuntimeError(
|
| 46 |
+
'This behavior has been deprecated, which raises an error in strict'
|
| 47 |
+
' mode.'
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
logging.warning(msg, *args, **kwargs)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
|
| 54 |
+
"""Adds a deprecation notice to a docstring for deprecated functions."""
|
| 55 |
+
main_text = [
|
| 56 |
+
'THIS FUNCTION IS DEPRECATED. It will be removed %s.'
|
| 57 |
+
% ('in a future version' if date is None else ('after %s' % date))
|
| 58 |
+
]
|
| 59 |
+
if instructions:
|
| 60 |
+
main_text.append('Instructions for updating:')
|
| 61 |
+
return decorator_utils.add_notice_to_docstring(
|
| 62 |
+
doc,
|
| 63 |
+
instructions,
|
| 64 |
+
'DEPRECATED FUNCTION',
|
| 65 |
+
'(deprecated)',
|
| 66 |
+
main_text,
|
| 67 |
+
notice_type='Deprecated')
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _add_deprecated_arg_notice_to_docstring(doc, date, instructions,
|
| 71 |
+
deprecated_names):
|
| 72 |
+
"""Adds a deprecation notice to a docstring for deprecated arguments."""
|
| 73 |
+
|
| 74 |
+
deprecation_string = ', '.join(sorted(deprecated_names))
|
| 75 |
+
|
| 76 |
+
return decorator_utils.add_notice_to_docstring(
|
| 77 |
+
doc,
|
| 78 |
+
instructions,
|
| 79 |
+
'DEPRECATED FUNCTION ARGUMENTS',
|
| 80 |
+
'(deprecated arguments)', [
|
| 81 |
+
'SOME ARGUMENTS ARE DEPRECATED: `(%s)`. '
|
| 82 |
+
'They will be removed %s.' %
|
| 83 |
+
(deprecation_string, 'in a future version' if date is None else
|
| 84 |
+
('after %s' % date)), 'Instructions for updating:'
|
| 85 |
+
],
|
| 86 |
+
notice_type='Deprecated')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _add_deprecated_arg_value_notice_to_docstring(doc, date, instructions,
|
| 90 |
+
deprecated_name_value_dict):
|
| 91 |
+
"""Adds a deprecation notice to a docstring for deprecated arguments."""
|
| 92 |
+
|
| 93 |
+
deprecation_string = ', '.join(
|
| 94 |
+
'%s=%r' % (key, value)
|
| 95 |
+
for key, value in sorted(deprecated_name_value_dict.items()))
|
| 96 |
+
|
| 97 |
+
when = 'in a future version' if date is None else ('after %s' % date)
|
| 98 |
+
|
| 99 |
+
return decorator_utils.add_notice_to_docstring(
|
| 100 |
+
doc,
|
| 101 |
+
instructions,
|
| 102 |
+
'DEPRECATED FUNCTION ARGUMENT VALUES',
|
| 103 |
+
'(deprecated argument values)', [
|
| 104 |
+
'SOME ARGUMENT VALUES ARE DEPRECATED: `(%s)`. '
|
| 105 |
+
'They will be removed %s.' %
|
| 106 |
+
(deprecation_string, when), 'Instructions for updating:'
|
| 107 |
+
],
|
| 108 |
+
notice_type='Deprecated')
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _validate_deprecation_args(date, instructions):
|
| 112 |
+
if date is not None and not re.match(r'20\d\d-[01]\d-[0123]\d', date):
|
| 113 |
+
raise ValueError(f'Date must be in format YYYY-MM-DD. Received: {date}')
|
| 114 |
+
if not instructions:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
'Don\'t deprecate things without conversion instructions! Specify '
|
| 117 |
+
'the `instructions` argument.')
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _call_location(outer=False):
|
| 121 |
+
"""Returns call location given level up from current call."""
|
| 122 |
+
# Two up: <_call_location>, <_call_location's caller>
|
| 123 |
+
# tf_inspect is not required here. Please ignore the lint warning by adding
|
| 124 |
+
# DISABLE_IMPORT_INSPECT_CHECK=TRUE to your cl description. Using it caused
|
| 125 |
+
# test timeouts (b/189384061).
|
| 126 |
+
f = inspect.currentframe().f_back.f_back
|
| 127 |
+
parent = f and f.f_back
|
| 128 |
+
if outer and parent is not None:
|
| 129 |
+
f = parent
|
| 130 |
+
return '{}:{}'.format(f.f_code.co_filename, f.f_lineno)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _safe_eq(a, b):
|
| 134 |
+
if a is None or b is None:
|
| 135 |
+
return a is None and b is None
|
| 136 |
+
return a == b
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _wrap_decorator(wrapped_function, decorator_name):
|
| 140 |
+
"""Indicate that one function wraps another.
|
| 141 |
+
|
| 142 |
+
This decorator wraps a function using `tf_decorator.make_decorator`
|
| 143 |
+
so that doc generation scripts can pick up original function
|
| 144 |
+
signature.
|
| 145 |
+
It would be better to use @functools.wrap decorator, but it would
|
| 146 |
+
not update function signature to match wrapped function in Python 2.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
wrapped_function: The function that decorated function wraps.
|
| 150 |
+
decorator_name: The name of the decorator.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Function that accepts wrapper function as an argument and returns
|
| 154 |
+
`TFDecorator` instance.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def wrapper(wrapper_func):
|
| 158 |
+
return tf_decorator.make_decorator(wrapped_function, wrapper_func,
|
| 159 |
+
decorator_name)
|
| 160 |
+
|
| 161 |
+
return wrapper
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def deprecated_alias(deprecated_name, name, func_or_class, warn_once=True):
|
| 165 |
+
"""Deprecate a symbol in favor of a new name with identical semantics.
|
| 166 |
+
|
| 167 |
+
This function is meant to be used when defining a backwards-compatibility
|
| 168 |
+
alias for a symbol which has been moved. For example:
|
| 169 |
+
|
| 170 |
+
module1.py:
|
| 171 |
+
```python
|
| 172 |
+
class NewNameForClass: pass
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
module2.py:
|
| 176 |
+
```python
|
| 177 |
+
import module1
|
| 178 |
+
|
| 179 |
+
DeprecatedNameForClass = deprecated_alias(
|
| 180 |
+
deprecated_name='module2.DeprecatedNameForClass',
|
| 181 |
+
name='module1.NewNameForClass',
|
| 182 |
+
func_or_class=module1.NewNameForClass)
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
This function works for classes and functions.
|
| 186 |
+
|
| 187 |
+
For classes, it creates a new class which is functionally identical (it
|
| 188 |
+
inherits from the original, and overrides its constructor), but which prints
|
| 189 |
+
a deprecation warning when an instance is created. It also adds a deprecation
|
| 190 |
+
notice to the class' docstring.
|
| 191 |
+
|
| 192 |
+
For functions, it returns a function wrapped by `tf_decorator.make_decorator`.
|
| 193 |
+
That function prints a warning when used, and has a deprecation notice in its
|
| 194 |
+
docstring. This is more or less equivalent (the deprecation warning has
|
| 195 |
+
slightly different text) to writing:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
@deprecated
|
| 199 |
+
def deprecated_alias(original_args):
|
| 200 |
+
real_function(original_args)
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
deprecated_name: The name of the symbol that is being deprecated, to be used
|
| 205 |
+
in the warning message. This should be its fully qualified name to avoid
|
| 206 |
+
confusion.
|
| 207 |
+
name: The name of the symbol that is to be used instead of the deprecated
|
| 208 |
+
name. This should be a fully qualified name to avoid confusion.
|
| 209 |
+
func_or_class: The (non-deprecated) class or function for which a deprecated
|
| 210 |
+
alias should be created.
|
| 211 |
+
warn_once: If True (the default), only print a deprecation warning the first
|
| 212 |
+
time this function is used, or the class is instantiated.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
A wrapped version of `func_or_class` which prints a deprecation warning on
|
| 216 |
+
use and has a modified docstring.
|
| 217 |
+
"""
|
| 218 |
+
if tf_inspect.isclass(func_or_class):
|
| 219 |
+
|
| 220 |
+
# Make a new class with __init__ wrapped in a warning.
|
| 221 |
+
class _NewClass(func_or_class): # pylint: disable=missing-docstring
|
| 222 |
+
__doc__ = decorator_utils.add_notice_to_docstring(
|
| 223 |
+
func_or_class.__doc__,
|
| 224 |
+
'Please use %s instead.' % name,
|
| 225 |
+
'DEPRECATED CLASS',
|
| 226 |
+
'(deprecated)', [('THIS CLASS IS DEPRECATED. '
|
| 227 |
+
'It will be removed in a future version. ')],
|
| 228 |
+
notice_type='Deprecated')
|
| 229 |
+
__name__ = func_or_class.__name__
|
| 230 |
+
__module__ = _call_location(outer=True)
|
| 231 |
+
|
| 232 |
+
@_wrap_decorator(func_or_class.__init__, 'deprecated_alias')
|
| 233 |
+
def __init__(self, *args, **kwargs):
|
| 234 |
+
if hasattr(_NewClass.__init__, '__func__'):
|
| 235 |
+
# Python 2
|
| 236 |
+
_NewClass.__init__.__func__.__doc__ = func_or_class.__init__.__doc__
|
| 237 |
+
else:
|
| 238 |
+
# Python 3
|
| 239 |
+
_NewClass.__init__.__doc__ = func_or_class.__init__.__doc__
|
| 240 |
+
|
| 241 |
+
if _PRINT_DEPRECATION_WARNINGS:
|
| 242 |
+
# We're making the alias as we speak. The original may have other
|
| 243 |
+
# aliases, so we cannot use it to check for whether it's already been
|
| 244 |
+
# warned about.
|
| 245 |
+
if _NewClass.__init__ not in _PRINTED_WARNING:
|
| 246 |
+
if warn_once:
|
| 247 |
+
_PRINTED_WARNING[_NewClass.__init__] = True
|
| 248 |
+
_log_deprecation(
|
| 249 |
+
'From %s: The name %s is deprecated. Please use %s instead.\n',
|
| 250 |
+
_call_location(), deprecated_name, name)
|
| 251 |
+
super(_NewClass, self).__init__(*args, **kwargs)
|
| 252 |
+
|
| 253 |
+
return _NewClass
|
| 254 |
+
else:
|
| 255 |
+
decorator_utils.validate_callable(func_or_class, 'deprecated')
|
| 256 |
+
|
| 257 |
+
# Make a wrapper for the original
|
| 258 |
+
@functools.wraps(func_or_class)
|
| 259 |
+
def new_func(*args, **kwargs): # pylint: disable=missing-docstring
|
| 260 |
+
if _PRINT_DEPRECATION_WARNINGS:
|
| 261 |
+
# We're making the alias as we speak. The original may have other
|
| 262 |
+
# aliases, so we cannot use it to check for whether it's already been
|
| 263 |
+
# warned about.
|
| 264 |
+
if new_func not in _PRINTED_WARNING:
|
| 265 |
+
if warn_once:
|
| 266 |
+
_PRINTED_WARNING[new_func] = True
|
| 267 |
+
_log_deprecation(
|
| 268 |
+
'From %s: The name %s is deprecated. Please use %s instead.\n',
|
| 269 |
+
_call_location(), deprecated_name, name)
|
| 270 |
+
return func_or_class(*args, **kwargs)
|
| 271 |
+
|
| 272 |
+
return tf_decorator.make_decorator(
|
| 273 |
+
func_or_class, new_func, 'deprecated',
|
| 274 |
+
_add_deprecated_function_notice_to_docstring(
|
| 275 |
+
func_or_class.__doc__, None, 'Please use %s instead.' % name))
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def deprecated_endpoints(*args):
|
| 279 |
+
"""Decorator for marking endpoints deprecated.
|
| 280 |
+
|
| 281 |
+
This decorator does not print deprecation messages.
|
| 282 |
+
TODO(annarev): eventually start printing deprecation warnings when
|
| 283 |
+
@deprecation_endpoints decorator is added.
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
*args: Deprecated endpoint names.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
A function that takes symbol as an argument and adds
|
| 290 |
+
_tf_deprecated_api_names to that symbol.
|
| 291 |
+
_tf_deprecated_api_names would be set to a list of deprecated
|
| 292 |
+
endpoint names for the symbol.
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
def deprecated_wrapper(func):
|
| 296 |
+
# pylint: disable=protected-access
|
| 297 |
+
if '_tf_deprecated_api_names' in func.__dict__:
|
| 298 |
+
raise DeprecatedNamesAlreadySetError(
|
| 299 |
+
f'Cannot set deprecated names for {func.__name__} to {args}. '
|
| 300 |
+
'Deprecated names are already set to '
|
| 301 |
+
f'{func._tf_deprecated_api_names}.')
|
| 302 |
+
func._tf_deprecated_api_names = args
|
| 303 |
+
# pylint: disable=protected-access
|
| 304 |
+
return func
|
| 305 |
+
|
| 306 |
+
return deprecated_wrapper
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def deprecated(date, instructions, warn_once=True):
|
| 310 |
+
"""Decorator for marking functions or methods deprecated.
|
| 311 |
+
|
| 312 |
+
This decorator logs a deprecation warning whenever the decorated function is
|
| 313 |
+
called. It has the following format:
|
| 314 |
+
|
| 315 |
+
<function> (from <module>) is deprecated and will be removed after <date>.
|
| 316 |
+
Instructions for updating:
|
| 317 |
+
<instructions>
|
| 318 |
+
|
| 319 |
+
If `date` is None, 'after <date>' is replaced with 'in a future version'.
|
| 320 |
+
<function> will include the class name if it is a method.
|
| 321 |
+
|
| 322 |
+
It also edits the docstring of the function: ' (deprecated)' is appended
|
| 323 |
+
to the first line of the docstring and a deprecation notice is prepended
|
| 324 |
+
to the rest of the docstring.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
date: String or None. The date the function is scheduled to be removed. Must
|
| 328 |
+
be ISO 8601 (YYYY-MM-DD), or None.
|
| 329 |
+
instructions: String. Instructions on how to update code using the
|
| 330 |
+
deprecated function.
|
| 331 |
+
warn_once: Boolean. Set to `True` to warn only the first time the decorated
|
| 332 |
+
function is called. Otherwise, every call will log a warning.
|
| 333 |
+
|
| 334 |
+
Returns:
|
| 335 |
+
Decorated function or method.
|
| 336 |
+
|
| 337 |
+
Raises:
|
| 338 |
+
ValueError: If date is not None or in ISO 8601 format, or instructions are
|
| 339 |
+
empty.
|
| 340 |
+
"""
|
| 341 |
+
_validate_deprecation_args(date, instructions)
|
| 342 |
+
|
| 343 |
+
def deprecated_wrapper(func_or_class):
|
| 344 |
+
"""Deprecation wrapper."""
|
| 345 |
+
if isinstance(func_or_class, type):
|
| 346 |
+
# If a class is deprecated, you actually want to wrap the constructor.
|
| 347 |
+
cls = func_or_class
|
| 348 |
+
if cls.__new__ is object.__new__:
|
| 349 |
+
# If a class defaults to its parent's constructor, wrap that instead.
|
| 350 |
+
func = cls.__init__
|
| 351 |
+
constructor_name = '__init__'
|
| 352 |
+
decorators, _ = tf_decorator.unwrap(func)
|
| 353 |
+
for decorator in decorators:
|
| 354 |
+
if decorator.decorator_name == 'deprecated':
|
| 355 |
+
# If the parent is already deprecated, there's nothing to do.
|
| 356 |
+
return cls
|
| 357 |
+
else:
|
| 358 |
+
func = cls.__new__
|
| 359 |
+
constructor_name = '__new__'
|
| 360 |
+
|
| 361 |
+
else:
|
| 362 |
+
cls = None
|
| 363 |
+
constructor_name = None
|
| 364 |
+
func = func_or_class
|
| 365 |
+
|
| 366 |
+
decorator_utils.validate_callable(func, 'deprecated')
|
| 367 |
+
|
| 368 |
+
@_wrap_decorator(func, 'deprecated')
|
| 369 |
+
def new_func(*args, **kwargs): # pylint: disable=missing-docstring
|
| 370 |
+
if _PRINT_DEPRECATION_WARNINGS:
|
| 371 |
+
if func not in _PRINTED_WARNING and cls not in _PRINTED_WARNING:
|
| 372 |
+
if warn_once:
|
| 373 |
+
_PRINTED_WARNING[func] = True
|
| 374 |
+
if cls:
|
| 375 |
+
_PRINTED_WARNING[cls] = True
|
| 376 |
+
_log_deprecation(
|
| 377 |
+
'From %s: %s (from %s) is deprecated and will be removed %s.\n'
|
| 378 |
+
'Instructions for updating:\n%s', _call_location(),
|
| 379 |
+
decorator_utils.get_qualified_name(func),
|
| 380 |
+
func_or_class.__module__,
|
| 381 |
+
'in a future version' if date is None else ('after %s' % date),
|
| 382 |
+
instructions)
|
| 383 |
+
return func(*args, **kwargs)
|
| 384 |
+
|
| 385 |
+
doc_controls.set_deprecated(new_func)
|
| 386 |
+
new_func = tf_decorator.make_decorator(
|
| 387 |
+
func, new_func, 'deprecated',
|
| 388 |
+
_add_deprecated_function_notice_to_docstring(func.__doc__, date,
|
| 389 |
+
instructions))
|
| 390 |
+
new_func.__signature__ = inspect.signature(func)
|
| 391 |
+
|
| 392 |
+
if cls is None:
|
| 393 |
+
return new_func
|
| 394 |
+
else:
|
| 395 |
+
# Insert the wrapped function as the constructor
|
| 396 |
+
setattr(cls, constructor_name, new_func)
|
| 397 |
+
|
| 398 |
+
# And update the docstring of the class.
|
| 399 |
+
cls.__doc__ = _add_deprecated_function_notice_to_docstring(
|
| 400 |
+
cls.__doc__, date, instructions)
|
| 401 |
+
|
| 402 |
+
return cls
|
| 403 |
+
|
| 404 |
+
return deprecated_wrapper
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
DeprecatedArgSpec = collections.namedtuple(
|
| 408 |
+
'DeprecatedArgSpec', ['position', 'has_ok_value', 'ok_value'])
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples,
|
| 412 |
+
**kwargs):
|
| 413 |
+
"""Decorator for marking specific function arguments as deprecated.
|
| 414 |
+
|
| 415 |
+
This decorator logs a deprecation warning whenever the decorated function is
|
| 416 |
+
called with the deprecated argument. It has the following format:
|
| 417 |
+
|
| 418 |
+
Calling <function> (from <module>) with <arg> is deprecated and will be
|
| 419 |
+
removed after <date>. Instructions for updating:
|
| 420 |
+
<instructions>
|
| 421 |
+
|
| 422 |
+
If `date` is None, 'after <date>' is replaced with 'in a future version'.
|
| 423 |
+
<function> includes the class name if it is a method.
|
| 424 |
+
|
| 425 |
+
It also edits the docstring of the function: ' (deprecated arguments)' is
|
| 426 |
+
appended to the first line of the docstring and a deprecation notice is
|
| 427 |
+
prepended to the rest of the docstring.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
date: String or None. The date the function is scheduled to be removed. Must
|
| 431 |
+
be ISO 8601 (YYYY-MM-DD), or None.
|
| 432 |
+
instructions: String. Instructions on how to update code using the
|
| 433 |
+
deprecated function.
|
| 434 |
+
*deprecated_arg_names_or_tuples: String or 2-Tuple (String, ok_val). The
|
| 435 |
+
string is the deprecated argument name. Optionally, an ok-value may be
|
| 436 |
+
provided. If the user provided argument equals this value, the warning is
|
| 437 |
+
suppressed.
|
| 438 |
+
**kwargs: If `warn_once=False` is passed, every call with a deprecated
|
| 439 |
+
argument will log a warning. The default behavior is to only warn the
|
| 440 |
+
first time the function is called with any given deprecated argument. All
|
| 441 |
+
other kwargs raise `ValueError`.
|
| 442 |
+
|
| 443 |
+
Returns:
|
| 444 |
+
Decorated function or method.
|
| 445 |
+
|
| 446 |
+
Raises:
|
| 447 |
+
ValueError: If date is not None or in ISO 8601 format, instructions are
|
| 448 |
+
empty, the deprecated arguments are not present in the function
|
| 449 |
+
signature, the second element of a deprecated_tuple is not a
|
| 450 |
+
list, or if a kwarg other than `warn_once` is passed.
|
| 451 |
+
"""
|
| 452 |
+
_validate_deprecation_args(date, instructions)
|
| 453 |
+
if not deprecated_arg_names_or_tuples:
|
| 454 |
+
raise ValueError('Specify which argument is deprecated.')
|
| 455 |
+
if kwargs and list(kwargs.keys()) != ['warn_once']:
|
| 456 |
+
kwargs.pop('warn_once', None)
|
| 457 |
+
raise ValueError(f'Illegal argument passed to deprecated_args: {kwargs}')
|
| 458 |
+
warn_once = kwargs.get('warn_once', True)
|
| 459 |
+
|
| 460 |
+
def _get_arg_names_to_ok_vals():
|
| 461 |
+
"""Returns a dict mapping arg_name to DeprecatedArgSpec w/o position."""
|
| 462 |
+
d = {}
|
| 463 |
+
for name_or_tuple in deprecated_arg_names_or_tuples:
|
| 464 |
+
if isinstance(name_or_tuple, tuple):
|
| 465 |
+
d[name_or_tuple[0]] = DeprecatedArgSpec(-1, True, name_or_tuple[1])
|
| 466 |
+
else:
|
| 467 |
+
d[name_or_tuple] = DeprecatedArgSpec(-1, False, None)
|
| 468 |
+
return d
|
| 469 |
+
|
| 470 |
+
def _get_deprecated_positional_arguments(names_to_ok_vals, arg_spec):
|
| 471 |
+
"""Builds a dictionary from deprecated arguments to their spec.
|
| 472 |
+
|
| 473 |
+
Returned dict is keyed by argument name.
|
| 474 |
+
Each value is a DeprecatedArgSpec with the following fields:
|
| 475 |
+
position: The zero-based argument position of the argument
|
| 476 |
+
within the signature. None if the argument isn't found in
|
| 477 |
+
the signature.
|
| 478 |
+
ok_values: Values of this argument for which warning will be
|
| 479 |
+
suppressed.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
names_to_ok_vals: dict from string arg_name to a list of values, possibly
|
| 483 |
+
empty, which should not elicit a warning.
|
| 484 |
+
arg_spec: Output from tf_inspect.getfullargspec on the called function.
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
Dictionary from arg_name to DeprecatedArgSpec.
|
| 488 |
+
"""
|
| 489 |
+
# Extract argument list
|
| 490 |
+
arg_space = arg_spec.args + arg_spec.kwonlyargs
|
| 491 |
+
arg_name_to_pos = {name: pos for pos, name in enumerate(arg_space)}
|
| 492 |
+
deprecated_positional_args = {}
|
| 493 |
+
for arg_name, spec in iter(names_to_ok_vals.items()):
|
| 494 |
+
if arg_name in arg_name_to_pos:
|
| 495 |
+
pos = arg_name_to_pos[arg_name]
|
| 496 |
+
deprecated_positional_args[arg_name] = DeprecatedArgSpec(
|
| 497 |
+
pos, spec.has_ok_value, spec.ok_value)
|
| 498 |
+
return deprecated_positional_args
|
| 499 |
+
|
| 500 |
+
deprecated_arg_names = _get_arg_names_to_ok_vals()
|
| 501 |
+
|
| 502 |
+
def deprecated_wrapper(func):
|
| 503 |
+
"""Deprecation decorator."""
|
| 504 |
+
decorator_utils.validate_callable(func, 'deprecated_args')
|
| 505 |
+
|
| 506 |
+
arg_spec = tf_inspect.getfullargspec(func)
|
| 507 |
+
deprecated_positions = _get_deprecated_positional_arguments(
|
| 508 |
+
deprecated_arg_names, arg_spec)
|
| 509 |
+
|
| 510 |
+
is_varargs_deprecated = arg_spec.varargs in deprecated_arg_names
|
| 511 |
+
is_kwargs_deprecated = arg_spec.varkw in deprecated_arg_names
|
| 512 |
+
|
| 513 |
+
if (len(deprecated_positions) + is_varargs_deprecated + is_kwargs_deprecated
|
| 514 |
+
!= len(deprecated_arg_names_or_tuples)):
|
| 515 |
+
known_args = (
|
| 516 |
+
arg_spec.args + arg_spec.kwonlyargs +
|
| 517 |
+
[arg_spec.varargs, arg_spec.varkw])
|
| 518 |
+
missing_args = [
|
| 519 |
+
arg_name for arg_name in deprecated_arg_names
|
| 520 |
+
if arg_name not in known_args
|
| 521 |
+
]
|
| 522 |
+
raise ValueError('The following deprecated arguments are not present '
|
| 523 |
+
f'in the function signature: {missing_args}. '
|
| 524 |
+
'Expected arguments from the following list: '
|
| 525 |
+
f'{known_args}.')
|
| 526 |
+
|
| 527 |
+
def _same_value(a, b):
|
| 528 |
+
"""A comparison operation that works for multiple object types.
|
| 529 |
+
|
| 530 |
+
Returns True for two empty lists, two numeric values with the
|
| 531 |
+
same value, etc.
|
| 532 |
+
|
| 533 |
+
Returns False for (pd.DataFrame, None), and other pairs which
|
| 534 |
+
should not be considered equivalent.
|
| 535 |
+
|
| 536 |
+
Args:
|
| 537 |
+
a: value one of the comparison.
|
| 538 |
+
b: value two of the comparison.
|
| 539 |
+
|
| 540 |
+
Returns:
|
| 541 |
+
A boolean indicating whether the two inputs are the same value
|
| 542 |
+
for the purposes of deprecation.
|
| 543 |
+
"""
|
| 544 |
+
if a is b:
|
| 545 |
+
return True
|
| 546 |
+
try:
|
| 547 |
+
equality = a == b
|
| 548 |
+
if isinstance(equality, bool):
|
| 549 |
+
return equality
|
| 550 |
+
except TypeError:
|
| 551 |
+
return False
|
| 552 |
+
return False
|
| 553 |
+
|
| 554 |
+
@functools.wraps(func)
|
| 555 |
+
def new_func(*args, **kwargs):
|
| 556 |
+
"""Deprecation wrapper."""
|
| 557 |
+
# TODO(apassos) figure out a way to have reasonable performance with
|
| 558 |
+
# deprecation warnings and eager mode.
|
| 559 |
+
if is_in_graph_mode.IS_IN_GRAPH_MODE() and _PRINT_DEPRECATION_WARNINGS:
|
| 560 |
+
invalid_args = []
|
| 561 |
+
named_args = tf_inspect.getcallargs(func, *args, **kwargs)
|
| 562 |
+
for arg_name, spec in iter(deprecated_positions.items()):
|
| 563 |
+
if (spec.position < len(args) and
|
| 564 |
+
not (spec.has_ok_value and
|
| 565 |
+
_same_value(named_args[arg_name], spec.ok_value))):
|
| 566 |
+
invalid_args.append(arg_name)
|
| 567 |
+
if is_varargs_deprecated and len(args) > len(arg_spec.args):
|
| 568 |
+
invalid_args.append(arg_spec.varargs)
|
| 569 |
+
if is_kwargs_deprecated and kwargs:
|
| 570 |
+
invalid_args.append(arg_spec.varkw)
|
| 571 |
+
for arg_name in deprecated_arg_names:
|
| 572 |
+
if (arg_name in kwargs and
|
| 573 |
+
not (deprecated_positions[arg_name].has_ok_value and
|
| 574 |
+
_same_value(named_args[arg_name],
|
| 575 |
+
deprecated_positions[arg_name].ok_value))):
|
| 576 |
+
invalid_args.append(arg_name)
|
| 577 |
+
for arg_name in invalid_args:
|
| 578 |
+
if (func, arg_name) not in _PRINTED_WARNING:
|
| 579 |
+
if warn_once:
|
| 580 |
+
_PRINTED_WARNING[(func, arg_name)] = True
|
| 581 |
+
_log_deprecation(
|
| 582 |
+
'From %s: calling %s (from %s) with %s is deprecated and will '
|
| 583 |
+
'be removed %s.\nInstructions for updating:\n%s',
|
| 584 |
+
_call_location(), decorator_utils.get_qualified_name(func),
|
| 585 |
+
func.__module__, arg_name,
|
| 586 |
+
'in a future version' if date is None else ('after %s' % date),
|
| 587 |
+
instructions)
|
| 588 |
+
return func(*args, **kwargs)
|
| 589 |
+
|
| 590 |
+
doc = _add_deprecated_arg_notice_to_docstring(
|
| 591 |
+
func.__doc__, date, instructions, sorted(deprecated_arg_names.keys()))
|
| 592 |
+
return tf_decorator.make_decorator(func, new_func, 'deprecated', doc)
|
| 593 |
+
|
| 594 |
+
return deprecated_wrapper
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def deprecated_arg_values(date,
|
| 598 |
+
instructions,
|
| 599 |
+
warn_once=True,
|
| 600 |
+
**deprecated_kwargs):
|
| 601 |
+
"""Decorator for marking specific function argument values as deprecated.
|
| 602 |
+
|
| 603 |
+
This decorator logs a deprecation warning whenever the decorated function is
|
| 604 |
+
called with the deprecated argument values. It has the following format:
|
| 605 |
+
|
| 606 |
+
Calling <function> (from <module>) with <arg>=<value> is deprecated and
|
| 607 |
+
will be removed after <date>. Instructions for updating:
|
| 608 |
+
<instructions>
|
| 609 |
+
|
| 610 |
+
If `date` is None, 'after <date>' is replaced with 'in a future version'.
|
| 611 |
+
<function> will include the class name if it is a method.
|
| 612 |
+
|
| 613 |
+
It also edits the docstring of the function: ' (deprecated arguments)' is
|
| 614 |
+
appended to the first line of the docstring and a deprecation notice is
|
| 615 |
+
prepended to the rest of the docstring.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
date: String or None. The date the function is scheduled to be removed. Must
|
| 619 |
+
be ISO 8601 (YYYY-MM-DD), or None
|
| 620 |
+
instructions: String. Instructions on how to update code using the
|
| 621 |
+
deprecated function.
|
| 622 |
+
warn_once: If `True`, warn only the first time this function is called with
|
| 623 |
+
deprecated argument values. Otherwise, every call (with a deprecated
|
| 624 |
+
argument value) will log a warning.
|
| 625 |
+
**deprecated_kwargs: The deprecated argument values.
|
| 626 |
+
|
| 627 |
+
Returns:
|
| 628 |
+
Decorated function or method.
|
| 629 |
+
|
| 630 |
+
Raises:
|
| 631 |
+
ValueError: If date is not None or in ISO 8601 format, or instructions are
|
| 632 |
+
empty.
|
| 633 |
+
"""
|
| 634 |
+
_validate_deprecation_args(date, instructions)
|
| 635 |
+
if not deprecated_kwargs:
|
| 636 |
+
raise ValueError('Specify which argument values are deprecated.')
|
| 637 |
+
|
| 638 |
+
def deprecated_wrapper(func):
|
| 639 |
+
"""Deprecation decorator."""
|
| 640 |
+
decorator_utils.validate_callable(func, 'deprecated_arg_values')
|
| 641 |
+
|
| 642 |
+
@functools.wraps(func)
|
| 643 |
+
def new_func(*args, **kwargs):
|
| 644 |
+
"""Deprecation wrapper."""
|
| 645 |
+
if _PRINT_DEPRECATION_WARNINGS:
|
| 646 |
+
named_args = tf_inspect.getcallargs(func, *args, **kwargs)
|
| 647 |
+
for arg_name, arg_value in deprecated_kwargs.items():
|
| 648 |
+
if arg_name in named_args and _safe_eq(named_args[arg_name],
|
| 649 |
+
arg_value):
|
| 650 |
+
if (func, arg_name) not in _PRINTED_WARNING:
|
| 651 |
+
if warn_once:
|
| 652 |
+
_PRINTED_WARNING[(func, arg_name)] = True
|
| 653 |
+
_log_deprecation(
|
| 654 |
+
'From %s: calling %s (from %s) with %s=%s is deprecated and '
|
| 655 |
+
'will be removed %s.\nInstructions for updating:\n%s',
|
| 656 |
+
_call_location(), decorator_utils.get_qualified_name(func),
|
| 657 |
+
func.__module__, arg_name, arg_value,
|
| 658 |
+
'in a future version' if date is None else
|
| 659 |
+
('after %s' % date), instructions)
|
| 660 |
+
return func(*args, **kwargs)
|
| 661 |
+
|
| 662 |
+
doc = _add_deprecated_arg_value_notice_to_docstring(func.__doc__, date,
|
| 663 |
+
instructions,
|
| 664 |
+
deprecated_kwargs)
|
| 665 |
+
return tf_decorator.make_decorator(func, new_func, 'deprecated', doc)
|
| 666 |
+
|
| 667 |
+
return deprecated_wrapper
|
| 668 |
+
|
| 669 |
+
|
| 670 |
+
def deprecated_argument_lookup(new_name, new_value, old_name, old_value):
|
| 671 |
+
"""Looks up deprecated argument name and ensures both are not used.
|
| 672 |
+
|
| 673 |
+
Args:
|
| 674 |
+
new_name: new name of argument
|
| 675 |
+
new_value: value of new argument (or None if not used)
|
| 676 |
+
old_name: old name of argument
|
| 677 |
+
old_value: value of old argument (or None if not used)
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
The effective argument that should be used.
|
| 681 |
+
Raises:
|
| 682 |
+
ValueError: if new_value and old_value are both non-null
|
| 683 |
+
"""
|
| 684 |
+
if old_value is not None:
|
| 685 |
+
if new_value is not None:
|
| 686 |
+
raise ValueError(f"Cannot specify both '{old_name}' and '{new_name}'.")
|
| 687 |
+
return old_value
|
| 688 |
+
return new_value
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def rewrite_argument_docstring(old_doc, old_argument, new_argument):
|
| 692 |
+
return old_doc.replace('`%s`' % old_argument,
|
| 693 |
+
'`%s`' % new_argument).replace('%s:' % old_argument,
|
| 694 |
+
'%s:' % new_argument)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
@tf_contextlib.contextmanager
|
| 698 |
+
def silence():
|
| 699 |
+
"""Temporarily silence deprecation warnings."""
|
| 700 |
+
global _PRINT_DEPRECATION_WARNINGS
|
| 701 |
+
print_deprecation_warnings = _PRINT_DEPRECATION_WARNINGS
|
| 702 |
+
_PRINT_DEPRECATION_WARNINGS = False
|
| 703 |
+
yield
|
| 704 |
+
_PRINT_DEPRECATION_WARNINGS = print_deprecation_warnings
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def deprecate_moved_module(deprecated_name, new_module, deletion_version):
|
| 708 |
+
"""Logs a warning when a module that has been moved to a new location is used.
|
| 709 |
+
|
| 710 |
+
Copy the following code into the old module:
|
| 711 |
+
|
| 712 |
+
```
|
| 713 |
+
import deprecation
|
| 714 |
+
import new_module
|
| 715 |
+
|
| 716 |
+
__getattr__ = deprecation.deprecate_moved_module(
|
| 717 |
+
__name__, new_module, "2.9") # adjust version number.
|
| 718 |
+
```
|
| 719 |
+
|
| 720 |
+
Args:
|
| 721 |
+
deprecated_name: Name of old module.
|
| 722 |
+
new_module: Module to replace the old module.
|
| 723 |
+
deletion_version: Version of TensorFlow in which the old module will be
|
| 724 |
+
removed.
|
| 725 |
+
|
| 726 |
+
Returns:
|
| 727 |
+
A function that logs a warning and returns the symbol from the new module.
|
| 728 |
+
Set this function as the module's `__getattr__`.
|
| 729 |
+
"""
|
| 730 |
+
|
| 731 |
+
def getter(name):
|
| 732 |
+
if getter not in _PRINTED_WARNING and _PRINT_DEPRECATION_WARNINGS:
|
| 733 |
+
_PRINTED_WARNING[getter] = True
|
| 734 |
+
_log_deprecation(
|
| 735 |
+
'Please fix your imports. Module %s has been moved to %s. The old '
|
| 736 |
+
'module will be deleted in version %s.', deprecated_name,
|
| 737 |
+
new_module.__name__, deletion_version)
|
| 738 |
+
return getattr(new_module, name)
|
| 739 |
+
|
| 740 |
+
return getter
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
class HiddenTfApiAttribute(property):
|
| 744 |
+
"""Hides a class attribute from the public API.
|
| 745 |
+
|
| 746 |
+
Attributes in public classes can be hidden from the API by having an '_' in
|
| 747 |
+
front of the name (e.g. ClassName._variables). This doesn't work when
|
| 748 |
+
attributes or methods are inherited from a parent class. To hide inherited
|
| 749 |
+
attributes, set their values to be `deprecation.hide_attribute_from_api`.
|
| 750 |
+
"""
|
| 751 |
+
|
| 752 |
+
def __init__(self, deprecation_message):
|
| 753 |
+
|
| 754 |
+
def raise_error(unused_self):
|
| 755 |
+
raise AttributeError(deprecation_message)
|
| 756 |
+
|
| 757 |
+
super(HiddenTfApiAttribute, self).__init__(raise_error)
|
| 758 |
+
|
| 759 |
+
|
| 760 |
+
hide_attribute_from_api = HiddenTfApiAttribute # pylint: disable=invalid-name
|
| 761 |
+
|
| 762 |
+
# TODO(kathywu): Remove once cl/246395236 is submitted.
|
| 763 |
+
HIDDEN_ATTRIBUTE = HiddenTfApiAttribute('This attribute has been deprecated.')
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/dispatch.py
ADDED
|
@@ -0,0 +1,1302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Type-based dispatch for TensorFlow's Python APIs.
|
| 16 |
+
|
| 17 |
+
"Python APIs" refers to Python functions that have been exported with
|
| 18 |
+
`tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also
|
| 19 |
+
referred to as "ops".
|
| 20 |
+
|
| 21 |
+
There are currently two dispatch systems for TensorFlow:
|
| 22 |
+
|
| 23 |
+
* The "fallback dispatch" system calls an API's standard implementation first,
|
| 24 |
+
and only tries to perform dispatch if that standard implementation raises a
|
| 25 |
+
TypeError (or ValueError) exception.
|
| 26 |
+
|
| 27 |
+
* The "type-based dispatch" system checks the types of the parameters passed
|
| 28 |
+
to an API, and performs dispatch if those types match any signatures that
|
| 29 |
+
have been registered for dispatch.
|
| 30 |
+
|
| 31 |
+
The fallback dispatch system was the original dispatch system, but it was
|
| 32 |
+
somewhat brittle and had limitations, such as an inability to support dispatch
|
| 33 |
+
for some operations (like convert_to_tensor). We plan to remove the fallback
|
| 34 |
+
dispatch system in favor of the type-based dispatch system, once all users have
|
| 35 |
+
been switched over to use it.
|
| 36 |
+
|
| 37 |
+
### Fallback Dispatch
|
| 38 |
+
|
| 39 |
+
The fallback dispatch system is based on "operation dispatchers", which can be
|
| 40 |
+
used to override the behavior for TensorFlow ops when they are called with
|
| 41 |
+
otherwise unsupported argument types. In particular, when an operation is
|
| 42 |
+
called with arguments that would cause it to raise a TypeError, it falls back on
|
| 43 |
+
its registered operation dispatchers. If any registered dispatchers can handle
|
| 44 |
+
the arguments, then its result is returned. Otherwise, the original TypeError is
|
| 45 |
+
raised.
|
| 46 |
+
|
| 47 |
+
### Type-based Dispatch
|
| 48 |
+
|
| 49 |
+
The main interface for the type-based dispatch system is the `dispatch_for_api`
|
| 50 |
+
decorator, which overrides the default implementation for a TensorFlow API.
|
| 51 |
+
The decorated function (known as the "dispatch target") will override the
|
| 52 |
+
default implementation for the API when the API is called with parameters that
|
| 53 |
+
match a specified type signature.
|
| 54 |
+
|
| 55 |
+
### Dispatch Support
|
| 56 |
+
|
| 57 |
+
By default, dispatch support is added to the generated op wrappers for any
|
| 58 |
+
visible ops by default. APIs/ops that are implemented in Python can opt in to
|
| 59 |
+
dispatch support using the `add_dispatch_support` decorator.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
import collections
|
| 63 |
+
import itertools
|
| 64 |
+
import typing # pylint: disable=unused-import (used in doctests)
|
| 65 |
+
|
| 66 |
+
from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher
|
| 67 |
+
from tensorflow.python.framework import ops
|
| 68 |
+
from tensorflow.python.util import tf_decorator
|
| 69 |
+
from tensorflow.python.util import tf_export as tf_export_lib
|
| 70 |
+
from tensorflow.python.util import tf_inspect
|
| 71 |
+
from tensorflow.python.util import traceback_utils
|
| 72 |
+
from tensorflow.python.util import type_annotations
|
| 73 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Private function attributes used to store dispatchers on TensorFlow APIs.
|
| 77 |
+
FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers"
|
| 78 |
+
TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher"
|
| 79 |
+
|
| 80 |
+
# OpDispatchers which should be used for all operations.
|
| 81 |
+
_GLOBAL_DISPATCHERS = []
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
################################################################################
|
| 85 |
+
# Fallback Dispatch
|
| 86 |
+
################################################################################
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@tf_export("__internal__.dispatch.OpDispatcher", v1=[])
|
| 90 |
+
class OpDispatcher(object):
|
| 91 |
+
"""Abstract base class for TensorFlow operator dispatchers.
|
| 92 |
+
|
| 93 |
+
Each operation dispatcher acts as an override handler for a single
|
| 94 |
+
TensorFlow operation, and its results are used when the handler indicates
|
| 95 |
+
that it can handle the operation's arguments (by returning any value other
|
| 96 |
+
than `OpDispatcher.NOT_SUPPORTED`).
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
# Sentinel value that can be returned to indicate that an operation
|
| 100 |
+
# dispatcher does not support a given set of arguments.
|
| 101 |
+
NOT_SUPPORTED = object()
|
| 102 |
+
|
| 103 |
+
def handle(self, args, kwargs): # pylint: disable=unused-argument
|
| 104 |
+
"""Handle this dispatcher's operation with the specified arguments.
|
| 105 |
+
|
| 106 |
+
If this operation dispatcher can handle the given arguments, then
|
| 107 |
+
return an appropriate value (or raise an appropriate exception).
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
args: The arguments to the operation.
|
| 111 |
+
kwargs: They keyword arguments to the operation.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this
|
| 115 |
+
dispatcher can not handle the given arguments.
|
| 116 |
+
"""
|
| 117 |
+
return self.NOT_SUPPORTED
|
| 118 |
+
|
| 119 |
+
def register(self, op):
|
| 120 |
+
"""Register this dispatcher as a handler for `op`.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
op: Python function: the TensorFlow operation that should be handled. Must
|
| 124 |
+
have a dispatch list (which is added automatically for generated ops,
|
| 125 |
+
and can be added to Python ops using the `add_dispatch_support`
|
| 126 |
+
decorator).
|
| 127 |
+
"""
|
| 128 |
+
if not hasattr(op, FALLBACK_DISPATCH_ATTR):
|
| 129 |
+
raise AssertionError("Dispatching not enabled for %s" % op)
|
| 130 |
+
getattr(op, FALLBACK_DISPATCH_ATTR).append(self)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[])
|
| 134 |
+
class GlobalOpDispatcher(object):
|
| 135 |
+
"""Abstract base class for TensorFlow global operator dispatchers."""
|
| 136 |
+
|
| 137 |
+
NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED
|
| 138 |
+
|
| 139 |
+
def handle(self, op, args, kwargs):
|
| 140 |
+
"""Handle the specified operation with the specified arguments."""
|
| 141 |
+
|
| 142 |
+
def register(self):
|
| 143 |
+
"""Register this dispatcher as a handler for all ops."""
|
| 144 |
+
_GLOBAL_DISPATCHERS.append(self)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def dispatch(op, args, kwargs):
|
| 148 |
+
"""Returns the result from the first successful dispatcher for a given op.
|
| 149 |
+
|
| 150 |
+
Calls the `handle` method of each `OpDispatcher` that has been registered
|
| 151 |
+
to handle `op`, and returns the value from the first successful handler.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
op: Python function: the operation to dispatch for.
|
| 155 |
+
args: The arguments to the operation.
|
| 156 |
+
kwargs: They keyword arguments to the operation.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
The result of the operation, or `NOT_SUPPORTED` if no registered
|
| 160 |
+
dispatcher can handle the given arguments.
|
| 161 |
+
"""
|
| 162 |
+
for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR):
|
| 163 |
+
result = dispatcher.handle(args, kwargs)
|
| 164 |
+
if result is not OpDispatcher.NOT_SUPPORTED:
|
| 165 |
+
return result
|
| 166 |
+
for dispatcher in _GLOBAL_DISPATCHERS:
|
| 167 |
+
result = dispatcher.handle(op, args, kwargs)
|
| 168 |
+
if result is not OpDispatcher.NOT_SUPPORTED:
|
| 169 |
+
return result
|
| 170 |
+
return OpDispatcher.NOT_SUPPORTED
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class _TypeBasedDispatcher(OpDispatcher):
|
| 174 |
+
"""Dispatcher that handles op if any arguments have a specified type.
|
| 175 |
+
|
| 176 |
+
Checks the types of the arguments and keyword arguments (including elements
|
| 177 |
+
of lists or tuples), and if any argument values have the indicated type(s),
|
| 178 |
+
then delegates to an override function.
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
def __init__(self, override_func, types):
|
| 182 |
+
self._types = types
|
| 183 |
+
self._override_func = override_func
|
| 184 |
+
|
| 185 |
+
def _handles(self, args, kwargs):
|
| 186 |
+
for arg in itertools.chain(args, kwargs.values()):
|
| 187 |
+
if (isinstance(arg, self._types) or
|
| 188 |
+
(isinstance(arg, (list, tuple)) and
|
| 189 |
+
any(isinstance(elt, self._types) for elt in arg))):
|
| 190 |
+
return True
|
| 191 |
+
return False
|
| 192 |
+
|
| 193 |
+
def handle(self, args, kwargs):
|
| 194 |
+
if self._handles(args, kwargs):
|
| 195 |
+
return self._override_func(*args, **kwargs)
|
| 196 |
+
else:
|
| 197 |
+
return self.NOT_SUPPORTED
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _remove_annotation(sig):
|
| 201 |
+
"""Removes annotation from a python Signature."""
|
| 202 |
+
parameters = [p.replace(annotation=p.empty) for p in sig.parameters.values()]
|
| 203 |
+
return sig.replace(parameters=parameters, return_annotation=sig.empty)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _get_required_param_names(sig):
|
| 207 |
+
"""Returns a list of required parameter names from a python Signature."""
|
| 208 |
+
params = []
|
| 209 |
+
for p in sig.parameters.values():
|
| 210 |
+
if p.kind == p.VAR_POSITIONAL:
|
| 211 |
+
continue
|
| 212 |
+
if p.kind == p.VAR_KEYWORD:
|
| 213 |
+
continue
|
| 214 |
+
if p.default is not p.empty:
|
| 215 |
+
continue
|
| 216 |
+
params.append(p.name)
|
| 217 |
+
return params
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def get_compatible_func(op, func):
|
| 221 |
+
"""Returns a compatible function.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
op: a callable with whose signature the returned function is compatible.
|
| 225 |
+
func: a callable which is called by the returned function.
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
a compatible function, which conducts the actions of `func` but can
|
| 229 |
+
be called like `op`, given that:
|
| 230 |
+
- the list of required arguments in `func` and `op` are the same.
|
| 231 |
+
- there is no override of the default arguments of `op` that are not
|
| 232 |
+
supported by `func`.
|
| 233 |
+
"""
|
| 234 |
+
op_signature = _remove_annotation(tf_inspect.signature(op))
|
| 235 |
+
func_signature = _remove_annotation(tf_inspect.signature(func))
|
| 236 |
+
|
| 237 |
+
# Identitical signatures, no need to apply compatibility fixes.
|
| 238 |
+
if op_signature == func_signature:
|
| 239 |
+
return func
|
| 240 |
+
|
| 241 |
+
# When calling func:
|
| 242 |
+
# - Positional args without default must be in the same order.
|
| 243 |
+
# - Ignore missing optional arguments from op
|
| 244 |
+
|
| 245 |
+
op_pos_names = _get_required_param_names(op_signature)
|
| 246 |
+
func_pos_names = _get_required_param_names(func_signature)
|
| 247 |
+
|
| 248 |
+
if op_pos_names != func_pos_names:
|
| 249 |
+
raise AssertionError(
|
| 250 |
+
"The decorated function's non-default arguments must be identical"
|
| 251 |
+
" to that of the overridden op."
|
| 252 |
+
f" func has {func_pos_names}. op has {op_pos_names}."
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
func_missing_params = {}
|
| 256 |
+
|
| 257 |
+
for name in set(op_signature.parameters.keys()) - set(
|
| 258 |
+
func_signature.parameters.keys()
|
| 259 |
+
):
|
| 260 |
+
p = op_signature.parameters[name]
|
| 261 |
+
if p.default is p.empty:
|
| 262 |
+
raise AssertionError(
|
| 263 |
+
"The decorated function's signature must implement all of the"
|
| 264 |
+
f" non-default arguments of the overridden op. Argument `{name}` is"
|
| 265 |
+
" unimplemented."
|
| 266 |
+
)
|
| 267 |
+
func_missing_params[name] = p
|
| 268 |
+
|
| 269 |
+
def compatible_func(*args, **kwargs):
|
| 270 |
+
bound = op_signature.bind(*args, **kwargs)
|
| 271 |
+
for name, param in func_missing_params.items():
|
| 272 |
+
if name not in bound.arguments:
|
| 273 |
+
continue
|
| 274 |
+
value = bound.arguments.pop(name)
|
| 275 |
+
if value is not param.default:
|
| 276 |
+
raise AssertionError(
|
| 277 |
+
f"Dispatched op is called with argument `{name}` set to a"
|
| 278 |
+
" non-default value, which is not supported by the decorated"
|
| 279 |
+
" function"
|
| 280 |
+
)
|
| 281 |
+
return func(*bound.args, **bound.kwargs)
|
| 282 |
+
|
| 283 |
+
return compatible_func
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# pylint: disable=g-doc-return-or-yield
|
| 287 |
+
def dispatch_for_types(op, *types):
|
| 288 |
+
"""Decorator to declare that a Python function overrides an op for a type.
|
| 289 |
+
|
| 290 |
+
The decorated function is used to override `op` if any of the arguments or
|
| 291 |
+
keyword arguments (including elements of lists or tuples) have one of the
|
| 292 |
+
specified types.
|
| 293 |
+
|
| 294 |
+
Example:
|
| 295 |
+
|
| 296 |
+
```python
|
| 297 |
+
@dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue)
|
| 298 |
+
def ragged_add(x, y, name=None): ...
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
op: Python function: the operation that should be overridden.
|
| 303 |
+
*types: The argument types for which this function should be used.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def decorator(func):
|
| 307 |
+
|
| 308 |
+
_TypeBasedDispatcher(get_compatible_func(op, func), types).register(op)
|
| 309 |
+
return func
|
| 310 |
+
|
| 311 |
+
return decorator
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# pylint: enable=g-doc-return-or-yield
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def add_fallback_dispatch_list(target):
|
| 318 |
+
"""Decorator that adds a dispatch_list attribute to an op."""
|
| 319 |
+
if hasattr(target, FALLBACK_DISPATCH_ATTR):
|
| 320 |
+
raise AssertionError("%s already has a dispatch list" % target)
|
| 321 |
+
setattr(target, FALLBACK_DISPATCH_ATTR, [])
|
| 322 |
+
return target
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# Alias for backwards-compatibility.
|
| 326 |
+
add_dispatch_list = add_fallback_dispatch_list
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
################################################################################
|
| 330 |
+
# Type-based Dispatch
|
| 331 |
+
################################################################################
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
@tf_export("experimental.dispatch_for_api")
|
| 335 |
+
def dispatch_for_api(api, *signatures):
|
| 336 |
+
"""Decorator that overrides the default implementation for a TensorFlow API.
|
| 337 |
+
|
| 338 |
+
The decorated function (known as the "dispatch target") will override the
|
| 339 |
+
default implementation for the API when the API is called with parameters that
|
| 340 |
+
match a specified type signature. Signatures are specified using dictionaries
|
| 341 |
+
that map parameter names to type annotations. E.g., in the following example,
|
| 342 |
+
`masked_add` will be called for `tf.add` if both `x` and `y` are
|
| 343 |
+
`MaskedTensor`s:
|
| 344 |
+
|
| 345 |
+
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
| 346 |
+
... values: tf.Tensor
|
| 347 |
+
... mask: tf.Tensor
|
| 348 |
+
|
| 349 |
+
>>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})
|
| 350 |
+
... def masked_add(x, y, name=None):
|
| 351 |
+
... return MaskedTensor(x.values + y.values, x.mask & y.mask)
|
| 352 |
+
|
| 353 |
+
>>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
|
| 354 |
+
>>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
|
| 355 |
+
values=[11 12], mask=[ True False]
|
| 356 |
+
|
| 357 |
+
If multiple type signatures are specified, then the dispatch target will be
|
| 358 |
+
called if any of the signatures match. For example, the following code
|
| 359 |
+
registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is
|
| 360 |
+
a `MaskedTensor`.
|
| 361 |
+
|
| 362 |
+
>>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})
|
| 363 |
+
... def masked_add(x, y):
|
| 364 |
+
... x_values = x.values if isinstance(x, MaskedTensor) else x
|
| 365 |
+
... x_mask = x.mask if isinstance(x, MaskedTensor) else True
|
| 366 |
+
... y_values = y.values if isinstance(y, MaskedTensor) else y
|
| 367 |
+
... y_mask = y.mask if isinstance(y, MaskedTensor) else True
|
| 368 |
+
... return MaskedTensor(x_values + y_values, x_mask & y_mask)
|
| 369 |
+
|
| 370 |
+
The type annotations in type signatures may be type objects (e.g.,
|
| 371 |
+
`MaskedTensor`), `typing.List` values, or `typing.Union` values. For
|
| 372 |
+
example, the following will register `masked_concat` to be called if `values`
|
| 373 |
+
is a list of `MaskedTensor` values:
|
| 374 |
+
|
| 375 |
+
>>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})
|
| 376 |
+
... def masked_concat(values, axis):
|
| 377 |
+
... return MaskedTensor(tf.concat([v.values for v in values], axis),
|
| 378 |
+
... tf.concat([v.mask for v in values], axis))
|
| 379 |
+
|
| 380 |
+
Each type signature must contain at least one subclass of `tf.CompositeTensor`
|
| 381 |
+
(which includes subclasses of `tf.ExtensionType`), and dispatch will only be
|
| 382 |
+
triggered if at least one type-annotated parameter contains a
|
| 383 |
+
`CompositeTensor` value. This rule avoids invoking dispatch in degenerate
|
| 384 |
+
cases, such as the following examples:
|
| 385 |
+
|
| 386 |
+
* `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not
|
| 387 |
+
dispatch to the decorated dispatch target when the user calls
|
| 388 |
+
`tf.concat([])`.
|
| 389 |
+
|
| 390 |
+
* `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y':
|
| 391 |
+
Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch
|
| 392 |
+
target when the user calls `tf.add(tf.constant(1), tf.constant(2))`.
|
| 393 |
+
|
| 394 |
+
The dispatch target's signature must match the signature of the API that is
|
| 395 |
+
being overridden. In particular, parameters must have the same names, and
|
| 396 |
+
must occur in the same order. The dispatch target may optionally elide the
|
| 397 |
+
"name" parameter, in which case it will be wrapped with a call to
|
| 398 |
+
`tf.name_scope` when appropraite.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
api: The TensorFlow API to override.
|
| 402 |
+
*signatures: Dictionaries mapping parameter names or indices to type
|
| 403 |
+
annotations, specifying when the dispatch target should be called. In
|
| 404 |
+
particular, the dispatch target will be called if any signature matches;
|
| 405 |
+
and a signature matches if all of the specified parameters have types that
|
| 406 |
+
match with the indicated type annotations. If no signatures are
|
| 407 |
+
specified, then a signature will be read from the dispatch target
|
| 408 |
+
function's type annotations.
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
A decorator that overrides the default implementation for `api`.
|
| 412 |
+
|
| 413 |
+
#### Registered APIs
|
| 414 |
+
|
| 415 |
+
The TensorFlow APIs that may be overridden by `@dispatch_for_api` are:
|
| 416 |
+
|
| 417 |
+
<<API_LIST>>
|
| 418 |
+
"""
|
| 419 |
+
dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None)
|
| 420 |
+
if dispatcher is None:
|
| 421 |
+
raise ValueError(f"{api} does not support dispatch.")
|
| 422 |
+
|
| 423 |
+
api_signature = tf_inspect.signature(api)
|
| 424 |
+
signature_checkers = [
|
| 425 |
+
_make_signature_checker(api_signature, signature)
|
| 426 |
+
for signature in signatures
|
| 427 |
+
]
|
| 428 |
+
|
| 429 |
+
def decorator(dispatch_target):
|
| 430 |
+
"""Decorator that registers the given dispatch target."""
|
| 431 |
+
if not callable(dispatch_target):
|
| 432 |
+
raise TypeError("Expected dispatch_target to be callable; "
|
| 433 |
+
f"got {dispatch_target!r}")
|
| 434 |
+
dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature)
|
| 435 |
+
_check_signature(api_signature, dispatch_target)
|
| 436 |
+
|
| 437 |
+
for signature_checker in signature_checkers:
|
| 438 |
+
dispatcher.Register(signature_checker, dispatch_target)
|
| 439 |
+
_TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures)
|
| 440 |
+
|
| 441 |
+
if not signature_checkers:
|
| 442 |
+
signature = _signature_from_annotations(dispatch_target)
|
| 443 |
+
checker = _make_signature_checker(api_signature, signature)
|
| 444 |
+
dispatcher.Register(checker, dispatch_target)
|
| 445 |
+
_TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature)
|
| 446 |
+
|
| 447 |
+
return dispatch_target
|
| 448 |
+
|
| 449 |
+
return decorator
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`,
|
| 453 |
+
# which can be used for documentation generation and for improved error messages
|
| 454 |
+
# when APIs are called with unsupported types.
|
| 455 |
+
_TYPE_BASED_DISPATCH_SIGNATURES = {}
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def apis_with_type_based_dispatch():
|
| 459 |
+
"""Returns a list of TensorFlow APIs that support type-based dispatch."""
|
| 460 |
+
return sorted(
|
| 461 |
+
_TYPE_BASED_DISPATCH_SIGNATURES,
|
| 462 |
+
key=lambda api: f"{api.__module__}.{api.__name__}")
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def type_based_dispatch_signatures_for(cls):
|
| 466 |
+
"""Returns dispatch signatures that have been registered for a given class.
|
| 467 |
+
|
| 468 |
+
This function is intended for documentation-generation purposes.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
cls: The class to search for. Type signatures are searched recursively, so
|
| 472 |
+
e.g., if `cls=RaggedTensor`, then information will be returned for all
|
| 473 |
+
dispatch targets that have `RaggedTensor` anywhere in their type
|
| 474 |
+
annotations (including nested in `typing.Union` or `typing.List`.)
|
| 475 |
+
|
| 476 |
+
Returns:
|
| 477 |
+
A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API
|
| 478 |
+
function; and `signatures` is a list of dispatch signatures for `api`
|
| 479 |
+
that include `cls`. (Each signature is a dict mapping argument names to
|
| 480 |
+
type annotations; see `dispatch_for_api` for more info.)
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
def contains_cls(x):
|
| 484 |
+
"""Returns true if `x` contains `cls`."""
|
| 485 |
+
if isinstance(x, dict):
|
| 486 |
+
return any(contains_cls(v) for v in x.values())
|
| 487 |
+
elif x is cls:
|
| 488 |
+
return True
|
| 489 |
+
elif (type_annotations.is_generic_list(x) or
|
| 490 |
+
type_annotations.is_generic_union(x)):
|
| 491 |
+
type_args = type_annotations.get_generic_type_args(x)
|
| 492 |
+
return any(contains_cls(arg) for arg in type_args)
|
| 493 |
+
else:
|
| 494 |
+
return False
|
| 495 |
+
|
| 496 |
+
result = {}
|
| 497 |
+
for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
|
| 498 |
+
for _, signatures in api_signatures.items():
|
| 499 |
+
filtered = list(filter(contains_cls, signatures))
|
| 500 |
+
if filtered:
|
| 501 |
+
result.setdefault(api, []).extend(filtered)
|
| 502 |
+
return result
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
# TODO(edloper): Consider using a mechanism like this to automatically add
|
| 506 |
+
# the `name` argument to all TensorFlow APIs that are implemented in Python
|
| 507 |
+
# (so each Python function doesn't need to do it manually).
|
| 508 |
+
def _add_name_scope_wrapper(func, api_signature):
|
| 509 |
+
"""Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`.
|
| 510 |
+
|
| 511 |
+
If `func` already expects a "name" arg, or if `api_signature` does not
|
| 512 |
+
expect a "name" arg, then returns `func` as-is.
|
| 513 |
+
|
| 514 |
+
Args:
|
| 515 |
+
func: The function to wrap. Signature must match `api_signature` (except
|
| 516 |
+
the "name" parameter may be missing.
|
| 517 |
+
api_signature: The signature of the original API (used to find the index for
|
| 518 |
+
the "name" parameter).
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
The wrapped function (or the original function if no wrapping is needed).
|
| 522 |
+
"""
|
| 523 |
+
if "name" not in api_signature.parameters:
|
| 524 |
+
return func # no wrapping needed (API has no name parameter).
|
| 525 |
+
|
| 526 |
+
func_signature = tf_inspect.signature(func)
|
| 527 |
+
func_argspec = tf_inspect.getargspec(func)
|
| 528 |
+
if "name" in func_signature.parameters or func_argspec.keywords is not None:
|
| 529 |
+
return func # No wrapping needed (already has name parameter).
|
| 530 |
+
|
| 531 |
+
name_index = list(api_signature.parameters).index("name")
|
| 532 |
+
|
| 533 |
+
def wrapped_func(*args, **kwargs):
|
| 534 |
+
if name_index < len(args):
|
| 535 |
+
name = args[name_index]
|
| 536 |
+
args = args[:name_index] + args[name_index + 1:]
|
| 537 |
+
else:
|
| 538 |
+
name = kwargs.pop("name", None)
|
| 539 |
+
if name is None:
|
| 540 |
+
return func(*args, **kwargs)
|
| 541 |
+
else:
|
| 542 |
+
with ops.name_scope(name):
|
| 543 |
+
return func(*args, **kwargs)
|
| 544 |
+
|
| 545 |
+
wrapped_func = tf_decorator.make_decorator(func, wrapped_func)
|
| 546 |
+
wrapped_func.__signature__ = func_signature.replace(
|
| 547 |
+
parameters=(list(func_signature.parameters.values()) +
|
| 548 |
+
[api_signature.parameters["name"]]))
|
| 549 |
+
del wrapped_func._tf_decorator
|
| 550 |
+
return wrapped_func
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@tf_export("experimental.unregister_dispatch_for")
|
| 554 |
+
def unregister_dispatch_for(dispatch_target):
|
| 555 |
+
"""Unregisters a function that was registered with `@dispatch_for_*`.
|
| 556 |
+
|
| 557 |
+
This is primarily intended for testing purposes.
|
| 558 |
+
|
| 559 |
+
Example:
|
| 560 |
+
|
| 561 |
+
>>> # Define a type and register a dispatcher to override `tf.abs`:
|
| 562 |
+
>>> class MyTensor(tf.experimental.ExtensionType):
|
| 563 |
+
... value: tf.Tensor
|
| 564 |
+
>>> @tf.experimental.dispatch_for_api(tf.abs)
|
| 565 |
+
... def my_abs(x: MyTensor):
|
| 566 |
+
... return MyTensor(tf.abs(x.value))
|
| 567 |
+
>>> tf.abs(MyTensor(5))
|
| 568 |
+
MyTensor(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>)
|
| 569 |
+
|
| 570 |
+
>>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`.
|
| 571 |
+
>>> unregister_dispatch_for(my_abs)
|
| 572 |
+
>>> tf.abs(MyTensor(5))
|
| 573 |
+
Traceback (most recent call last):
|
| 574 |
+
...
|
| 575 |
+
ValueError: Attempt to convert a value ... to a Tensor.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
dispatch_target: The function to unregister.
|
| 579 |
+
|
| 580 |
+
Raises:
|
| 581 |
+
ValueError: If `dispatch_target` was not registered using `@dispatch_for`,
|
| 582 |
+
`@dispatch_for_unary_elementwise_apis`, or
|
| 583 |
+
`@dispatch_for_binary_elementwise_apis`.
|
| 584 |
+
"""
|
| 585 |
+
found = False
|
| 586 |
+
|
| 587 |
+
# Check if dispatch_target registered by `@dispatch_for_api`
|
| 588 |
+
for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items():
|
| 589 |
+
if dispatch_target in signatures:
|
| 590 |
+
dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR)
|
| 591 |
+
dispatcher.Unregister(dispatch_target)
|
| 592 |
+
del signatures[dispatch_target]
|
| 593 |
+
found = True
|
| 594 |
+
|
| 595 |
+
# Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis`
|
| 596 |
+
elementwise_keys_to_delete = [
|
| 597 |
+
key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items()
|
| 598 |
+
if handler is dispatch_target
|
| 599 |
+
]
|
| 600 |
+
for key in set(elementwise_keys_to_delete):
|
| 601 |
+
for _, target in _ELEMENTWISE_API_TARGETS[key]:
|
| 602 |
+
unregister_dispatch_for(target)
|
| 603 |
+
del _ELEMENTWISE_API_HANDLERS[key]
|
| 604 |
+
del _ELEMENTWISE_API_TARGETS[key]
|
| 605 |
+
found = True
|
| 606 |
+
|
| 607 |
+
if not found:
|
| 608 |
+
raise ValueError(f"Function {dispatch_target} was not registered using "
|
| 609 |
+
"a `@dispatch_for_*` decorator.")
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def register_dispatchable_type(cls):
|
| 613 |
+
"""Class decorator that registers a type for use with type-based dispatch.
|
| 614 |
+
|
| 615 |
+
Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType`
|
| 616 |
+
(which are automatically registered).
|
| 617 |
+
|
| 618 |
+
Note: this function is intended to support internal legacy use cases (such
|
| 619 |
+
as RaggedTensorValue), and will probably not be exposed as a public API.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
cls: The class to register.
|
| 623 |
+
|
| 624 |
+
Returns:
|
| 625 |
+
`cls`.
|
| 626 |
+
"""
|
| 627 |
+
_api_dispatcher.register_dispatchable_type(cls)
|
| 628 |
+
return cls
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def add_type_based_api_dispatcher(target):
|
| 632 |
+
"""Adds a PythonAPIDispatcher to the given TensorFlow API function."""
|
| 633 |
+
if hasattr(target, TYPE_BASED_DISPATCH_ATTR):
|
| 634 |
+
raise ValueError(f"{target} already has a type-based API dispatcher.")
|
| 635 |
+
|
| 636 |
+
_, unwrapped = tf_decorator.unwrap(target)
|
| 637 |
+
target_argspec = tf_inspect.getargspec(unwrapped)
|
| 638 |
+
if target_argspec.varargs or target_argspec.keywords:
|
| 639 |
+
# @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs
|
| 640 |
+
# and keywords. Examples of APIs that take varargs and kwargs: meshgrid,
|
| 641 |
+
# einsum, map_values, map_flat_values.
|
| 642 |
+
return target
|
| 643 |
+
|
| 644 |
+
setattr(
|
| 645 |
+
target, TYPE_BASED_DISPATCH_ATTR,
|
| 646 |
+
_api_dispatcher.PythonAPIDispatcher(unwrapped.__name__,
|
| 647 |
+
target_argspec.args,
|
| 648 |
+
target_argspec.defaults))
|
| 649 |
+
_TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list)
|
| 650 |
+
return target
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def _check_signature(api_signature, func):
|
| 654 |
+
"""Checks that a dispatch target's signature is compatible with an API.
|
| 655 |
+
|
| 656 |
+
Args:
|
| 657 |
+
api_signature: The signature of the TensorFlow API.
|
| 658 |
+
func: The dispatch target.
|
| 659 |
+
|
| 660 |
+
Raises:
|
| 661 |
+
ValueError: if the signatures are incompatible. Two signatures are
|
| 662 |
+
considered compatible if they have the same number of parameters, and all
|
| 663 |
+
corresponding parameters have the same `name` and `kind`. (Parameters
|
| 664 |
+
are not required to have the same default value or the same annotation.)
|
| 665 |
+
"""
|
| 666 |
+
# Special case: if func_signature is (*args, **kwargs), then assume it's ok.
|
| 667 |
+
func_argspec = tf_inspect.getargspec(func)
|
| 668 |
+
if (func_argspec.varargs is not None and func_argspec.keywords is not None
|
| 669 |
+
and not func_argspec.args):
|
| 670 |
+
return
|
| 671 |
+
|
| 672 |
+
func_signature = tf_inspect.signature(func)
|
| 673 |
+
ok = len(api_signature.parameters) == len(func_signature.parameters)
|
| 674 |
+
if ok:
|
| 675 |
+
for param_1, param_2 in zip(api_signature.parameters.values(),
|
| 676 |
+
func_signature.parameters.values()):
|
| 677 |
+
if (param_1.name != param_2.name) or (param_1.kind != param_2.kind):
|
| 678 |
+
ok = False
|
| 679 |
+
if not ok:
|
| 680 |
+
raise ValueError(f"Dispatch function's signature {func_signature} does "
|
| 681 |
+
f"not match API's signature {api_signature}.")
|
| 682 |
+
|
| 683 |
+
|
| 684 |
+
def _make_signature_checker(api_signature, signature):
|
| 685 |
+
"""Builds a PySignatureChecker for the given type signature.
|
| 686 |
+
|
| 687 |
+
Args:
|
| 688 |
+
api_signature: The `inspect.Signature` of the API whose signature is
|
| 689 |
+
being checked.
|
| 690 |
+
signature: Dictionary mapping parameter names to type annotations.
|
| 691 |
+
|
| 692 |
+
Returns:
|
| 693 |
+
A `PySignatureChecker`.
|
| 694 |
+
"""
|
| 695 |
+
if not (isinstance(signature, dict) and
|
| 696 |
+
all(isinstance(k, (str, int)) for k in signature)):
|
| 697 |
+
raise TypeError("signatures must be dictionaries mapping parameter names "
|
| 698 |
+
"to type annotations.")
|
| 699 |
+
checkers = []
|
| 700 |
+
|
| 701 |
+
param_names = list(api_signature.parameters)
|
| 702 |
+
for param_name, param_type in signature.items():
|
| 703 |
+
# Convert positional parameters to named parameters.
|
| 704 |
+
if (isinstance(param_name, int) and
|
| 705 |
+
param_name < len(api_signature.parameters)):
|
| 706 |
+
param_name = list(api_signature.parameters.values())[param_name].name
|
| 707 |
+
|
| 708 |
+
# Check that the parameter exists, and has an appropriate kind.
|
| 709 |
+
param = api_signature.parameters.get(param_name, None)
|
| 710 |
+
if param is None:
|
| 711 |
+
raise ValueError("signature includes annotation for unknown "
|
| 712 |
+
f"parameter {param_name!r}.")
|
| 713 |
+
if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY,
|
| 714 |
+
tf_inspect.Parameter.POSITIONAL_OR_KEYWORD):
|
| 715 |
+
raise ValueError("Dispatch currently only supports type annotations "
|
| 716 |
+
"for positional parameters; can't handle annotation "
|
| 717 |
+
f"for {param.kind!r} parameter {param_name}.")
|
| 718 |
+
|
| 719 |
+
checker = make_type_checker(param_type)
|
| 720 |
+
index = param_names.index(param_name)
|
| 721 |
+
checkers.append((index, checker))
|
| 722 |
+
|
| 723 |
+
return _api_dispatcher.PySignatureChecker(checkers)
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
# Cache for InstanceTypeChecker objects (we only want to create one
|
| 727 |
+
# InstanceTypeChecker for each type, since each one uses an internal cache
|
| 728 |
+
# to avoid repeated calls back into Python's isinstance).
|
| 729 |
+
_is_instance_checker_cache = {}
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def make_type_checker(annotation):
|
| 733 |
+
"""Builds a PyTypeChecker for the given type annotation."""
|
| 734 |
+
if type_annotations.is_generic_union(annotation):
|
| 735 |
+
type_args = type_annotations.get_generic_type_args(annotation)
|
| 736 |
+
|
| 737 |
+
# If the union contains two or more simple types, then use a single
|
| 738 |
+
# InstanceChecker to check them.
|
| 739 |
+
simple_types = [t for t in type_args if isinstance(t, type)]
|
| 740 |
+
simple_types = tuple(sorted(simple_types, key=id))
|
| 741 |
+
if len(simple_types) > 1:
|
| 742 |
+
if simple_types not in _is_instance_checker_cache:
|
| 743 |
+
checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
|
| 744 |
+
_is_instance_checker_cache[simple_types] = checker
|
| 745 |
+
options = ([_is_instance_checker_cache[simple_types]] +
|
| 746 |
+
[make_type_checker(t) for t in type_args
|
| 747 |
+
if not isinstance(t, type)])
|
| 748 |
+
return _api_dispatcher.MakeUnionChecker(options)
|
| 749 |
+
|
| 750 |
+
options = [make_type_checker(t) for t in type_args]
|
| 751 |
+
return _api_dispatcher.MakeUnionChecker(options)
|
| 752 |
+
|
| 753 |
+
elif type_annotations.is_generic_list(annotation):
|
| 754 |
+
type_args = type_annotations.get_generic_type_args(annotation)
|
| 755 |
+
if len(type_args) != 1:
|
| 756 |
+
raise AssertionError("Expected List[...] to have a single type parameter")
|
| 757 |
+
elt_type = make_type_checker(type_args[0])
|
| 758 |
+
return _api_dispatcher.MakeListChecker(elt_type)
|
| 759 |
+
|
| 760 |
+
elif isinstance(annotation, type):
|
| 761 |
+
if annotation not in _is_instance_checker_cache:
|
| 762 |
+
checker = _api_dispatcher.MakeInstanceChecker(annotation)
|
| 763 |
+
_is_instance_checker_cache[annotation] = checker
|
| 764 |
+
return _is_instance_checker_cache[annotation]
|
| 765 |
+
|
| 766 |
+
elif annotation is None:
|
| 767 |
+
return make_type_checker(type(None))
|
| 768 |
+
|
| 769 |
+
else:
|
| 770 |
+
raise ValueError(f"Type annotation {annotation} is not currently supported"
|
| 771 |
+
" by dispatch. Supported annotations: type objects, "
|
| 772 |
+
" List[...], and Union[...]")
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def _signature_from_annotations(func):
|
| 776 |
+
"""Builds a dict mapping from parameter names to type annotations."""
|
| 777 |
+
func_signature = tf_inspect.signature(func)
|
| 778 |
+
|
| 779 |
+
signature = dict([(name, param.annotation)
|
| 780 |
+
for (name, param) in func_signature.parameters.items()
|
| 781 |
+
if param.annotation != tf_inspect.Parameter.empty])
|
| 782 |
+
if not signature:
|
| 783 |
+
raise ValueError("The dispatch_for_api decorator must be called with at "
|
| 784 |
+
"least one signature, or applied to a function that "
|
| 785 |
+
"has type annotations on its parameters.")
|
| 786 |
+
return signature
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
# Registries for elementwise APIs and API handlers.
|
| 790 |
+
#
|
| 791 |
+
# _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered
|
| 792 |
+
# as elementwise operations using the `register_*_elementwise_api`
|
| 793 |
+
# decorators.
|
| 794 |
+
#
|
| 795 |
+
# _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API
|
| 796 |
+
# handlers that have been registered with the `dispatch_for_*_elementwise_apis`
|
| 797 |
+
# decorators.
|
| 798 |
+
#
|
| 799 |
+
# _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of
|
| 800 |
+
# `(api, dispatch_target)` pairs. Used to impelement
|
| 801 |
+
# `unregister_elementwise_api_handler`.
|
| 802 |
+
_UNARY_ELEMENTWISE_APIS = []
|
| 803 |
+
_BINARY_ELEMENTWISE_APIS = []
|
| 804 |
+
_BINARY_ELEMENTWISE_ASSERT_APIS = []
|
| 805 |
+
_ELEMENTWISE_API_HANDLERS = {}
|
| 806 |
+
_ELEMENTWISE_API_TARGETS = {}
|
| 807 |
+
|
| 808 |
+
_ASSERT_API_TAG = "ASSERT_API_TAG"
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
@tf_export("experimental.dispatch_for_unary_elementwise_apis")
|
| 812 |
+
def dispatch_for_unary_elementwise_apis(x_type):
|
| 813 |
+
"""Decorator to override default implementation for unary elementwise APIs.
|
| 814 |
+
|
| 815 |
+
The decorated function (known as the "elementwise api handler") overrides
|
| 816 |
+
the default implementation for any unary elementwise API whenever the value
|
| 817 |
+
for the first argument (typically named `x`) matches the type annotation
|
| 818 |
+
`x_type`. The elementwise api handler is called with two arguments:
|
| 819 |
+
|
| 820 |
+
`elementwise_api_handler(api_func, x)`
|
| 821 |
+
|
| 822 |
+
Where `api_func` is a function that takes a single parameter and performs the
|
| 823 |
+
elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the
|
| 824 |
+
elementwise api.
|
| 825 |
+
|
| 826 |
+
The following example shows how this decorator can be used to update all
|
| 827 |
+
unary elementwise operations to handle a `MaskedTensor` type:
|
| 828 |
+
|
| 829 |
+
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
| 830 |
+
... values: tf.Tensor
|
| 831 |
+
... mask: tf.Tensor
|
| 832 |
+
>>> @dispatch_for_unary_elementwise_apis(MaskedTensor)
|
| 833 |
+
... def unary_elementwise_api_handler(api_func, x):
|
| 834 |
+
... return MaskedTensor(api_func(x.values), x.mask)
|
| 835 |
+
>>> mt = MaskedTensor([1, -2, -3], [True, False, True])
|
| 836 |
+
>>> abs_mt = tf.abs(mt)
|
| 837 |
+
>>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
|
| 838 |
+
values=[1 2 3], mask=[ True False True]
|
| 839 |
+
|
| 840 |
+
For unary elementwise operations that take extra arguments beyond `x`, those
|
| 841 |
+
arguments are *not* passed to the elementwise api handler, but are
|
| 842 |
+
automatically added when `api_func` is called. E.g., in the following
|
| 843 |
+
example, the `dtype` parameter is not passed to
|
| 844 |
+
`unary_elementwise_api_handler`, but is added by `api_func`.
|
| 845 |
+
|
| 846 |
+
>>> ones_mt = tf.ones_like(mt, dtype=tf.float32)
|
| 847 |
+
>>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
|
| 848 |
+
values=[1.0 1.0 1.0], mask=[ True False True]
|
| 849 |
+
|
| 850 |
+
Args:
|
| 851 |
+
x_type: A type annotation indicating when the api handler should be called.
|
| 852 |
+
See `dispatch_for_api` for a list of supported annotation types.
|
| 853 |
+
|
| 854 |
+
Returns:
|
| 855 |
+
A decorator.
|
| 856 |
+
|
| 857 |
+
#### Registered APIs
|
| 858 |
+
|
| 859 |
+
The unary elementwise APIs are:
|
| 860 |
+
|
| 861 |
+
<<API_LIST>>
|
| 862 |
+
"""
|
| 863 |
+
|
| 864 |
+
def decorator(handler):
|
| 865 |
+
if (x_type,) in _ELEMENTWISE_API_HANDLERS:
|
| 866 |
+
raise ValueError("A unary elementwise dispatch handler "
|
| 867 |
+
f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) "
|
| 868 |
+
f"has already been registered for {x_type}.")
|
| 869 |
+
_ELEMENTWISE_API_HANDLERS[(x_type,)] = handler
|
| 870 |
+
for api in _UNARY_ELEMENTWISE_APIS:
|
| 871 |
+
_add_dispatch_for_unary_elementwise_api(api, x_type, handler)
|
| 872 |
+
|
| 873 |
+
return handler
|
| 874 |
+
|
| 875 |
+
return decorator
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
@tf_export("experimental.dispatch_for_binary_elementwise_apis")
|
| 879 |
+
def dispatch_for_binary_elementwise_apis(x_type, y_type):
|
| 880 |
+
"""Decorator to override default implementation for binary elementwise APIs.
|
| 881 |
+
|
| 882 |
+
The decorated function (known as the "elementwise api handler") overrides
|
| 883 |
+
the default implementation for any binary elementwise API whenever the value
|
| 884 |
+
for the first two arguments (typically named `x` and `y`) match the specified
|
| 885 |
+
type annotations. The elementwise api handler is called with two arguments:
|
| 886 |
+
|
| 887 |
+
`elementwise_api_handler(api_func, x, y)`
|
| 888 |
+
|
| 889 |
+
Where `x` and `y` are the first two arguments to the elementwise api, and
|
| 890 |
+
`api_func` is a TensorFlow function that takes two parameters and performs the
|
| 891 |
+
elementwise operation (e.g., `tf.add`).
|
| 892 |
+
|
| 893 |
+
The following example shows how this decorator can be used to update all
|
| 894 |
+
binary elementwise operations to handle a `MaskedTensor` type:
|
| 895 |
+
|
| 896 |
+
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
| 897 |
+
... values: tf.Tensor
|
| 898 |
+
... mask: tf.Tensor
|
| 899 |
+
>>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)
|
| 900 |
+
... def binary_elementwise_api_handler(api_func, x, y):
|
| 901 |
+
... return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)
|
| 902 |
+
>>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False])
|
| 903 |
+
>>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True])
|
| 904 |
+
>>> c = tf.add(a, b)
|
| 905 |
+
>>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}")
|
| 906 |
+
values=[ 3 6 9 12 5], mask=[ True True True False False]
|
| 907 |
+
|
| 908 |
+
Args:
|
| 909 |
+
x_type: A type annotation indicating when the api handler should be called.
|
| 910 |
+
y_type: A type annotation indicating when the api handler should be called.
|
| 911 |
+
|
| 912 |
+
Returns:
|
| 913 |
+
A decorator.
|
| 914 |
+
|
| 915 |
+
#### Registered APIs
|
| 916 |
+
|
| 917 |
+
The binary elementwise APIs are:
|
| 918 |
+
|
| 919 |
+
<<API_LIST>>
|
| 920 |
+
"""
|
| 921 |
+
|
| 922 |
+
def decorator(handler):
|
| 923 |
+
if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS:
|
| 924 |
+
raise ValueError("A binary elementwise dispatch handler "
|
| 925 |
+
f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) "
|
| 926 |
+
f"has already been registered for ({x_type}, {y_type}).")
|
| 927 |
+
_ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler
|
| 928 |
+
for api in _BINARY_ELEMENTWISE_APIS:
|
| 929 |
+
_add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
|
| 930 |
+
|
| 931 |
+
return handler
|
| 932 |
+
|
| 933 |
+
return decorator
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
@tf_export("experimental.dispatch_for_binary_elementwise_assert_apis")
|
| 937 |
+
def dispatch_for_binary_elementwise_assert_apis(x_type, y_type):
|
| 938 |
+
"""Decorator to override default implementation for binary elementwise assert APIs.
|
| 939 |
+
|
| 940 |
+
The decorated function (known as the "elementwise assert handler")
|
| 941 |
+
overrides the default implementation for any binary elementwise assert API
|
| 942 |
+
whenever the value for the first two arguments (typically named `x` and `y`)
|
| 943 |
+
match the specified type annotations. The handler is called with two
|
| 944 |
+
arguments:
|
| 945 |
+
|
| 946 |
+
`elementwise_assert_handler(assert_func, x, y)`
|
| 947 |
+
|
| 948 |
+
Where `x` and `y` are the first two arguments to the binary elementwise assert
|
| 949 |
+
operation, and `assert_func` is a TensorFlow function that takes two
|
| 950 |
+
parameters and performs the elementwise assert operation (e.g.,
|
| 951 |
+
`tf.debugging.assert_equal`).
|
| 952 |
+
|
| 953 |
+
The following example shows how this decorator can be used to update all
|
| 954 |
+
binary elementwise assert operations to handle a `MaskedTensor` type:
|
| 955 |
+
|
| 956 |
+
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
| 957 |
+
... values: tf.Tensor
|
| 958 |
+
... mask: tf.Tensor
|
| 959 |
+
>>> @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor)
|
| 960 |
+
... def binary_elementwise_assert_api_handler(assert_func, x, y):
|
| 961 |
+
... merged_mask = tf.logical_and(x.mask, y.mask)
|
| 962 |
+
... selected_x_values = tf.boolean_mask(x.values, merged_mask)
|
| 963 |
+
... selected_y_values = tf.boolean_mask(y.values, merged_mask)
|
| 964 |
+
... assert_func(selected_x_values, selected_y_values)
|
| 965 |
+
>>> a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True])
|
| 966 |
+
>>> b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False])
|
| 967 |
+
>>> tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown
|
| 968 |
+
|
| 969 |
+
>>> a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True])
|
| 970 |
+
>>> b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True])
|
| 971 |
+
>>> tf.debugging.assert_greater(a, b)
|
| 972 |
+
Traceback (most recent call last):
|
| 973 |
+
...
|
| 974 |
+
InvalidArgumentError: Condition x > y did not hold.
|
| 975 |
+
|
| 976 |
+
Args:
|
| 977 |
+
x_type: A type annotation indicating when the api handler should be called.
|
| 978 |
+
y_type: A type annotation indicating when the api handler should be called.
|
| 979 |
+
|
| 980 |
+
Returns:
|
| 981 |
+
A decorator.
|
| 982 |
+
|
| 983 |
+
#### Registered APIs
|
| 984 |
+
|
| 985 |
+
The binary elementwise assert APIs are:
|
| 986 |
+
|
| 987 |
+
<<API_LIST>>
|
| 988 |
+
"""
|
| 989 |
+
|
| 990 |
+
def decorator(handler):
|
| 991 |
+
api_handler_key = (x_type, y_type, _ASSERT_API_TAG)
|
| 992 |
+
if api_handler_key in _ELEMENTWISE_API_HANDLERS:
|
| 993 |
+
raise ValueError("A binary elementwise assert dispatch handler "
|
| 994 |
+
f"({_ELEMENTWISE_API_HANDLERS[api_handler_key]}) "
|
| 995 |
+
f"has already been registered for ({x_type}, {y_type}).")
|
| 996 |
+
_ELEMENTWISE_API_HANDLERS[api_handler_key] = handler
|
| 997 |
+
for api in _BINARY_ELEMENTWISE_ASSERT_APIS:
|
| 998 |
+
_add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler)
|
| 999 |
+
|
| 1000 |
+
return handler
|
| 1001 |
+
|
| 1002 |
+
return decorator
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
def register_unary_elementwise_api(func):
|
| 1006 |
+
"""Decorator that registers a TensorFlow op as a unary elementwise API."""
|
| 1007 |
+
_UNARY_ELEMENTWISE_APIS.append(func)
|
| 1008 |
+
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
|
| 1009 |
+
if len(args) == 1:
|
| 1010 |
+
_add_dispatch_for_unary_elementwise_api(func, args[0], handler)
|
| 1011 |
+
return func
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
def register_binary_elementwise_api(func):
|
| 1015 |
+
"""Decorator that registers a TensorFlow op as a binary elementwise API."""
|
| 1016 |
+
_BINARY_ELEMENTWISE_APIS.append(func)
|
| 1017 |
+
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
|
| 1018 |
+
if len(args) == 2:
|
| 1019 |
+
_add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
|
| 1020 |
+
return func
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
def register_binary_elementwise_assert_api(func):
|
| 1024 |
+
"""Decorator that registers a TensorFlow op as a binary elementwise assert API.
|
| 1025 |
+
|
| 1026 |
+
Different from `dispatch_for_binary_elementwise_apis`, this decorator is used
|
| 1027 |
+
for assert apis, such as assert_equal, assert_none_equal, etc, which return
|
| 1028 |
+
None in eager mode and an op in graph mode.
|
| 1029 |
+
|
| 1030 |
+
Args:
|
| 1031 |
+
func: The function that implements the binary elementwise assert API.
|
| 1032 |
+
|
| 1033 |
+
Returns:
|
| 1034 |
+
`func`
|
| 1035 |
+
"""
|
| 1036 |
+
_BINARY_ELEMENTWISE_ASSERT_APIS.append(func)
|
| 1037 |
+
for args, handler in _ELEMENTWISE_API_HANDLERS.items():
|
| 1038 |
+
if len(args) == 3 and args[2] is _ASSERT_API_TAG:
|
| 1039 |
+
_add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler)
|
| 1040 |
+
return func
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
def unary_elementwise_apis():
|
| 1044 |
+
"""Returns a list of APIs that have been registered as unary elementwise."""
|
| 1045 |
+
return tuple(_UNARY_ELEMENTWISE_APIS)
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
def binary_elementwise_apis():
|
| 1049 |
+
"""Returns a list of APIs that have been registered as binary elementwise."""
|
| 1050 |
+
return tuple(_BINARY_ELEMENTWISE_APIS)
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
def _add_dispatch_for_unary_elementwise_api(api, x_type,
|
| 1054 |
+
elementwise_api_handler):
|
| 1055 |
+
"""Registers a unary elementwise handler as a dispatcher for a given API."""
|
| 1056 |
+
api_signature = tf_inspect.signature(api)
|
| 1057 |
+
x_name = list(api_signature.parameters)[0]
|
| 1058 |
+
name_index = _find_name_index(api_signature)
|
| 1059 |
+
|
| 1060 |
+
need_to_bind_api_args = (
|
| 1061 |
+
len(api_signature.parameters) > 2 or
|
| 1062 |
+
"name" not in api_signature.parameters)
|
| 1063 |
+
|
| 1064 |
+
@dispatch_for_api(api, {x_name: x_type})
|
| 1065 |
+
def dispatch_target(*args, **kwargs):
|
| 1066 |
+
args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
|
| 1067 |
+
if args:
|
| 1068 |
+
x, args = args[0], args[1:]
|
| 1069 |
+
else:
|
| 1070 |
+
x = kwargs.pop(x_name)
|
| 1071 |
+
|
| 1072 |
+
if need_to_bind_api_args:
|
| 1073 |
+
tensor_api = lambda v: api(v, *args, **kwargs)
|
| 1074 |
+
else:
|
| 1075 |
+
tensor_api = api
|
| 1076 |
+
|
| 1077 |
+
if name is None:
|
| 1078 |
+
return elementwise_api_handler(tensor_api, x)
|
| 1079 |
+
else:
|
| 1080 |
+
with ops.name_scope(name, None, [x]):
|
| 1081 |
+
return elementwise_api_handler(tensor_api, x)
|
| 1082 |
+
|
| 1083 |
+
dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
|
| 1084 |
+
dispatch_target.__qualname__ = dispatch_target.__name__
|
| 1085 |
+
# Keep track of what targets we've registered (so we can unregister them).
|
| 1086 |
+
target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), [])
|
| 1087 |
+
target_list.append((api, dispatch_target))
|
| 1088 |
+
|
| 1089 |
+
|
| 1090 |
+
def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type,
|
| 1091 |
+
elementwise_api_handler):
|
| 1092 |
+
"""Registers a binary elementwise handler as a dispatcher for a given API."""
|
| 1093 |
+
api_signature = tf_inspect.signature(api)
|
| 1094 |
+
x_name, y_name = list(api_signature.parameters)[:2]
|
| 1095 |
+
name_index = _find_name_index(api_signature)
|
| 1096 |
+
|
| 1097 |
+
need_to_bind_api_args = (len(api_signature.parameters) > 3 or
|
| 1098 |
+
"name" not in api_signature.parameters)
|
| 1099 |
+
|
| 1100 |
+
@dispatch_for_api(api, {x_name: x_type, y_name: y_type})
|
| 1101 |
+
def dispatch_target(*args, **kwargs):
|
| 1102 |
+
args, kwargs, name = _extract_name_arg(args, kwargs, name_index)
|
| 1103 |
+
if len(args) > 1:
|
| 1104 |
+
x, y, args = args[0], args[1], args[2:]
|
| 1105 |
+
elif args:
|
| 1106 |
+
x, args = args[0], args[1:]
|
| 1107 |
+
y = kwargs.pop(y_name, None)
|
| 1108 |
+
else:
|
| 1109 |
+
x = kwargs.pop(x_name, None)
|
| 1110 |
+
y = kwargs.pop(y_name, None)
|
| 1111 |
+
|
| 1112 |
+
if need_to_bind_api_args:
|
| 1113 |
+
tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs)
|
| 1114 |
+
else:
|
| 1115 |
+
tensor_api = api
|
| 1116 |
+
|
| 1117 |
+
if name is None:
|
| 1118 |
+
return elementwise_api_handler(tensor_api, x, y)
|
| 1119 |
+
else:
|
| 1120 |
+
with ops.name_scope(name, None, [x, y]):
|
| 1121 |
+
return elementwise_api_handler(tensor_api, x, y)
|
| 1122 |
+
|
| 1123 |
+
dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__
|
| 1124 |
+
dispatch_target.__qualname__ = dispatch_target.__name__
|
| 1125 |
+
# Keep track of what targets we've registered (so we can unregister them).
|
| 1126 |
+
target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), [])
|
| 1127 |
+
target_list.append((api, dispatch_target))
|
| 1128 |
+
|
| 1129 |
+
|
| 1130 |
+
def _find_name_index(signature):
|
| 1131 |
+
"""Returns the index of the `name` parameter, or -1 if it's not present."""
|
| 1132 |
+
try:
|
| 1133 |
+
return list(signature.parameters).index("name")
|
| 1134 |
+
except ValueError:
|
| 1135 |
+
return -1
|
| 1136 |
+
|
| 1137 |
+
|
| 1138 |
+
def _extract_name_arg(args, kwargs, name_index):
|
| 1139 |
+
"""Extracts the parameter `name` and returns `(args, kwargs, name_value)`."""
|
| 1140 |
+
if name_index < 0:
|
| 1141 |
+
name_value = None
|
| 1142 |
+
elif name_index < len(args):
|
| 1143 |
+
name_value = args[name_index]
|
| 1144 |
+
args = args[:name_index] + args[name_index + 1:]
|
| 1145 |
+
else:
|
| 1146 |
+
name_value = kwargs.pop("name", None)
|
| 1147 |
+
return args, kwargs, name_value
|
| 1148 |
+
|
| 1149 |
+
|
| 1150 |
+
def update_docstrings_with_api_lists():
|
| 1151 |
+
"""Updates the docstrings of dispatch decorators with API lists.
|
| 1152 |
+
|
| 1153 |
+
Updates docstrings for `dispatch_for_api`,
|
| 1154 |
+
`dispatch_for_unary_elementwise_apis`, and
|
| 1155 |
+
`dispatch_for_binary_elementwise_apis`, by replacing the string '<<API_LIST>>'
|
| 1156 |
+
with a list of APIs that have been registered for that decorator.
|
| 1157 |
+
"""
|
| 1158 |
+
_update_docstring_with_api_list(dispatch_for_unary_elementwise_apis,
|
| 1159 |
+
_UNARY_ELEMENTWISE_APIS)
|
| 1160 |
+
_update_docstring_with_api_list(dispatch_for_binary_elementwise_apis,
|
| 1161 |
+
_BINARY_ELEMENTWISE_APIS)
|
| 1162 |
+
_update_docstring_with_api_list(dispatch_for_binary_elementwise_assert_apis,
|
| 1163 |
+
_BINARY_ELEMENTWISE_ASSERT_APIS)
|
| 1164 |
+
_update_docstring_with_api_list(dispatch_for_api,
|
| 1165 |
+
_TYPE_BASED_DISPATCH_SIGNATURES)
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
def _update_docstring_with_api_list(target, api_list):
|
| 1169 |
+
"""Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs."""
|
| 1170 |
+
lines = []
|
| 1171 |
+
for func in api_list:
|
| 1172 |
+
name = tf_export_lib.get_canonical_name_for_symbol(
|
| 1173 |
+
func, add_prefix_to_v1_names=True)
|
| 1174 |
+
if name is not None:
|
| 1175 |
+
params = tf_inspect.signature(func).parameters.keys()
|
| 1176 |
+
lines.append(f" * `tf.{name}({', '.join(params)})`")
|
| 1177 |
+
lines.sort()
|
| 1178 |
+
target.__doc__ = target.__doc__.replace(" <<API_LIST>>", "\n".join(lines))
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
################################################################################
|
| 1182 |
+
# Dispatch Support
|
| 1183 |
+
################################################################################
|
| 1184 |
+
@tf_export("__internal__.dispatch.add_dispatch_support", v1=[])
|
| 1185 |
+
def add_dispatch_support(target=None, iterable_parameters=None):
|
| 1186 |
+
"""Decorator that adds a dispatch handling wrapper to a TensorFlow Python API.
|
| 1187 |
+
|
| 1188 |
+
This wrapper adds the decorated function as an API that can be overridden
|
| 1189 |
+
using the `@dispatch_for_api` decorator. In the following example, we first
|
| 1190 |
+
define a new API (`double`) that supports dispatch, then define a custom type
|
| 1191 |
+
(`MaskedTensor`) and finally use `dispatch_for_api` to override the default
|
| 1192 |
+
implementation of `double` when called with `MaskedTensor` values:
|
| 1193 |
+
|
| 1194 |
+
>>> @add_dispatch_support
|
| 1195 |
+
... def double(x):
|
| 1196 |
+
... return x * 2
|
| 1197 |
+
>>> class MaskedTensor(tf.experimental.ExtensionType):
|
| 1198 |
+
... values: tf.Tensor
|
| 1199 |
+
... mask: tf.Tensor
|
| 1200 |
+
>>> @dispatch_for_api(double, {'x': MaskedTensor})
|
| 1201 |
+
... def masked_double(x):
|
| 1202 |
+
... return MaskedTensor(x.values * 2, y.mask)
|
| 1203 |
+
|
| 1204 |
+
The optional `iterable_parameter` argument can be used to mark parameters that
|
| 1205 |
+
can take arbitrary iterable values (such as generator expressions). These
|
| 1206 |
+
need to be handled specially during dispatch, since just iterating over an
|
| 1207 |
+
iterable uses up its values. In the following example, we define a new API
|
| 1208 |
+
whose second argument can be an iterable value; and then override the default
|
| 1209 |
+
implementatio of that API when the iterable contains MaskedTensors:
|
| 1210 |
+
|
| 1211 |
+
>>> @add_dispatch_support(iterable_parameters=['ys'])
|
| 1212 |
+
... def add_tensor_to_list_of_tensors(x, ys):
|
| 1213 |
+
... return [x + y for y in ys]
|
| 1214 |
+
>>> @dispatch_for_api(add_tensor_to_list_of_tensors,
|
| 1215 |
+
... {'ys': typing.List[MaskedTensor]})
|
| 1216 |
+
... def masked_add_tensor_to_list_of_tensors(x, ys):
|
| 1217 |
+
... return [MaskedTensor(x+y.values, y.mask) for y in ys]
|
| 1218 |
+
|
| 1219 |
+
(Note: the only TensorFlow API that currently supports iterables is `add_n`.)
|
| 1220 |
+
|
| 1221 |
+
Args:
|
| 1222 |
+
target: The TensorFlow API that should support dispatch.
|
| 1223 |
+
iterable_parameters: Optional list of parameter names that may be called
|
| 1224 |
+
with iterables (such as the `inputs` parameter for `tf.add_n`).
|
| 1225 |
+
|
| 1226 |
+
Returns:
|
| 1227 |
+
A decorator.
|
| 1228 |
+
"""
|
| 1229 |
+
|
| 1230 |
+
if not (iterable_parameters is None or
|
| 1231 |
+
(isinstance(iterable_parameters, (list, tuple)) and
|
| 1232 |
+
all(isinstance(p, str) for p in iterable_parameters))):
|
| 1233 |
+
raise TypeError("iterable_parameters should be a list or tuple of string.")
|
| 1234 |
+
|
| 1235 |
+
def decorator(dispatch_target):
|
| 1236 |
+
|
| 1237 |
+
# Get the name & index for each iterable parameter.
|
| 1238 |
+
if iterable_parameters is None:
|
| 1239 |
+
iterable_params = None
|
| 1240 |
+
else:
|
| 1241 |
+
arg_names = tf_inspect.getargspec(dispatch_target).args
|
| 1242 |
+
iterable_params = [
|
| 1243 |
+
(name, arg_names.index(name)) for name in iterable_parameters
|
| 1244 |
+
]
|
| 1245 |
+
|
| 1246 |
+
@traceback_utils.filter_traceback
|
| 1247 |
+
def op_dispatch_handler(*args, **kwargs):
|
| 1248 |
+
"""Call `dispatch_target`, peforming dispatch when appropriate."""
|
| 1249 |
+
|
| 1250 |
+
# Type-based dispatch system (dispatch v2):
|
| 1251 |
+
if api_dispatcher is not None:
|
| 1252 |
+
if iterable_params is not None:
|
| 1253 |
+
args, kwargs = replace_iterable_params(args, kwargs, iterable_params)
|
| 1254 |
+
result = api_dispatcher.Dispatch(args, kwargs)
|
| 1255 |
+
if result is not NotImplemented:
|
| 1256 |
+
return result
|
| 1257 |
+
|
| 1258 |
+
# Fallback dispatch system (dispatch v1):
|
| 1259 |
+
try:
|
| 1260 |
+
return dispatch_target(*args, **kwargs)
|
| 1261 |
+
except (TypeError, ValueError):
|
| 1262 |
+
# Note: convert_to_eager_tensor currently raises a ValueError, not a
|
| 1263 |
+
# TypeError, when given unexpected types. So we need to catch both.
|
| 1264 |
+
result = dispatch(op_dispatch_handler, args, kwargs)
|
| 1265 |
+
if result is not OpDispatcher.NOT_SUPPORTED:
|
| 1266 |
+
return result
|
| 1267 |
+
else:
|
| 1268 |
+
raise
|
| 1269 |
+
|
| 1270 |
+
add_fallback_dispatch_list(op_dispatch_handler)
|
| 1271 |
+
op_dispatch_handler = tf_decorator.make_decorator(dispatch_target,
|
| 1272 |
+
op_dispatch_handler)
|
| 1273 |
+
add_type_based_api_dispatcher(op_dispatch_handler)
|
| 1274 |
+
api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR,
|
| 1275 |
+
None)
|
| 1276 |
+
return op_dispatch_handler
|
| 1277 |
+
|
| 1278 |
+
if target is None:
|
| 1279 |
+
return decorator
|
| 1280 |
+
else:
|
| 1281 |
+
return decorator(target)
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
def replace_iterable_params(args, kwargs, iterable_params):
|
| 1285 |
+
"""Returns (args, kwargs) with any iterable parameters converted to lists.
|
| 1286 |
+
|
| 1287 |
+
Args:
|
| 1288 |
+
args: Positional rguments to a function
|
| 1289 |
+
kwargs: Keyword arguments to a function.
|
| 1290 |
+
iterable_params: A list of (name, index) tuples for iterable parameters.
|
| 1291 |
+
|
| 1292 |
+
Returns:
|
| 1293 |
+
A tuple (args, kwargs), where any positional or keyword parameters in
|
| 1294 |
+
`iterable_params` have their value converted to a `list`.
|
| 1295 |
+
"""
|
| 1296 |
+
args = list(args)
|
| 1297 |
+
for name, index in iterable_params:
|
| 1298 |
+
if index < len(args):
|
| 1299 |
+
args[index] = list(args[index])
|
| 1300 |
+
elif name in kwargs:
|
| 1301 |
+
kwargs[name] = list(kwargs[name])
|
| 1302 |
+
return tuple(args), kwargs
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/example_parser_configuration.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Extract parse_example op configuration to a proto."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.core.example import example_parser_configuration_pb2
|
| 18 |
+
from tensorflow.python.framework import tensor_shape
|
| 19 |
+
from tensorflow.python.framework import tensor_util
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def extract_example_parser_configuration(parse_example_op, sess):
|
| 23 |
+
"""Returns an ExampleParserConfig proto.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
parse_example_op: A ParseExample or ParseExampleV2 `Operation`
|
| 27 |
+
sess: A tf.compat.v1.Session needed to obtain some configuration values.
|
| 28 |
+
Returns:
|
| 29 |
+
A ExampleParserConfig proto.
|
| 30 |
+
|
| 31 |
+
Raises:
|
| 32 |
+
ValueError: If attributes are inconsistent.
|
| 33 |
+
"""
|
| 34 |
+
if parse_example_op.type == "ParseExample":
|
| 35 |
+
return _extract_from_parse_example(parse_example_op, sess)
|
| 36 |
+
elif parse_example_op.type == "ParseExampleV2":
|
| 37 |
+
return _extract_from_parse_example_v2(parse_example_op, sess)
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(
|
| 40 |
+
"Found unexpected type when parsing example. Expected `ParseExample` "
|
| 41 |
+
f"object. Received type: {parse_example_op.type}")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _extract_from_parse_example(parse_example_op, sess):
|
| 45 |
+
"""Extract ExampleParserConfig from ParseExample op."""
|
| 46 |
+
config = example_parser_configuration_pb2.ExampleParserConfiguration()
|
| 47 |
+
|
| 48 |
+
num_sparse = parse_example_op.get_attr("Nsparse")
|
| 49 |
+
num_dense = parse_example_op.get_attr("Ndense")
|
| 50 |
+
total_features = num_dense + num_sparse
|
| 51 |
+
|
| 52 |
+
sparse_types = parse_example_op.get_attr("sparse_types")
|
| 53 |
+
dense_types = parse_example_op.get_attr("Tdense")
|
| 54 |
+
dense_shapes = parse_example_op.get_attr("dense_shapes")
|
| 55 |
+
|
| 56 |
+
if len(sparse_types) != num_sparse:
|
| 57 |
+
raise ValueError("len(sparse_types) attribute does not match "
|
| 58 |
+
"Nsparse attribute (%d vs %d)" %
|
| 59 |
+
(len(sparse_types), num_sparse))
|
| 60 |
+
|
| 61 |
+
if len(dense_types) != num_dense:
|
| 62 |
+
raise ValueError("len(dense_types) attribute does not match "
|
| 63 |
+
"Ndense attribute (%d vs %d)" %
|
| 64 |
+
(len(dense_types), num_dense))
|
| 65 |
+
|
| 66 |
+
if len(dense_shapes) != num_dense:
|
| 67 |
+
raise ValueError("len(dense_shapes) attribute does not match "
|
| 68 |
+
"Ndense attribute (%d vs %d)" %
|
| 69 |
+
(len(dense_shapes), num_dense))
|
| 70 |
+
|
| 71 |
+
# Skip over the serialized input, and the names input.
|
| 72 |
+
fetch_list = parse_example_op.inputs[2:]
|
| 73 |
+
|
| 74 |
+
# Fetch total_features key names and num_dense default values.
|
| 75 |
+
if len(fetch_list) != (total_features + num_dense):
|
| 76 |
+
raise ValueError("len(fetch_list) does not match total features + "
|
| 77 |
+
"num_dense (%d vs %d)" %
|
| 78 |
+
(len(fetch_list), (total_features + num_dense)))
|
| 79 |
+
|
| 80 |
+
fetched = sess.run(fetch_list)
|
| 81 |
+
|
| 82 |
+
if len(fetched) != len(fetch_list):
|
| 83 |
+
raise ValueError("len(fetched) does not match len(fetch_list) "
|
| 84 |
+
"(%d vs %d)" % (len(fetched), len(fetch_list)))
|
| 85 |
+
|
| 86 |
+
# Fetch indices.
|
| 87 |
+
sparse_keys_start = 0
|
| 88 |
+
dense_keys_start = sparse_keys_start + num_sparse
|
| 89 |
+
dense_def_start = dense_keys_start + num_dense
|
| 90 |
+
|
| 91 |
+
# Output tensor indices.
|
| 92 |
+
sparse_indices_start = 0
|
| 93 |
+
sparse_values_start = num_sparse
|
| 94 |
+
sparse_shapes_start = sparse_values_start + num_sparse
|
| 95 |
+
dense_values_start = sparse_shapes_start + num_sparse
|
| 96 |
+
|
| 97 |
+
# Dense features.
|
| 98 |
+
for i in range(num_dense):
|
| 99 |
+
key = fetched[dense_keys_start + i]
|
| 100 |
+
feature_config = config.feature_map[key]
|
| 101 |
+
# Convert the default value numpy array fetched from the session run
|
| 102 |
+
# into a TensorProto.
|
| 103 |
+
fixed_config = feature_config.fixed_len_feature
|
| 104 |
+
|
| 105 |
+
fixed_config.default_value.CopyFrom(
|
| 106 |
+
tensor_util.make_tensor_proto(fetched[dense_def_start + i]))
|
| 107 |
+
# Convert the shape from the attributes
|
| 108 |
+
# into a TensorShapeProto.
|
| 109 |
+
fixed_config.shape.CopyFrom(
|
| 110 |
+
tensor_shape.TensorShape(dense_shapes[i]).as_proto())
|
| 111 |
+
|
| 112 |
+
fixed_config.dtype = dense_types[i].as_datatype_enum
|
| 113 |
+
# Get the output tensor name.
|
| 114 |
+
fixed_config.values_output_tensor_name = parse_example_op.outputs[
|
| 115 |
+
dense_values_start + i].name
|
| 116 |
+
|
| 117 |
+
# Sparse features.
|
| 118 |
+
for i in range(num_sparse):
|
| 119 |
+
key = fetched[sparse_keys_start + i]
|
| 120 |
+
feature_config = config.feature_map[key]
|
| 121 |
+
var_len_feature = feature_config.var_len_feature
|
| 122 |
+
var_len_feature.dtype = sparse_types[i].as_datatype_enum
|
| 123 |
+
var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
|
| 124 |
+
sparse_indices_start + i].name
|
| 125 |
+
var_len_feature.values_output_tensor_name = parse_example_op.outputs[
|
| 126 |
+
sparse_values_start + i].name
|
| 127 |
+
var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
|
| 128 |
+
sparse_shapes_start + i].name
|
| 129 |
+
|
| 130 |
+
return config
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _extract_from_parse_example_v2(parse_example_op, sess):
|
| 134 |
+
"""Extract ExampleParserConfig from ParseExampleV2 op."""
|
| 135 |
+
config = example_parser_configuration_pb2.ExampleParserConfiguration()
|
| 136 |
+
|
| 137 |
+
dense_types = parse_example_op.get_attr("Tdense")
|
| 138 |
+
num_sparse = parse_example_op.get_attr("num_sparse")
|
| 139 |
+
sparse_types = parse_example_op.get_attr("sparse_types")
|
| 140 |
+
ragged_value_types = parse_example_op.get_attr("ragged_value_types")
|
| 141 |
+
ragged_split_types = parse_example_op.get_attr("ragged_split_types")
|
| 142 |
+
dense_shapes = parse_example_op.get_attr("dense_shapes")
|
| 143 |
+
|
| 144 |
+
num_dense = len(dense_types)
|
| 145 |
+
num_ragged = len(ragged_value_types)
|
| 146 |
+
assert len(ragged_value_types) == len(ragged_split_types)
|
| 147 |
+
assert len(parse_example_op.inputs) == 5 + num_dense
|
| 148 |
+
|
| 149 |
+
# Skip over the serialized input, and the names input.
|
| 150 |
+
fetched = sess.run(parse_example_op.inputs[2:])
|
| 151 |
+
sparse_keys = fetched[0].tolist()
|
| 152 |
+
dense_keys = fetched[1].tolist()
|
| 153 |
+
ragged_keys = fetched[2].tolist()
|
| 154 |
+
dense_defaults = fetched[3:]
|
| 155 |
+
assert len(sparse_keys) == num_sparse
|
| 156 |
+
assert len(dense_keys) == num_dense
|
| 157 |
+
assert len(ragged_keys) == num_ragged
|
| 158 |
+
|
| 159 |
+
# Output tensor indices.
|
| 160 |
+
sparse_indices_start = 0
|
| 161 |
+
sparse_values_start = num_sparse
|
| 162 |
+
sparse_shapes_start = sparse_values_start + num_sparse
|
| 163 |
+
dense_values_start = sparse_shapes_start + num_sparse
|
| 164 |
+
ragged_values_start = dense_values_start + num_dense
|
| 165 |
+
ragged_row_splits_start = ragged_values_start + num_ragged
|
| 166 |
+
|
| 167 |
+
# Dense features.
|
| 168 |
+
for i in range(num_dense):
|
| 169 |
+
key = dense_keys[i]
|
| 170 |
+
feature_config = config.feature_map[key]
|
| 171 |
+
# Convert the default value numpy array fetched from the session run
|
| 172 |
+
# into a TensorProto.
|
| 173 |
+
fixed_config = feature_config.fixed_len_feature
|
| 174 |
+
|
| 175 |
+
fixed_config.default_value.CopyFrom(
|
| 176 |
+
tensor_util.make_tensor_proto(dense_defaults[i]))
|
| 177 |
+
# Convert the shape from the attributes
|
| 178 |
+
# into a TensorShapeProto.
|
| 179 |
+
fixed_config.shape.CopyFrom(
|
| 180 |
+
tensor_shape.TensorShape(dense_shapes[i]).as_proto())
|
| 181 |
+
|
| 182 |
+
fixed_config.dtype = dense_types[i].as_datatype_enum
|
| 183 |
+
# Get the output tensor name.
|
| 184 |
+
fixed_config.values_output_tensor_name = parse_example_op.outputs[
|
| 185 |
+
dense_values_start + i].name
|
| 186 |
+
|
| 187 |
+
# Sparse features.
|
| 188 |
+
for i in range(num_sparse):
|
| 189 |
+
key = sparse_keys[i]
|
| 190 |
+
feature_config = config.feature_map[key]
|
| 191 |
+
var_len_feature = feature_config.var_len_feature
|
| 192 |
+
var_len_feature.dtype = sparse_types[i].as_datatype_enum
|
| 193 |
+
var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
|
| 194 |
+
sparse_indices_start + i].name
|
| 195 |
+
var_len_feature.values_output_tensor_name = parse_example_op.outputs[
|
| 196 |
+
sparse_values_start + i].name
|
| 197 |
+
var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
|
| 198 |
+
sparse_shapes_start + i].name
|
| 199 |
+
|
| 200 |
+
if num_ragged != 0:
|
| 201 |
+
del ragged_values_start # unused
|
| 202 |
+
del ragged_row_splits_start # unused
|
| 203 |
+
raise ValueError("Ragged features are not yet supported by "
|
| 204 |
+
"example_parser_configuration.proto")
|
| 205 |
+
|
| 206 |
+
return config
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/util/fast_module_type.pyi
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
def get_fast_module_type_class() -> object: ...
|