Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/gen_rpc_ops.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/gen_rpc_ops.py +763 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2_grpc.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2.py +37 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2_grpc.py +63 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/analyzer.py +107 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/conversion_metadata_schema_py_generated.py +568 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/convert_phase.py +219 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/schema_util.py +45 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/tflite_convert.py +696 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/util.py +1177 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/flatbuffer_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/visualize.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/flatbuffer_utils.py +455 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/debugger.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/debugger.py +549 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/visualize.py +549 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.pyi +18 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.so +3 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__init__.py +63 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/conditional_expressions.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/control_flow.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/data_structures.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/exceptions.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/logical.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -197,3 +197,4 @@ SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/gr
|
|
| 197 |
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/grappler/_pywrap_tf_item.so filter=lfs diff=lfs merge=lfs -text
|
| 198 |
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/profiler/internal/_pywrap_profiler.so filter=lfs diff=lfs merge=lfs -text
|
| 199 |
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/client/_pywrap_tf_session.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 197 |
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/grappler/_pywrap_tf_item.so filter=lfs diff=lfs merge=lfs -text
|
| 198 |
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/profiler/internal/_pywrap_profiler.so filter=lfs diff=lfs merge=lfs -text
|
| 199 |
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/client/_pywrap_tf_session.so filter=lfs diff=lfs merge=lfs -text
|
| 200 |
+
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.so filter=lfs diff=lfs merge=lfs -text
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (210 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (214 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (222 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/gen_rpc_ops.cpython-310.pyc
ADDED
|
Binary file (16.6 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/gen_rpc_ops.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Python wrappers around TensorFlow ops.
|
| 2 |
+
|
| 3 |
+
This file is MACHINE GENERATED! Do not edit.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import collections
|
| 7 |
+
|
| 8 |
+
from tensorflow.python import pywrap_tfe as pywrap_tfe
|
| 9 |
+
from tensorflow.python.eager import context as _context
|
| 10 |
+
from tensorflow.python.eager import core as _core
|
| 11 |
+
from tensorflow.python.eager import execute as _execute
|
| 12 |
+
from tensorflow.python.framework import dtypes as _dtypes
|
| 13 |
+
from tensorflow.security.fuzzing.py import annotation_types as _atypes
|
| 14 |
+
|
| 15 |
+
from tensorflow.python.framework import op_def_registry as _op_def_registry
|
| 16 |
+
from tensorflow.python.framework import ops as _ops
|
| 17 |
+
from tensorflow.python.framework import op_def_library as _op_def_library
|
| 18 |
+
from tensorflow.python.util.deprecation import deprecated_endpoints
|
| 19 |
+
from tensorflow.python.util import dispatch as _dispatch
|
| 20 |
+
from tensorflow.python.util.tf_export import tf_export
|
| 21 |
+
|
| 22 |
+
from typing import TypeVar, List, Any
|
| 23 |
+
from typing_extensions import Annotated
|
| 24 |
+
|
| 25 |
+
@_dispatch.add_fallback_dispatch_list
|
| 26 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 27 |
+
@tf_export('delete_rpc_future_resource')
|
| 28 |
+
def delete_rpc_future_resource(handle: Annotated[Any, _atypes.Resource], deleter: Annotated[Any, _atypes.Variant], name=None):
|
| 29 |
+
r"""TODO: add doc.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
handle: A `Tensor` of type `resource`.
|
| 33 |
+
deleter: A `Tensor` of type `variant`.
|
| 34 |
+
name: A name for the operation (optional).
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
The created Operation.
|
| 38 |
+
"""
|
| 39 |
+
_ctx = _context._context or _context.context()
|
| 40 |
+
tld = _ctx._thread_local_data
|
| 41 |
+
if tld.is_eager:
|
| 42 |
+
try:
|
| 43 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 44 |
+
_ctx, "DeleteRpcFutureResource", name, handle, deleter)
|
| 45 |
+
return _result
|
| 46 |
+
except _core._NotOkStatusException as e:
|
| 47 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 48 |
+
except _core._FallbackException:
|
| 49 |
+
pass
|
| 50 |
+
try:
|
| 51 |
+
_result = _dispatcher_for_delete_rpc_future_resource(
|
| 52 |
+
(handle, deleter, name,), None)
|
| 53 |
+
if _result is not NotImplemented:
|
| 54 |
+
return _result
|
| 55 |
+
return delete_rpc_future_resource_eager_fallback(
|
| 56 |
+
handle, deleter, name=name, ctx=_ctx)
|
| 57 |
+
except _core._SymbolicException:
|
| 58 |
+
pass # Add nodes to the TensorFlow graph.
|
| 59 |
+
except (TypeError, ValueError):
|
| 60 |
+
_result = _dispatch.dispatch(
|
| 61 |
+
delete_rpc_future_resource, (), dict(handle=handle,
|
| 62 |
+
deleter=deleter, name=name)
|
| 63 |
+
)
|
| 64 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 65 |
+
return _result
|
| 66 |
+
raise
|
| 67 |
+
else:
|
| 68 |
+
_result = _dispatcher_for_delete_rpc_future_resource(
|
| 69 |
+
(handle, deleter, name,), None)
|
| 70 |
+
if _result is not NotImplemented:
|
| 71 |
+
return _result
|
| 72 |
+
# Add nodes to the TensorFlow graph.
|
| 73 |
+
try:
|
| 74 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 75 |
+
"DeleteRpcFutureResource", handle=handle, deleter=deleter, name=name)
|
| 76 |
+
except (TypeError, ValueError):
|
| 77 |
+
_result = _dispatch.dispatch(
|
| 78 |
+
delete_rpc_future_resource, (), dict(handle=handle, deleter=deleter,
|
| 79 |
+
name=name)
|
| 80 |
+
)
|
| 81 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 82 |
+
return _result
|
| 83 |
+
raise
|
| 84 |
+
return _op
|
| 85 |
+
DeleteRpcFutureResource = tf_export("raw_ops.DeleteRpcFutureResource")(_ops.to_raw_op(delete_rpc_future_resource))
|
| 86 |
+
_dispatcher_for_delete_rpc_future_resource = delete_rpc_future_resource._tf_type_based_dispatcher.Dispatch
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def delete_rpc_future_resource_eager_fallback(handle: Annotated[Any, _atypes.Resource], deleter: Annotated[Any, _atypes.Variant], name, ctx):
|
| 90 |
+
handle = _ops.convert_to_tensor(handle, _dtypes.resource)
|
| 91 |
+
deleter = _ops.convert_to_tensor(deleter, _dtypes.variant)
|
| 92 |
+
_inputs_flat = [handle, deleter]
|
| 93 |
+
_attrs = None
|
| 94 |
+
_result = _execute.execute(b"DeleteRpcFutureResource", 0,
|
| 95 |
+
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
|
| 96 |
+
name=name)
|
| 97 |
+
_result = None
|
| 98 |
+
return _result
|
| 99 |
+
|
| 100 |
+
_RpcCallOutput = collections.namedtuple(
|
| 101 |
+
"RpcCall",
|
| 102 |
+
["future", "deleter"])
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@_dispatch.add_fallback_dispatch_list
|
| 106 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 107 |
+
@tf_export('rpc_call')
|
| 108 |
+
def rpc_call(client: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], args, timeout_in_ms: Annotated[Any, _atypes.Int64], name=None):
|
| 109 |
+
r"""TODO: add doc.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
client: A `Tensor` of type `resource`.
|
| 113 |
+
method_name: A `Tensor` of type `string`.
|
| 114 |
+
args: A list of `Tensor` objects.
|
| 115 |
+
timeout_in_ms: A `Tensor` of type `int64`.
|
| 116 |
+
name: A name for the operation (optional).
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
A tuple of `Tensor` objects (future, deleter).
|
| 120 |
+
|
| 121 |
+
future: A `Tensor` of type `resource`.
|
| 122 |
+
deleter: A `Tensor` of type `variant`.
|
| 123 |
+
"""
|
| 124 |
+
_ctx = _context._context or _context.context()
|
| 125 |
+
tld = _ctx._thread_local_data
|
| 126 |
+
if tld.is_eager:
|
| 127 |
+
try:
|
| 128 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 129 |
+
_ctx, "RpcCall", name, client, method_name, args, timeout_in_ms)
|
| 130 |
+
_result = _RpcCallOutput._make(_result)
|
| 131 |
+
return _result
|
| 132 |
+
except _core._NotOkStatusException as e:
|
| 133 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 134 |
+
except _core._FallbackException:
|
| 135 |
+
pass
|
| 136 |
+
try:
|
| 137 |
+
_result = _dispatcher_for_rpc_call(
|
| 138 |
+
(client, method_name, args, timeout_in_ms, name,), None)
|
| 139 |
+
if _result is not NotImplemented:
|
| 140 |
+
return _result
|
| 141 |
+
return rpc_call_eager_fallback(
|
| 142 |
+
client, method_name, args, timeout_in_ms, name=name, ctx=_ctx)
|
| 143 |
+
except _core._SymbolicException:
|
| 144 |
+
pass # Add nodes to the TensorFlow graph.
|
| 145 |
+
except (TypeError, ValueError):
|
| 146 |
+
_result = _dispatch.dispatch(
|
| 147 |
+
rpc_call, (), dict(client=client, method_name=method_name,
|
| 148 |
+
args=args, timeout_in_ms=timeout_in_ms,
|
| 149 |
+
name=name)
|
| 150 |
+
)
|
| 151 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 152 |
+
return _result
|
| 153 |
+
raise
|
| 154 |
+
else:
|
| 155 |
+
_result = _dispatcher_for_rpc_call(
|
| 156 |
+
(client, method_name, args, timeout_in_ms, name,), None)
|
| 157 |
+
if _result is not NotImplemented:
|
| 158 |
+
return _result
|
| 159 |
+
# Add nodes to the TensorFlow graph.
|
| 160 |
+
try:
|
| 161 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 162 |
+
"RpcCall", client=client, method_name=method_name, args=args,
|
| 163 |
+
timeout_in_ms=timeout_in_ms, name=name)
|
| 164 |
+
except (TypeError, ValueError):
|
| 165 |
+
_result = _dispatch.dispatch(
|
| 166 |
+
rpc_call, (), dict(client=client, method_name=method_name,
|
| 167 |
+
args=args, timeout_in_ms=timeout_in_ms,
|
| 168 |
+
name=name)
|
| 169 |
+
)
|
| 170 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 171 |
+
return _result
|
| 172 |
+
raise
|
| 173 |
+
_result = _outputs[:]
|
| 174 |
+
if _execute.must_record_gradient():
|
| 175 |
+
_attrs = ("Tin", _op.get_attr("Tin"))
|
| 176 |
+
_inputs_flat = _op.inputs
|
| 177 |
+
_execute.record_gradient(
|
| 178 |
+
"RpcCall", _inputs_flat, _attrs, _result)
|
| 179 |
+
_result = _RpcCallOutput._make(_result)
|
| 180 |
+
return _result
|
| 181 |
+
|
| 182 |
+
RpcCall = tf_export("raw_ops.RpcCall")(_ops.to_raw_op(rpc_call))
|
| 183 |
+
_dispatcher_for_rpc_call = rpc_call._tf_type_based_dispatcher.Dispatch
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def rpc_call_eager_fallback(client: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], args, timeout_in_ms: Annotated[Any, _atypes.Int64], name, ctx):
|
| 187 |
+
_attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
|
| 188 |
+
client = _ops.convert_to_tensor(client, _dtypes.resource)
|
| 189 |
+
method_name = _ops.convert_to_tensor(method_name, _dtypes.string)
|
| 190 |
+
timeout_in_ms = _ops.convert_to_tensor(timeout_in_ms, _dtypes.int64)
|
| 191 |
+
_inputs_flat = [client, method_name] + list(args) + [timeout_in_ms]
|
| 192 |
+
_attrs = ("Tin", _attr_Tin)
|
| 193 |
+
_result = _execute.execute(b"RpcCall", 2, inputs=_inputs_flat, attrs=_attrs,
|
| 194 |
+
ctx=ctx, name=name)
|
| 195 |
+
if _execute.must_record_gradient():
|
| 196 |
+
_execute.record_gradient(
|
| 197 |
+
"RpcCall", _inputs_flat, _attrs, _result)
|
| 198 |
+
_result = _RpcCallOutput._make(_result)
|
| 199 |
+
return _result
|
| 200 |
+
|
| 201 |
+
_RpcCheckStatusOutput = collections.namedtuple(
|
| 202 |
+
"RpcCheckStatus",
|
| 203 |
+
["error_code", "error"])
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@_dispatch.add_fallback_dispatch_list
|
| 207 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 208 |
+
@tf_export('rpc_check_status')
|
| 209 |
+
def rpc_check_status(status_or: Annotated[Any, _atypes.Resource], name=None):
|
| 210 |
+
r"""TODO: add doc.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
status_or: A `Tensor` of type `resource`.
|
| 214 |
+
name: A name for the operation (optional).
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
A tuple of `Tensor` objects (error_code, error).
|
| 218 |
+
|
| 219 |
+
error_code: A `Tensor` of type `int64`.
|
| 220 |
+
error: A `Tensor` of type `string`.
|
| 221 |
+
"""
|
| 222 |
+
_ctx = _context._context or _context.context()
|
| 223 |
+
tld = _ctx._thread_local_data
|
| 224 |
+
if tld.is_eager:
|
| 225 |
+
try:
|
| 226 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 227 |
+
_ctx, "RpcCheckStatus", name, status_or)
|
| 228 |
+
_result = _RpcCheckStatusOutput._make(_result)
|
| 229 |
+
return _result
|
| 230 |
+
except _core._NotOkStatusException as e:
|
| 231 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 232 |
+
except _core._FallbackException:
|
| 233 |
+
pass
|
| 234 |
+
try:
|
| 235 |
+
_result = _dispatcher_for_rpc_check_status(
|
| 236 |
+
(status_or, name,), None)
|
| 237 |
+
if _result is not NotImplemented:
|
| 238 |
+
return _result
|
| 239 |
+
return rpc_check_status_eager_fallback(
|
| 240 |
+
status_or, name=name, ctx=_ctx)
|
| 241 |
+
except _core._SymbolicException:
|
| 242 |
+
pass # Add nodes to the TensorFlow graph.
|
| 243 |
+
except (TypeError, ValueError):
|
| 244 |
+
_result = _dispatch.dispatch(
|
| 245 |
+
rpc_check_status, (), dict(status_or=status_or, name=name)
|
| 246 |
+
)
|
| 247 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 248 |
+
return _result
|
| 249 |
+
raise
|
| 250 |
+
else:
|
| 251 |
+
_result = _dispatcher_for_rpc_check_status(
|
| 252 |
+
(status_or, name,), None)
|
| 253 |
+
if _result is not NotImplemented:
|
| 254 |
+
return _result
|
| 255 |
+
# Add nodes to the TensorFlow graph.
|
| 256 |
+
try:
|
| 257 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 258 |
+
"RpcCheckStatus", status_or=status_or, name=name)
|
| 259 |
+
except (TypeError, ValueError):
|
| 260 |
+
_result = _dispatch.dispatch(
|
| 261 |
+
rpc_check_status, (), dict(status_or=status_or, name=name)
|
| 262 |
+
)
|
| 263 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 264 |
+
return _result
|
| 265 |
+
raise
|
| 266 |
+
_result = _outputs[:]
|
| 267 |
+
if _execute.must_record_gradient():
|
| 268 |
+
_attrs = ()
|
| 269 |
+
_inputs_flat = _op.inputs
|
| 270 |
+
_execute.record_gradient(
|
| 271 |
+
"RpcCheckStatus", _inputs_flat, _attrs, _result)
|
| 272 |
+
_result = _RpcCheckStatusOutput._make(_result)
|
| 273 |
+
return _result
|
| 274 |
+
|
| 275 |
+
RpcCheckStatus = tf_export("raw_ops.RpcCheckStatus")(_ops.to_raw_op(rpc_check_status))
|
| 276 |
+
_dispatcher_for_rpc_check_status = rpc_check_status._tf_type_based_dispatcher.Dispatch
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def rpc_check_status_eager_fallback(status_or: Annotated[Any, _atypes.Resource], name, ctx):
|
| 280 |
+
status_or = _ops.convert_to_tensor(status_or, _dtypes.resource)
|
| 281 |
+
_inputs_flat = [status_or]
|
| 282 |
+
_attrs = None
|
| 283 |
+
_result = _execute.execute(b"RpcCheckStatus", 2, inputs=_inputs_flat,
|
| 284 |
+
attrs=_attrs, ctx=ctx, name=name)
|
| 285 |
+
if _execute.must_record_gradient():
|
| 286 |
+
_execute.record_gradient(
|
| 287 |
+
"RpcCheckStatus", _inputs_flat, _attrs, _result)
|
| 288 |
+
_result = _RpcCheckStatusOutput._make(_result)
|
| 289 |
+
return _result
|
| 290 |
+
|
| 291 |
+
_RpcClientOutput = collections.namedtuple(
|
| 292 |
+
"RpcClient",
|
| 293 |
+
["client", "method_specs"])
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
@_dispatch.add_fallback_dispatch_list
|
| 297 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 298 |
+
@tf_export('rpc_client')
|
| 299 |
+
def rpc_client(server_address: Annotated[Any, _atypes.String], timeout_in_ms: Annotated[Any, _atypes.Int64], shared_name:str="", list_registered_methods:bool=False, name=None):
|
| 300 |
+
r"""TODO: add doc.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
server_address: A `Tensor` of type `string`.
|
| 304 |
+
timeout_in_ms: A `Tensor` of type `int64`.
|
| 305 |
+
shared_name: An optional `string`. Defaults to `""`.
|
| 306 |
+
list_registered_methods: An optional `bool`. Defaults to `False`.
|
| 307 |
+
name: A name for the operation (optional).
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
A tuple of `Tensor` objects (client, method_specs).
|
| 311 |
+
|
| 312 |
+
client: A `Tensor` of type `resource`.
|
| 313 |
+
method_specs: A `Tensor` of type `string`.
|
| 314 |
+
"""
|
| 315 |
+
_ctx = _context._context or _context.context()
|
| 316 |
+
tld = _ctx._thread_local_data
|
| 317 |
+
if tld.is_eager:
|
| 318 |
+
try:
|
| 319 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 320 |
+
_ctx, "RpcClient", name, server_address, timeout_in_ms, "shared_name",
|
| 321 |
+
shared_name, "list_registered_methods", list_registered_methods)
|
| 322 |
+
_result = _RpcClientOutput._make(_result)
|
| 323 |
+
return _result
|
| 324 |
+
except _core._NotOkStatusException as e:
|
| 325 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 326 |
+
except _core._FallbackException:
|
| 327 |
+
pass
|
| 328 |
+
try:
|
| 329 |
+
_result = _dispatcher_for_rpc_client(
|
| 330 |
+
(server_address, timeout_in_ms, shared_name,
|
| 331 |
+
list_registered_methods, name,), None)
|
| 332 |
+
if _result is not NotImplemented:
|
| 333 |
+
return _result
|
| 334 |
+
return rpc_client_eager_fallback(
|
| 335 |
+
server_address, timeout_in_ms, shared_name=shared_name,
|
| 336 |
+
list_registered_methods=list_registered_methods, name=name,
|
| 337 |
+
ctx=_ctx)
|
| 338 |
+
except _core._SymbolicException:
|
| 339 |
+
pass # Add nodes to the TensorFlow graph.
|
| 340 |
+
except (TypeError, ValueError):
|
| 341 |
+
_result = _dispatch.dispatch(
|
| 342 |
+
rpc_client, (), dict(server_address=server_address,
|
| 343 |
+
timeout_in_ms=timeout_in_ms,
|
| 344 |
+
shared_name=shared_name,
|
| 345 |
+
list_registered_methods=list_registered_methods,
|
| 346 |
+
name=name)
|
| 347 |
+
)
|
| 348 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 349 |
+
return _result
|
| 350 |
+
raise
|
| 351 |
+
else:
|
| 352 |
+
_result = _dispatcher_for_rpc_client(
|
| 353 |
+
(server_address, timeout_in_ms, shared_name, list_registered_methods,
|
| 354 |
+
name,), None)
|
| 355 |
+
if _result is not NotImplemented:
|
| 356 |
+
return _result
|
| 357 |
+
# Add nodes to the TensorFlow graph.
|
| 358 |
+
if shared_name is None:
|
| 359 |
+
shared_name = ""
|
| 360 |
+
shared_name = _execute.make_str(shared_name, "shared_name")
|
| 361 |
+
if list_registered_methods is None:
|
| 362 |
+
list_registered_methods = False
|
| 363 |
+
list_registered_methods = _execute.make_bool(list_registered_methods, "list_registered_methods")
|
| 364 |
+
try:
|
| 365 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 366 |
+
"RpcClient", server_address=server_address,
|
| 367 |
+
timeout_in_ms=timeout_in_ms, shared_name=shared_name,
|
| 368 |
+
list_registered_methods=list_registered_methods,
|
| 369 |
+
name=name)
|
| 370 |
+
except (TypeError, ValueError):
|
| 371 |
+
_result = _dispatch.dispatch(
|
| 372 |
+
rpc_client, (), dict(server_address=server_address,
|
| 373 |
+
timeout_in_ms=timeout_in_ms,
|
| 374 |
+
shared_name=shared_name,
|
| 375 |
+
list_registered_methods=list_registered_methods,
|
| 376 |
+
name=name)
|
| 377 |
+
)
|
| 378 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 379 |
+
return _result
|
| 380 |
+
raise
|
| 381 |
+
_result = _outputs[:]
|
| 382 |
+
if _execute.must_record_gradient():
|
| 383 |
+
_attrs = ("shared_name", _op.get_attr("shared_name"),
|
| 384 |
+
"list_registered_methods",
|
| 385 |
+
_op._get_attr_bool("list_registered_methods"))
|
| 386 |
+
_inputs_flat = _op.inputs
|
| 387 |
+
_execute.record_gradient(
|
| 388 |
+
"RpcClient", _inputs_flat, _attrs, _result)
|
| 389 |
+
_result = _RpcClientOutput._make(_result)
|
| 390 |
+
return _result
|
| 391 |
+
|
| 392 |
+
RpcClient = tf_export("raw_ops.RpcClient")(_ops.to_raw_op(rpc_client))
|
| 393 |
+
_dispatcher_for_rpc_client = rpc_client._tf_type_based_dispatcher.Dispatch
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def rpc_client_eager_fallback(server_address: Annotated[Any, _atypes.String], timeout_in_ms: Annotated[Any, _atypes.Int64], shared_name: str, list_registered_methods: bool, name, ctx):
|
| 397 |
+
if shared_name is None:
|
| 398 |
+
shared_name = ""
|
| 399 |
+
shared_name = _execute.make_str(shared_name, "shared_name")
|
| 400 |
+
if list_registered_methods is None:
|
| 401 |
+
list_registered_methods = False
|
| 402 |
+
list_registered_methods = _execute.make_bool(list_registered_methods, "list_registered_methods")
|
| 403 |
+
server_address = _ops.convert_to_tensor(server_address, _dtypes.string)
|
| 404 |
+
timeout_in_ms = _ops.convert_to_tensor(timeout_in_ms, _dtypes.int64)
|
| 405 |
+
_inputs_flat = [server_address, timeout_in_ms]
|
| 406 |
+
_attrs = ("shared_name", shared_name, "list_registered_methods",
|
| 407 |
+
list_registered_methods)
|
| 408 |
+
_result = _execute.execute(b"RpcClient", 2, inputs=_inputs_flat,
|
| 409 |
+
attrs=_attrs, ctx=ctx, name=name)
|
| 410 |
+
if _execute.must_record_gradient():
|
| 411 |
+
_execute.record_gradient(
|
| 412 |
+
"RpcClient", _inputs_flat, _attrs, _result)
|
| 413 |
+
_result = _RpcClientOutput._make(_result)
|
| 414 |
+
return _result
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@_dispatch.add_fallback_dispatch_list
|
| 418 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 419 |
+
@tf_export('rpc_get_value')
|
| 420 |
+
def rpc_get_value(status_or: Annotated[Any, _atypes.Resource], Tout, name=None):
|
| 421 |
+
r"""TODO: add doc.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
status_or: A `Tensor` of type `resource`.
|
| 425 |
+
Tout: A list of `tf.DTypes`.
|
| 426 |
+
name: A name for the operation (optional).
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
A list of `Tensor` objects of type `Tout`.
|
| 430 |
+
"""
|
| 431 |
+
_ctx = _context._context or _context.context()
|
| 432 |
+
tld = _ctx._thread_local_data
|
| 433 |
+
if tld.is_eager:
|
| 434 |
+
try:
|
| 435 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 436 |
+
_ctx, "RpcGetValue", name, status_or, "Tout", Tout)
|
| 437 |
+
return _result
|
| 438 |
+
except _core._NotOkStatusException as e:
|
| 439 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 440 |
+
except _core._FallbackException:
|
| 441 |
+
pass
|
| 442 |
+
try:
|
| 443 |
+
_result = _dispatcher_for_rpc_get_value(
|
| 444 |
+
(status_or, Tout, name,), None)
|
| 445 |
+
if _result is not NotImplemented:
|
| 446 |
+
return _result
|
| 447 |
+
return rpc_get_value_eager_fallback(
|
| 448 |
+
status_or, Tout=Tout, name=name, ctx=_ctx)
|
| 449 |
+
except _core._SymbolicException:
|
| 450 |
+
pass # Add nodes to the TensorFlow graph.
|
| 451 |
+
except (TypeError, ValueError):
|
| 452 |
+
_result = _dispatch.dispatch(
|
| 453 |
+
rpc_get_value, (), dict(status_or=status_or, Tout=Tout, name=name)
|
| 454 |
+
)
|
| 455 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 456 |
+
return _result
|
| 457 |
+
raise
|
| 458 |
+
else:
|
| 459 |
+
_result = _dispatcher_for_rpc_get_value(
|
| 460 |
+
(status_or, Tout, name,), None)
|
| 461 |
+
if _result is not NotImplemented:
|
| 462 |
+
return _result
|
| 463 |
+
# Add nodes to the TensorFlow graph.
|
| 464 |
+
if not isinstance(Tout, (list, tuple)):
|
| 465 |
+
raise TypeError(
|
| 466 |
+
"Expected list for 'Tout' argument to "
|
| 467 |
+
"'rpc_get_value' Op, not %r." % Tout)
|
| 468 |
+
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
| 469 |
+
try:
|
| 470 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 471 |
+
"RpcGetValue", status_or=status_or, Tout=Tout, name=name)
|
| 472 |
+
except (TypeError, ValueError):
|
| 473 |
+
_result = _dispatch.dispatch(
|
| 474 |
+
rpc_get_value, (), dict(status_or=status_or, Tout=Tout, name=name)
|
| 475 |
+
)
|
| 476 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 477 |
+
return _result
|
| 478 |
+
raise
|
| 479 |
+
_result = _outputs[:]
|
| 480 |
+
if not _result:
|
| 481 |
+
return _op
|
| 482 |
+
if _execute.must_record_gradient():
|
| 483 |
+
_attrs = ("Tout", _op.get_attr("Tout"))
|
| 484 |
+
_inputs_flat = _op.inputs
|
| 485 |
+
_execute.record_gradient(
|
| 486 |
+
"RpcGetValue", _inputs_flat, _attrs, _result)
|
| 487 |
+
return _result
|
| 488 |
+
|
| 489 |
+
RpcGetValue = tf_export("raw_ops.RpcGetValue")(_ops.to_raw_op(rpc_get_value))
|
| 490 |
+
_dispatcher_for_rpc_get_value = rpc_get_value._tf_type_based_dispatcher.Dispatch
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def rpc_get_value_eager_fallback(status_or: Annotated[Any, _atypes.Resource], Tout, name, ctx):
|
| 494 |
+
if not isinstance(Tout, (list, tuple)):
|
| 495 |
+
raise TypeError(
|
| 496 |
+
"Expected list for 'Tout' argument to "
|
| 497 |
+
"'rpc_get_value' Op, not %r." % Tout)
|
| 498 |
+
Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
|
| 499 |
+
status_or = _ops.convert_to_tensor(status_or, _dtypes.resource)
|
| 500 |
+
_inputs_flat = [status_or]
|
| 501 |
+
_attrs = ("Tout", Tout)
|
| 502 |
+
_result = _execute.execute(b"RpcGetValue", len(Tout), inputs=_inputs_flat,
|
| 503 |
+
attrs=_attrs, ctx=ctx, name=name)
|
| 504 |
+
if _execute.must_record_gradient():
|
| 505 |
+
_execute.record_gradient(
|
| 506 |
+
"RpcGetValue", _inputs_flat, _attrs, _result)
|
| 507 |
+
return _result
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@_dispatch.add_fallback_dispatch_list
|
| 511 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 512 |
+
@tf_export('rpc_server')
|
| 513 |
+
def rpc_server(server_address: Annotated[Any, _atypes.String], name=None) -> Annotated[Any, _atypes.Resource]:
|
| 514 |
+
r"""TODO: add doc.
|
| 515 |
+
|
| 516 |
+
Args:
|
| 517 |
+
server_address: A `Tensor` of type `string`.
|
| 518 |
+
name: A name for the operation (optional).
|
| 519 |
+
|
| 520 |
+
Returns:
|
| 521 |
+
A `Tensor` of type `resource`.
|
| 522 |
+
"""
|
| 523 |
+
_ctx = _context._context or _context.context()
|
| 524 |
+
tld = _ctx._thread_local_data
|
| 525 |
+
if tld.is_eager:
|
| 526 |
+
try:
|
| 527 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 528 |
+
_ctx, "RpcServer", name, server_address)
|
| 529 |
+
return _result
|
| 530 |
+
except _core._NotOkStatusException as e:
|
| 531 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 532 |
+
except _core._FallbackException:
|
| 533 |
+
pass
|
| 534 |
+
try:
|
| 535 |
+
_result = _dispatcher_for_rpc_server(
|
| 536 |
+
(server_address, name,), None)
|
| 537 |
+
if _result is not NotImplemented:
|
| 538 |
+
return _result
|
| 539 |
+
return rpc_server_eager_fallback(
|
| 540 |
+
server_address, name=name, ctx=_ctx)
|
| 541 |
+
except _core._SymbolicException:
|
| 542 |
+
pass # Add nodes to the TensorFlow graph.
|
| 543 |
+
except (TypeError, ValueError):
|
| 544 |
+
_result = _dispatch.dispatch(
|
| 545 |
+
rpc_server, (), dict(server_address=server_address, name=name)
|
| 546 |
+
)
|
| 547 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 548 |
+
return _result
|
| 549 |
+
raise
|
| 550 |
+
else:
|
| 551 |
+
_result = _dispatcher_for_rpc_server(
|
| 552 |
+
(server_address, name,), None)
|
| 553 |
+
if _result is not NotImplemented:
|
| 554 |
+
return _result
|
| 555 |
+
# Add nodes to the TensorFlow graph.
|
| 556 |
+
try:
|
| 557 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 558 |
+
"RpcServer", server_address=server_address, name=name)
|
| 559 |
+
except (TypeError, ValueError):
|
| 560 |
+
_result = _dispatch.dispatch(
|
| 561 |
+
rpc_server, (), dict(server_address=server_address, name=name)
|
| 562 |
+
)
|
| 563 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 564 |
+
return _result
|
| 565 |
+
raise
|
| 566 |
+
_result = _outputs[:]
|
| 567 |
+
if _execute.must_record_gradient():
|
| 568 |
+
_attrs = ()
|
| 569 |
+
_inputs_flat = _op.inputs
|
| 570 |
+
_execute.record_gradient(
|
| 571 |
+
"RpcServer", _inputs_flat, _attrs, _result)
|
| 572 |
+
_result, = _result
|
| 573 |
+
return _result
|
| 574 |
+
|
| 575 |
+
RpcServer = tf_export("raw_ops.RpcServer")(_ops.to_raw_op(rpc_server))
|
| 576 |
+
_dispatcher_for_rpc_server = rpc_server._tf_type_based_dispatcher.Dispatch
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
def rpc_server_eager_fallback(server_address: Annotated[Any, _atypes.String], name, ctx) -> Annotated[Any, _atypes.Resource]:
|
| 580 |
+
server_address = _ops.convert_to_tensor(server_address, _dtypes.string)
|
| 581 |
+
_inputs_flat = [server_address]
|
| 582 |
+
_attrs = None
|
| 583 |
+
_result = _execute.execute(b"RpcServer", 1, inputs=_inputs_flat,
|
| 584 |
+
attrs=_attrs, ctx=ctx, name=name)
|
| 585 |
+
if _execute.must_record_gradient():
|
| 586 |
+
_execute.record_gradient(
|
| 587 |
+
"RpcServer", _inputs_flat, _attrs, _result)
|
| 588 |
+
_result, = _result
|
| 589 |
+
return _result
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
@_dispatch.add_fallback_dispatch_list
|
| 593 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 594 |
+
@tf_export('rpc_server_register')
|
| 595 |
+
def rpc_server_register(server: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], captured_inputs, f, output_specs: str, input_specs:str="", name=None):
|
| 596 |
+
r"""TODO: add doc.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
server: A `Tensor` of type `resource`.
|
| 600 |
+
method_name: A `Tensor` of type `string`.
|
| 601 |
+
captured_inputs: A list of `Tensor` objects.
|
| 602 |
+
f: A function decorated with @Defun.
|
| 603 |
+
output_specs: A `string`.
|
| 604 |
+
input_specs: An optional `string`. Defaults to `""`.
|
| 605 |
+
name: A name for the operation (optional).
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
The created Operation.
|
| 609 |
+
"""
|
| 610 |
+
_ctx = _context._context or _context.context()
|
| 611 |
+
tld = _ctx._thread_local_data
|
| 612 |
+
if tld.is_eager:
|
| 613 |
+
try:
|
| 614 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 615 |
+
_ctx, "RpcServerRegister", name, server, method_name, captured_inputs,
|
| 616 |
+
"f", f, "input_specs", input_specs, "output_specs", output_specs)
|
| 617 |
+
return _result
|
| 618 |
+
except _core._NotOkStatusException as e:
|
| 619 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 620 |
+
except _core._FallbackException:
|
| 621 |
+
pass
|
| 622 |
+
try:
|
| 623 |
+
_result = _dispatcher_for_rpc_server_register(
|
| 624 |
+
(server, method_name, captured_inputs, f, output_specs, input_specs,
|
| 625 |
+
name,), None)
|
| 626 |
+
if _result is not NotImplemented:
|
| 627 |
+
return _result
|
| 628 |
+
return rpc_server_register_eager_fallback(
|
| 629 |
+
server, method_name, captured_inputs, f=f, input_specs=input_specs,
|
| 630 |
+
output_specs=output_specs, name=name, ctx=_ctx)
|
| 631 |
+
except _core._SymbolicException:
|
| 632 |
+
pass # Add nodes to the TensorFlow graph.
|
| 633 |
+
except (TypeError, ValueError):
|
| 634 |
+
_result = _dispatch.dispatch(
|
| 635 |
+
rpc_server_register, (), dict(server=server,
|
| 636 |
+
method_name=method_name,
|
| 637 |
+
captured_inputs=captured_inputs,
|
| 638 |
+
f=f, output_specs=output_specs,
|
| 639 |
+
input_specs=input_specs, name=name)
|
| 640 |
+
)
|
| 641 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 642 |
+
return _result
|
| 643 |
+
raise
|
| 644 |
+
else:
|
| 645 |
+
_result = _dispatcher_for_rpc_server_register(
|
| 646 |
+
(server, method_name, captured_inputs, f, output_specs, input_specs,
|
| 647 |
+
name,), None)
|
| 648 |
+
if _result is not NotImplemented:
|
| 649 |
+
return _result
|
| 650 |
+
# Add nodes to the TensorFlow graph.
|
| 651 |
+
output_specs = _execute.make_str(output_specs, "output_specs")
|
| 652 |
+
if input_specs is None:
|
| 653 |
+
input_specs = ""
|
| 654 |
+
input_specs = _execute.make_str(input_specs, "input_specs")
|
| 655 |
+
try:
|
| 656 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 657 |
+
"RpcServerRegister", server=server, method_name=method_name,
|
| 658 |
+
captured_inputs=captured_inputs, f=f,
|
| 659 |
+
output_specs=output_specs,
|
| 660 |
+
input_specs=input_specs, name=name)
|
| 661 |
+
except (TypeError, ValueError):
|
| 662 |
+
_result = _dispatch.dispatch(
|
| 663 |
+
rpc_server_register, (), dict(server=server,
|
| 664 |
+
method_name=method_name,
|
| 665 |
+
captured_inputs=captured_inputs, f=f,
|
| 666 |
+
output_specs=output_specs,
|
| 667 |
+
input_specs=input_specs, name=name)
|
| 668 |
+
)
|
| 669 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 670 |
+
return _result
|
| 671 |
+
raise
|
| 672 |
+
return _op
|
| 673 |
+
RpcServerRegister = tf_export("raw_ops.RpcServerRegister")(_ops.to_raw_op(rpc_server_register))
|
| 674 |
+
_dispatcher_for_rpc_server_register = rpc_server_register._tf_type_based_dispatcher.Dispatch
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
def rpc_server_register_eager_fallback(server: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], captured_inputs, f, output_specs: str, input_specs: str, name, ctx):
|
| 678 |
+
output_specs = _execute.make_str(output_specs, "output_specs")
|
| 679 |
+
if input_specs is None:
|
| 680 |
+
input_specs = ""
|
| 681 |
+
input_specs = _execute.make_str(input_specs, "input_specs")
|
| 682 |
+
_attr_Tin, captured_inputs = _execute.convert_to_mixed_eager_tensors(captured_inputs, ctx)
|
| 683 |
+
server = _ops.convert_to_tensor(server, _dtypes.resource)
|
| 684 |
+
method_name = _ops.convert_to_tensor(method_name, _dtypes.string)
|
| 685 |
+
_inputs_flat = [server, method_name] + list(captured_inputs)
|
| 686 |
+
_attrs = ("Tin", _attr_Tin, "f", f, "input_specs", input_specs,
|
| 687 |
+
"output_specs", output_specs)
|
| 688 |
+
_result = _execute.execute(b"RpcServerRegister", 0, inputs=_inputs_flat,
|
| 689 |
+
attrs=_attrs, ctx=ctx, name=name)
|
| 690 |
+
_result = None
|
| 691 |
+
return _result
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
@_dispatch.add_fallback_dispatch_list
|
| 695 |
+
@_dispatch.add_type_based_api_dispatcher
|
| 696 |
+
@tf_export('rpc_server_start')
|
| 697 |
+
def rpc_server_start(server: Annotated[Any, _atypes.Resource], name=None):
|
| 698 |
+
r"""TODO: add doc.
|
| 699 |
+
|
| 700 |
+
Args:
|
| 701 |
+
server: A `Tensor` of type `resource`.
|
| 702 |
+
name: A name for the operation (optional).
|
| 703 |
+
|
| 704 |
+
Returns:
|
| 705 |
+
The created Operation.
|
| 706 |
+
"""
|
| 707 |
+
_ctx = _context._context or _context.context()
|
| 708 |
+
tld = _ctx._thread_local_data
|
| 709 |
+
if tld.is_eager:
|
| 710 |
+
try:
|
| 711 |
+
_result = pywrap_tfe.TFE_Py_FastPathExecute(
|
| 712 |
+
_ctx, "RpcServerStart", name, server)
|
| 713 |
+
return _result
|
| 714 |
+
except _core._NotOkStatusException as e:
|
| 715 |
+
_ops.raise_from_not_ok_status(e, name)
|
| 716 |
+
except _core._FallbackException:
|
| 717 |
+
pass
|
| 718 |
+
try:
|
| 719 |
+
_result = _dispatcher_for_rpc_server_start(
|
| 720 |
+
(server, name,), None)
|
| 721 |
+
if _result is not NotImplemented:
|
| 722 |
+
return _result
|
| 723 |
+
return rpc_server_start_eager_fallback(
|
| 724 |
+
server, name=name, ctx=_ctx)
|
| 725 |
+
except _core._SymbolicException:
|
| 726 |
+
pass # Add nodes to the TensorFlow graph.
|
| 727 |
+
except (TypeError, ValueError):
|
| 728 |
+
_result = _dispatch.dispatch(
|
| 729 |
+
rpc_server_start, (), dict(server=server, name=name)
|
| 730 |
+
)
|
| 731 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 732 |
+
return _result
|
| 733 |
+
raise
|
| 734 |
+
else:
|
| 735 |
+
_result = _dispatcher_for_rpc_server_start(
|
| 736 |
+
(server, name,), None)
|
| 737 |
+
if _result is not NotImplemented:
|
| 738 |
+
return _result
|
| 739 |
+
# Add nodes to the TensorFlow graph.
|
| 740 |
+
try:
|
| 741 |
+
_, _, _op, _outputs = _op_def_library._apply_op_helper(
|
| 742 |
+
"RpcServerStart", server=server, name=name)
|
| 743 |
+
except (TypeError, ValueError):
|
| 744 |
+
_result = _dispatch.dispatch(
|
| 745 |
+
rpc_server_start, (), dict(server=server, name=name)
|
| 746 |
+
)
|
| 747 |
+
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
|
| 748 |
+
return _result
|
| 749 |
+
raise
|
| 750 |
+
return _op
|
| 751 |
+
RpcServerStart = tf_export("raw_ops.RpcServerStart")(_ops.to_raw_op(rpc_server_start))
|
| 752 |
+
_dispatcher_for_rpc_server_start = rpc_server_start._tf_type_based_dispatcher.Dispatch
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def rpc_server_start_eager_fallback(server: Annotated[Any, _atypes.Resource], name, ctx):
|
| 756 |
+
server = _ops.convert_to_tensor(server, _dtypes.resource)
|
| 757 |
+
_inputs_flat = [server]
|
| 758 |
+
_attrs = None
|
| 759 |
+
_result = _execute.execute(b"RpcServerStart", 0, inputs=_inputs_flat,
|
| 760 |
+
attrs=_attrs, ctx=ctx, name=name)
|
| 761 |
+
_result = None
|
| 762 |
+
return _result
|
| 763 |
+
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (220 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2.cpython-310.pyc
ADDED
|
Binary file (2.03 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2_grpc.cpython-310.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/distribute/experimental/rpc/proto/tf_rpc_service.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
|
| 15 |
+
from tensorflow.core.protobuf import struct_pb2 as tensorflow_dot_core_dot_protobuf_dot_struct__pb2
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nAtensorflow/distribute/experimental/rpc/proto/tf_rpc_service.proto\x12\x0etensorflow.rpc\x1a&tensorflow/core/framework/tensor.proto\x1a%tensorflow/core/protobuf/struct.proto\"M\n\x0b\x43\x61llRequest\x12\x0e\n\x06method\x18\x01 \x01(\t\x12.\n\rinput_tensors\x18\x02 \x03(\x0b\x32\x17.tensorflow.TensorProto\"?\n\x0c\x43\x61llResponse\x12/\n\x0eoutput_tensors\x18\x01 \x03(\x0b\x32\x17.tensorflow.TensorProto\"\r\n\x0bListRequest\"\x87\x01\n\x10RegisteredMethod\x12\x0e\n\x06method\x18\x01 \x01(\t\x12\x30\n\x0binput_specs\x18\x02 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x31\n\x0coutput_specs\x18\x03 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\"L\n\x0cListResponse\x12<\n\x12registered_methods\x18\x01 \x03(\x0b\x32 .tensorflow.rpc.RegisteredMethod2\x96\x01\n\nRpcService\x12\x43\n\x04\x43\x61ll\x12\x1b.tensorflow.rpc.CallRequest\x1a\x1c.tensorflow.rpc.CallResponse\"\x00\x12\x43\n\x04List\x12\x1b.tensorflow.rpc.ListRequest\x1a\x1c.tensorflow.rpc.ListResponse\"\x00\x62\x06proto3')
|
| 19 |
+
|
| 20 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 21 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.distribute.experimental.rpc.proto.tf_rpc_service_pb2', globals())
|
| 22 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 23 |
+
|
| 24 |
+
DESCRIPTOR._options = None
|
| 25 |
+
_CALLREQUEST._serialized_start=164
|
| 26 |
+
_CALLREQUEST._serialized_end=241
|
| 27 |
+
_CALLRESPONSE._serialized_start=243
|
| 28 |
+
_CALLRESPONSE._serialized_end=306
|
| 29 |
+
_LISTREQUEST._serialized_start=308
|
| 30 |
+
_LISTREQUEST._serialized_end=321
|
| 31 |
+
_REGISTEREDMETHOD._serialized_start=324
|
| 32 |
+
_REGISTEREDMETHOD._serialized_end=459
|
| 33 |
+
_LISTRESPONSE._serialized_start=461
|
| 34 |
+
_LISTRESPONSE._serialized_end=537
|
| 35 |
+
_RPCSERVICE._serialized_start=540
|
| 36 |
+
_RPCSERVICE._serialized_end=690
|
| 37 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2_grpc.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
| 2 |
+
import grpc
|
| 3 |
+
|
| 4 |
+
from tensorflow.distribute.experimental.rpc.proto import tf_rpc_service_pb2 as tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RpcServiceStub(object):
|
| 8 |
+
# missing associated documentation comment in .proto file
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
def __init__(self, channel):
|
| 12 |
+
"""Constructor.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
channel: A grpc.Channel.
|
| 16 |
+
"""
|
| 17 |
+
self.Call = channel.unary_unary(
|
| 18 |
+
'/tensorflow.rpc.RpcService/Call',
|
| 19 |
+
request_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallRequest.SerializeToString,
|
| 20 |
+
response_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallResponse.FromString,
|
| 21 |
+
)
|
| 22 |
+
self.List = channel.unary_unary(
|
| 23 |
+
'/tensorflow.rpc.RpcService/List',
|
| 24 |
+
request_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListRequest.SerializeToString,
|
| 25 |
+
response_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListResponse.FromString,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class RpcServiceServicer(object):
|
| 30 |
+
# missing associated documentation comment in .proto file
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def Call(self, request, context):
|
| 34 |
+
"""RPC for invoking a registered function on remote server.
|
| 35 |
+
"""
|
| 36 |
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 37 |
+
context.set_details('Method not implemented!')
|
| 38 |
+
raise NotImplementedError('Method not implemented!')
|
| 39 |
+
|
| 40 |
+
def List(self, request, context):
|
| 41 |
+
"""RPC for listing available methods in a server.
|
| 42 |
+
"""
|
| 43 |
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 44 |
+
context.set_details('Method not implemented!')
|
| 45 |
+
raise NotImplementedError('Method not implemented!')
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def add_RpcServiceServicer_to_server(servicer, server):
|
| 49 |
+
rpc_method_handlers = {
|
| 50 |
+
'Call': grpc.unary_unary_rpc_method_handler(
|
| 51 |
+
servicer.Call,
|
| 52 |
+
request_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallRequest.FromString,
|
| 53 |
+
response_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallResponse.SerializeToString,
|
| 54 |
+
),
|
| 55 |
+
'List': grpc.unary_unary_rpc_method_handler(
|
| 56 |
+
servicer.List,
|
| 57 |
+
request_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListRequest.FromString,
|
| 58 |
+
response_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListResponse.SerializeToString,
|
| 59 |
+
),
|
| 60 |
+
}
|
| 61 |
+
generic_handler = grpc.method_handlers_generic_handler(
|
| 62 |
+
'tensorflow.rpc.RpcService', rpc_method_handlers)
|
| 63 |
+
server.add_generic_rpc_handlers((generic_handler,))
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/analyzer.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""This tool analyzes a TensorFlow Lite graph."""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
# pylint: disable=g-import-not-at-top
|
| 20 |
+
if not os.path.splitext(__file__)[0].endswith(
|
| 21 |
+
os.path.join("tflite_runtime", "analyzer")):
|
| 22 |
+
# This file is part of tensorflow package.
|
| 23 |
+
from tensorflow.compiler.mlir.lite.python import wrap_converter
|
| 24 |
+
from tensorflow.lite.python.analyzer_wrapper import _pywrap_analyzer_wrapper as _analyzer_wrapper
|
| 25 |
+
from tensorflow.python.util.tf_export import tf_export as _tf_export
|
| 26 |
+
else:
|
| 27 |
+
# This file is part of tflite_runtime package.
|
| 28 |
+
from tflite_runtime import _pywrap_analyzer_wrapper as _analyzer_wrapper
|
| 29 |
+
|
| 30 |
+
def _tf_export(*x, **kwargs):
|
| 31 |
+
del x, kwargs
|
| 32 |
+
return lambda x: x
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@_tf_export("lite.experimental.Analyzer")
|
| 36 |
+
class ModelAnalyzer():
|
| 37 |
+
"""Provides a collection of TFLite model analyzer tools.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
model = tf.keras.applications.MobileNetV3Large()
|
| 43 |
+
fb_model = tf.lite.TFLiteConverterV2.from_keras_model(model).convert()
|
| 44 |
+
tf.lite.experimental.Analyzer.analyze(model_content=fb_model)
|
| 45 |
+
# === TFLite ModelAnalyzer ===
|
| 46 |
+
#
|
| 47 |
+
# Your TFLite model has ‘1’ subgraph(s). In the subgraph description below,
|
| 48 |
+
# T# represents the Tensor numbers. For example, in Subgraph#0, the MUL op
|
| 49 |
+
# takes tensor #0 and tensor #19 as input and produces tensor #136 as output.
|
| 50 |
+
#
|
| 51 |
+
# Subgraph#0 main(T#0) -> [T#263]
|
| 52 |
+
# Op#0 MUL(T#0, T#19) -> [T#136]
|
| 53 |
+
# Op#1 ADD(T#136, T#18) -> [T#137]
|
| 54 |
+
# Op#2 CONV_2D(T#137, T#44, T#93) -> [T#138]
|
| 55 |
+
# Op#3 HARD_SWISH(T#138) -> [T#139]
|
| 56 |
+
# Op#4 DEPTHWISE_CONV_2D(T#139, T#94, T#24) -> [T#140]
|
| 57 |
+
# ...
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
WARNING: Experimental interface, subject to change.
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
@staticmethod
|
| 64 |
+
def analyze(model_path=None,
|
| 65 |
+
model_content=None,
|
| 66 |
+
gpu_compatibility=False,
|
| 67 |
+
**kwargs):
|
| 68 |
+
"""Analyzes the given tflite_model with dumping model structure.
|
| 69 |
+
|
| 70 |
+
This tool provides a way to understand users' TFLite flatbuffer model by
|
| 71 |
+
dumping internal graph structure. It also provides additional features
|
| 72 |
+
like checking GPU delegate compatibility.
|
| 73 |
+
|
| 74 |
+
WARNING: Experimental interface, subject to change.
|
| 75 |
+
The output format is not guaranteed to stay stable, so don't
|
| 76 |
+
write scripts to this.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
model_path: TFLite flatbuffer model path.
|
| 80 |
+
model_content: TFLite flatbuffer model object.
|
| 81 |
+
gpu_compatibility: Whether to check GPU delegate compatibility.
|
| 82 |
+
**kwargs: Experimental keyword arguments to analyze API.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Print analyzed report via console output.
|
| 86 |
+
"""
|
| 87 |
+
if not model_path and not model_content:
|
| 88 |
+
raise ValueError("neither `model_path` nor `model_content` is provided")
|
| 89 |
+
if model_path:
|
| 90 |
+
print(f"=== {model_path} ===\n")
|
| 91 |
+
tflite_model = model_path
|
| 92 |
+
input_is_filepath = True
|
| 93 |
+
else:
|
| 94 |
+
print("=== TFLite ModelAnalyzer ===\n")
|
| 95 |
+
tflite_model = model_content
|
| 96 |
+
input_is_filepath = False
|
| 97 |
+
|
| 98 |
+
if kwargs.get("experimental_use_mlir", False):
|
| 99 |
+
print(
|
| 100 |
+
wrap_converter.wrapped_flat_buffer_file_to_mlir(
|
| 101 |
+
tflite_model, input_is_filepath
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
print(
|
| 106 |
+
_analyzer_wrapper.ModelAnalyzer(tflite_model, input_is_filepath,
|
| 107 |
+
gpu_compatibility))
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/conversion_metadata_schema_py_generated.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import flatbuffers
|
| 2 |
+
|
| 3 |
+
# automatically generated by the FlatBuffers compiler, do not modify
|
| 4 |
+
|
| 5 |
+
# namespace: tflite
|
| 6 |
+
|
| 7 |
+
from flatbuffers.compat import import_numpy
|
| 8 |
+
np = import_numpy()
|
| 9 |
+
|
| 10 |
+
class ModelType(object):
|
| 11 |
+
NONE = 0
|
| 12 |
+
TF_SAVED_MODEL = 1
|
| 13 |
+
KERAS_MODEL = 2
|
| 14 |
+
TF_CONCRETE_FUNCTIONS = 3
|
| 15 |
+
TF_GRAPH_DEF = 4
|
| 16 |
+
TF_SESSION = 5
|
| 17 |
+
JAX = 6
|
| 18 |
+
PYTORCH = 7
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ModelOptimizationMode(object):
|
| 22 |
+
PTQ_FLOAT16 = 1001
|
| 23 |
+
PTQ_DYNAMIC_RANGE = 1002
|
| 24 |
+
PTQ_FULL_INTEGER = 1003
|
| 25 |
+
PTQ_INT16 = 1004
|
| 26 |
+
QUANTIZATION_AWARE_TRAINING = 2000
|
| 27 |
+
RANDOM_SPARSITY = 3001
|
| 28 |
+
BLOCK_SPARSITY = 3002
|
| 29 |
+
STRUCTURED_SPARSITY = 3003
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Environment(object):
|
| 33 |
+
__slots__ = ['_tab']
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def GetRootAs(cls, buf, offset=0):
|
| 37 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
| 38 |
+
x = Environment()
|
| 39 |
+
x.Init(buf, n + offset)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def GetRootAsEnvironment(cls, buf, offset=0):
|
| 44 |
+
"""This method is deprecated. Please switch to GetRootAs."""
|
| 45 |
+
return cls.GetRootAs(buf, offset)
|
| 46 |
+
# Environment
|
| 47 |
+
def Init(self, buf, pos):
|
| 48 |
+
self._tab = flatbuffers.table.Table(buf, pos)
|
| 49 |
+
|
| 50 |
+
# Environment
|
| 51 |
+
def TensorflowVersion(self):
|
| 52 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 53 |
+
if o != 0:
|
| 54 |
+
return self._tab.String(o + self._tab.Pos)
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
# Environment
|
| 58 |
+
def ApiVersion(self):
|
| 59 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
| 60 |
+
if o != 0:
|
| 61 |
+
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
|
| 62 |
+
return 0
|
| 63 |
+
|
| 64 |
+
# Environment
|
| 65 |
+
def ModelType(self):
|
| 66 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
| 67 |
+
if o != 0:
|
| 68 |
+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
|
| 69 |
+
return 0
|
| 70 |
+
|
| 71 |
+
# Environment
|
| 72 |
+
def ModelHash(self):
|
| 73 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
| 74 |
+
if o != 0:
|
| 75 |
+
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
|
| 76 |
+
return 0
|
| 77 |
+
|
| 78 |
+
def EnvironmentStart(builder):
|
| 79 |
+
builder.StartObject(4)
|
| 80 |
+
|
| 81 |
+
def EnvironmentAddTensorflowVersion(builder, tensorflowVersion):
|
| 82 |
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(tensorflowVersion), 0)
|
| 83 |
+
|
| 84 |
+
def EnvironmentAddApiVersion(builder, apiVersion):
|
| 85 |
+
builder.PrependUint32Slot(1, apiVersion, 0)
|
| 86 |
+
|
| 87 |
+
def EnvironmentAddModelType(builder, modelType):
|
| 88 |
+
builder.PrependInt32Slot(2, modelType, 0)
|
| 89 |
+
|
| 90 |
+
def EnvironmentAddModelHash(builder, modelHash):
|
| 91 |
+
builder.PrependUint64Slot(3, modelHash, 0)
|
| 92 |
+
|
| 93 |
+
def EnvironmentEnd(builder):
|
| 94 |
+
return builder.EndObject()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class EnvironmentT(object):
|
| 99 |
+
|
| 100 |
+
# EnvironmentT
|
| 101 |
+
def __init__(self):
|
| 102 |
+
self.tensorflowVersion = None # type: str
|
| 103 |
+
self.apiVersion = 0 # type: int
|
| 104 |
+
self.modelType = 0 # type: int
|
| 105 |
+
self.modelHash = 0 # type: int
|
| 106 |
+
|
| 107 |
+
@classmethod
|
| 108 |
+
def InitFromBuf(cls, buf, pos):
|
| 109 |
+
environment = Environment()
|
| 110 |
+
environment.Init(buf, pos)
|
| 111 |
+
return cls.InitFromObj(environment)
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def InitFromPackedBuf(cls, buf, pos=0):
|
| 115 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
|
| 116 |
+
return cls.InitFromBuf(buf, pos+n)
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def InitFromObj(cls, environment):
|
| 120 |
+
x = EnvironmentT()
|
| 121 |
+
x._UnPack(environment)
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
# EnvironmentT
|
| 125 |
+
def _UnPack(self, environment):
|
| 126 |
+
if environment is None:
|
| 127 |
+
return
|
| 128 |
+
self.tensorflowVersion = environment.TensorflowVersion()
|
| 129 |
+
self.apiVersion = environment.ApiVersion()
|
| 130 |
+
self.modelType = environment.ModelType()
|
| 131 |
+
self.modelHash = environment.ModelHash()
|
| 132 |
+
|
| 133 |
+
# EnvironmentT
|
| 134 |
+
def Pack(self, builder):
|
| 135 |
+
if self.tensorflowVersion is not None:
|
| 136 |
+
tensorflowVersion = builder.CreateString(self.tensorflowVersion)
|
| 137 |
+
EnvironmentStart(builder)
|
| 138 |
+
if self.tensorflowVersion is not None:
|
| 139 |
+
EnvironmentAddTensorflowVersion(builder, tensorflowVersion)
|
| 140 |
+
EnvironmentAddApiVersion(builder, self.apiVersion)
|
| 141 |
+
EnvironmentAddModelType(builder, self.modelType)
|
| 142 |
+
EnvironmentAddModelHash(builder, self.modelHash)
|
| 143 |
+
environment = EnvironmentEnd(builder)
|
| 144 |
+
return environment
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class SparsityBlockSize(object):
|
| 148 |
+
__slots__ = ['_tab']
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def GetRootAs(cls, buf, offset=0):
|
| 152 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
| 153 |
+
x = SparsityBlockSize()
|
| 154 |
+
x.Init(buf, n + offset)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
@classmethod
|
| 158 |
+
def GetRootAsSparsityBlockSize(cls, buf, offset=0):
|
| 159 |
+
"""This method is deprecated. Please switch to GetRootAs."""
|
| 160 |
+
return cls.GetRootAs(buf, offset)
|
| 161 |
+
# SparsityBlockSize
|
| 162 |
+
def Init(self, buf, pos):
|
| 163 |
+
self._tab = flatbuffers.table.Table(buf, pos)
|
| 164 |
+
|
| 165 |
+
# SparsityBlockSize
|
| 166 |
+
def Values(self, j):
|
| 167 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 168 |
+
if o != 0:
|
| 169 |
+
a = self._tab.Vector(o)
|
| 170 |
+
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
| 171 |
+
return 0
|
| 172 |
+
|
| 173 |
+
# SparsityBlockSize
|
| 174 |
+
def ValuesAsNumpy(self):
|
| 175 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 176 |
+
if o != 0:
|
| 177 |
+
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
|
| 178 |
+
return 0
|
| 179 |
+
|
| 180 |
+
# SparsityBlockSize
|
| 181 |
+
def ValuesLength(self):
|
| 182 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 183 |
+
if o != 0:
|
| 184 |
+
return self._tab.VectorLen(o)
|
| 185 |
+
return 0
|
| 186 |
+
|
| 187 |
+
# SparsityBlockSize
|
| 188 |
+
def ValuesIsNone(self):
|
| 189 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 190 |
+
return o == 0
|
| 191 |
+
|
| 192 |
+
def SparsityBlockSizeStart(builder):
|
| 193 |
+
builder.StartObject(1)
|
| 194 |
+
|
| 195 |
+
def SparsityBlockSizeAddValues(builder, values):
|
| 196 |
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(values), 0)
|
| 197 |
+
|
| 198 |
+
def SparsityBlockSizeStartValuesVector(builder, numElems):
|
| 199 |
+
return builder.StartVector(4, numElems, 4)
|
| 200 |
+
|
| 201 |
+
def SparsityBlockSizeEnd(builder):
|
| 202 |
+
return builder.EndObject()
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
from typing import List
|
| 207 |
+
except:
|
| 208 |
+
pass
|
| 209 |
+
|
| 210 |
+
class SparsityBlockSizeT(object):
|
| 211 |
+
|
| 212 |
+
# SparsityBlockSizeT
|
| 213 |
+
def __init__(self):
|
| 214 |
+
self.values = None # type: List[int]
|
| 215 |
+
|
| 216 |
+
@classmethod
|
| 217 |
+
def InitFromBuf(cls, buf, pos):
|
| 218 |
+
sparsityBlockSize = SparsityBlockSize()
|
| 219 |
+
sparsityBlockSize.Init(buf, pos)
|
| 220 |
+
return cls.InitFromObj(sparsityBlockSize)
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def InitFromPackedBuf(cls, buf, pos=0):
|
| 224 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
|
| 225 |
+
return cls.InitFromBuf(buf, pos+n)
|
| 226 |
+
|
| 227 |
+
@classmethod
|
| 228 |
+
def InitFromObj(cls, sparsityBlockSize):
|
| 229 |
+
x = SparsityBlockSizeT()
|
| 230 |
+
x._UnPack(sparsityBlockSize)
|
| 231 |
+
return x
|
| 232 |
+
|
| 233 |
+
# SparsityBlockSizeT
|
| 234 |
+
def _UnPack(self, sparsityBlockSize):
|
| 235 |
+
if sparsityBlockSize is None:
|
| 236 |
+
return
|
| 237 |
+
if not sparsityBlockSize.ValuesIsNone():
|
| 238 |
+
if np is None:
|
| 239 |
+
self.values = []
|
| 240 |
+
for i in range(sparsityBlockSize.ValuesLength()):
|
| 241 |
+
self.values.append(sparsityBlockSize.Values(i))
|
| 242 |
+
else:
|
| 243 |
+
self.values = sparsityBlockSize.ValuesAsNumpy()
|
| 244 |
+
|
| 245 |
+
# SparsityBlockSizeT
|
| 246 |
+
def Pack(self, builder):
|
| 247 |
+
if self.values is not None:
|
| 248 |
+
if np is not None and type(self.values) is np.ndarray:
|
| 249 |
+
values = builder.CreateNumpyVector(self.values)
|
| 250 |
+
else:
|
| 251 |
+
SparsityBlockSizeStartValuesVector(builder, len(self.values))
|
| 252 |
+
for i in reversed(range(len(self.values))):
|
| 253 |
+
builder.PrependUint32(self.values[i])
|
| 254 |
+
values = builder.EndVector()
|
| 255 |
+
SparsityBlockSizeStart(builder)
|
| 256 |
+
if self.values is not None:
|
| 257 |
+
SparsityBlockSizeAddValues(builder, values)
|
| 258 |
+
sparsityBlockSize = SparsityBlockSizeEnd(builder)
|
| 259 |
+
return sparsityBlockSize
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class ConversionOptions(object):
|
| 263 |
+
__slots__ = ['_tab']
|
| 264 |
+
|
| 265 |
+
@classmethod
|
| 266 |
+
def GetRootAs(cls, buf, offset=0):
|
| 267 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
| 268 |
+
x = ConversionOptions()
|
| 269 |
+
x.Init(buf, n + offset)
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
@classmethod
|
| 273 |
+
def GetRootAsConversionOptions(cls, buf, offset=0):
|
| 274 |
+
"""This method is deprecated. Please switch to GetRootAs."""
|
| 275 |
+
return cls.GetRootAs(buf, offset)
|
| 276 |
+
# ConversionOptions
|
| 277 |
+
def Init(self, buf, pos):
|
| 278 |
+
self._tab = flatbuffers.table.Table(buf, pos)
|
| 279 |
+
|
| 280 |
+
# ConversionOptions
|
| 281 |
+
def ModelOptimizationModes(self, j):
|
| 282 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 283 |
+
if o != 0:
|
| 284 |
+
a = self._tab.Vector(o)
|
| 285 |
+
return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
|
| 286 |
+
return 0
|
| 287 |
+
|
| 288 |
+
# ConversionOptions
|
| 289 |
+
def ModelOptimizationModesAsNumpy(self):
|
| 290 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 291 |
+
if o != 0:
|
| 292 |
+
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
|
| 293 |
+
return 0
|
| 294 |
+
|
| 295 |
+
# ConversionOptions
|
| 296 |
+
def ModelOptimizationModesLength(self):
|
| 297 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 298 |
+
if o != 0:
|
| 299 |
+
return self._tab.VectorLen(o)
|
| 300 |
+
return 0
|
| 301 |
+
|
| 302 |
+
# ConversionOptions
|
| 303 |
+
def ModelOptimizationModesIsNone(self):
|
| 304 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 305 |
+
return o == 0
|
| 306 |
+
|
| 307 |
+
# ConversionOptions
|
| 308 |
+
def AllowCustomOps(self):
|
| 309 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
| 310 |
+
if o != 0:
|
| 311 |
+
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
| 312 |
+
return False
|
| 313 |
+
|
| 314 |
+
# ConversionOptions
|
| 315 |
+
def EnableSelectTfOps(self):
|
| 316 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
|
| 317 |
+
if o != 0:
|
| 318 |
+
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
| 319 |
+
return False
|
| 320 |
+
|
| 321 |
+
# ConversionOptions
|
| 322 |
+
def ForceSelectTfOps(self):
|
| 323 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
|
| 324 |
+
if o != 0:
|
| 325 |
+
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
|
| 326 |
+
return False
|
| 327 |
+
|
| 328 |
+
# ConversionOptions
|
| 329 |
+
def SparsityBlockSizes(self, j):
|
| 330 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
| 331 |
+
if o != 0:
|
| 332 |
+
x = self._tab.Vector(o)
|
| 333 |
+
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
|
| 334 |
+
x = self._tab.Indirect(x)
|
| 335 |
+
obj = SparsityBlockSize()
|
| 336 |
+
obj.Init(self._tab.Bytes, x)
|
| 337 |
+
return obj
|
| 338 |
+
return None
|
| 339 |
+
|
| 340 |
+
# ConversionOptions
|
| 341 |
+
def SparsityBlockSizesLength(self):
|
| 342 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
| 343 |
+
if o != 0:
|
| 344 |
+
return self._tab.VectorLen(o)
|
| 345 |
+
return 0
|
| 346 |
+
|
| 347 |
+
# ConversionOptions
|
| 348 |
+
def SparsityBlockSizesIsNone(self):
|
| 349 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
|
| 350 |
+
return o == 0
|
| 351 |
+
|
| 352 |
+
def ConversionOptionsStart(builder):
|
| 353 |
+
builder.StartObject(5)
|
| 354 |
+
|
| 355 |
+
def ConversionOptionsAddModelOptimizationModes(builder, modelOptimizationModes):
|
| 356 |
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(modelOptimizationModes), 0)
|
| 357 |
+
|
| 358 |
+
def ConversionOptionsStartModelOptimizationModesVector(builder, numElems):
|
| 359 |
+
return builder.StartVector(4, numElems, 4)
|
| 360 |
+
|
| 361 |
+
def ConversionOptionsAddAllowCustomOps(builder, allowCustomOps):
|
| 362 |
+
builder.PrependBoolSlot(1, allowCustomOps, 0)
|
| 363 |
+
|
| 364 |
+
def ConversionOptionsAddEnableSelectTfOps(builder, enableSelectTfOps):
|
| 365 |
+
builder.PrependBoolSlot(2, enableSelectTfOps, 0)
|
| 366 |
+
|
| 367 |
+
def ConversionOptionsAddForceSelectTfOps(builder, forceSelectTfOps):
|
| 368 |
+
builder.PrependBoolSlot(3, forceSelectTfOps, 0)
|
| 369 |
+
|
| 370 |
+
def ConversionOptionsAddSparsityBlockSizes(builder, sparsityBlockSizes):
|
| 371 |
+
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(sparsityBlockSizes), 0)
|
| 372 |
+
|
| 373 |
+
def ConversionOptionsStartSparsityBlockSizesVector(builder, numElems):
|
| 374 |
+
return builder.StartVector(4, numElems, 4)
|
| 375 |
+
|
| 376 |
+
def ConversionOptionsEnd(builder):
|
| 377 |
+
return builder.EndObject()
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
try:
|
| 381 |
+
from typing import List
|
| 382 |
+
except:
|
| 383 |
+
pass
|
| 384 |
+
|
| 385 |
+
class ConversionOptionsT(object):
|
| 386 |
+
|
| 387 |
+
# ConversionOptionsT
|
| 388 |
+
def __init__(self):
|
| 389 |
+
self.modelOptimizationModes = None # type: List[int]
|
| 390 |
+
self.allowCustomOps = False # type: bool
|
| 391 |
+
self.enableSelectTfOps = False # type: bool
|
| 392 |
+
self.forceSelectTfOps = False # type: bool
|
| 393 |
+
self.sparsityBlockSizes = None # type: List[SparsityBlockSizeT]
|
| 394 |
+
|
| 395 |
+
@classmethod
|
| 396 |
+
def InitFromBuf(cls, buf, pos):
|
| 397 |
+
conversionOptions = ConversionOptions()
|
| 398 |
+
conversionOptions.Init(buf, pos)
|
| 399 |
+
return cls.InitFromObj(conversionOptions)
|
| 400 |
+
|
| 401 |
+
@classmethod
|
| 402 |
+
def InitFromPackedBuf(cls, buf, pos=0):
|
| 403 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
|
| 404 |
+
return cls.InitFromBuf(buf, pos+n)
|
| 405 |
+
|
| 406 |
+
@classmethod
|
| 407 |
+
def InitFromObj(cls, conversionOptions):
|
| 408 |
+
x = ConversionOptionsT()
|
| 409 |
+
x._UnPack(conversionOptions)
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
# ConversionOptionsT
|
| 413 |
+
def _UnPack(self, conversionOptions):
|
| 414 |
+
if conversionOptions is None:
|
| 415 |
+
return
|
| 416 |
+
if not conversionOptions.ModelOptimizationModesIsNone():
|
| 417 |
+
if np is None:
|
| 418 |
+
self.modelOptimizationModes = []
|
| 419 |
+
for i in range(conversionOptions.ModelOptimizationModesLength()):
|
| 420 |
+
self.modelOptimizationModes.append(conversionOptions.ModelOptimizationModes(i))
|
| 421 |
+
else:
|
| 422 |
+
self.modelOptimizationModes = conversionOptions.ModelOptimizationModesAsNumpy()
|
| 423 |
+
self.allowCustomOps = conversionOptions.AllowCustomOps()
|
| 424 |
+
self.enableSelectTfOps = conversionOptions.EnableSelectTfOps()
|
| 425 |
+
self.forceSelectTfOps = conversionOptions.ForceSelectTfOps()
|
| 426 |
+
if not conversionOptions.SparsityBlockSizesIsNone():
|
| 427 |
+
self.sparsityBlockSizes = []
|
| 428 |
+
for i in range(conversionOptions.SparsityBlockSizesLength()):
|
| 429 |
+
if conversionOptions.SparsityBlockSizes(i) is None:
|
| 430 |
+
self.sparsityBlockSizes.append(None)
|
| 431 |
+
else:
|
| 432 |
+
sparsityBlockSize_ = SparsityBlockSizeT.InitFromObj(conversionOptions.SparsityBlockSizes(i))
|
| 433 |
+
self.sparsityBlockSizes.append(sparsityBlockSize_)
|
| 434 |
+
|
| 435 |
+
# ConversionOptionsT
|
| 436 |
+
def Pack(self, builder):
|
| 437 |
+
if self.modelOptimizationModes is not None:
|
| 438 |
+
if np is not None and type(self.modelOptimizationModes) is np.ndarray:
|
| 439 |
+
modelOptimizationModes = builder.CreateNumpyVector(self.modelOptimizationModes)
|
| 440 |
+
else:
|
| 441 |
+
ConversionOptionsStartModelOptimizationModesVector(builder, len(self.modelOptimizationModes))
|
| 442 |
+
for i in reversed(range(len(self.modelOptimizationModes))):
|
| 443 |
+
builder.PrependInt32(self.modelOptimizationModes[i])
|
| 444 |
+
modelOptimizationModes = builder.EndVector()
|
| 445 |
+
if self.sparsityBlockSizes is not None:
|
| 446 |
+
sparsityBlockSizeslist = []
|
| 447 |
+
for i in range(len(self.sparsityBlockSizes)):
|
| 448 |
+
sparsityBlockSizeslist.append(self.sparsityBlockSizes[i].Pack(builder))
|
| 449 |
+
ConversionOptionsStartSparsityBlockSizesVector(builder, len(self.sparsityBlockSizes))
|
| 450 |
+
for i in reversed(range(len(self.sparsityBlockSizes))):
|
| 451 |
+
builder.PrependUOffsetTRelative(sparsityBlockSizeslist[i])
|
| 452 |
+
sparsityBlockSizes = builder.EndVector()
|
| 453 |
+
ConversionOptionsStart(builder)
|
| 454 |
+
if self.modelOptimizationModes is not None:
|
| 455 |
+
ConversionOptionsAddModelOptimizationModes(builder, modelOptimizationModes)
|
| 456 |
+
ConversionOptionsAddAllowCustomOps(builder, self.allowCustomOps)
|
| 457 |
+
ConversionOptionsAddEnableSelectTfOps(builder, self.enableSelectTfOps)
|
| 458 |
+
ConversionOptionsAddForceSelectTfOps(builder, self.forceSelectTfOps)
|
| 459 |
+
if self.sparsityBlockSizes is not None:
|
| 460 |
+
ConversionOptionsAddSparsityBlockSizes(builder, sparsityBlockSizes)
|
| 461 |
+
conversionOptions = ConversionOptionsEnd(builder)
|
| 462 |
+
return conversionOptions
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
class ConversionMetadata(object):
|
| 466 |
+
__slots__ = ['_tab']
|
| 467 |
+
|
| 468 |
+
@classmethod
|
| 469 |
+
def GetRootAs(cls, buf, offset=0):
|
| 470 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
|
| 471 |
+
x = ConversionMetadata()
|
| 472 |
+
x.Init(buf, n + offset)
|
| 473 |
+
return x
|
| 474 |
+
|
| 475 |
+
@classmethod
|
| 476 |
+
def GetRootAsConversionMetadata(cls, buf, offset=0):
|
| 477 |
+
"""This method is deprecated. Please switch to GetRootAs."""
|
| 478 |
+
return cls.GetRootAs(buf, offset)
|
| 479 |
+
# ConversionMetadata
|
| 480 |
+
def Init(self, buf, pos):
|
| 481 |
+
self._tab = flatbuffers.table.Table(buf, pos)
|
| 482 |
+
|
| 483 |
+
# ConversionMetadata
|
| 484 |
+
def Environment(self):
|
| 485 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
|
| 486 |
+
if o != 0:
|
| 487 |
+
x = self._tab.Indirect(o + self._tab.Pos)
|
| 488 |
+
obj = Environment()
|
| 489 |
+
obj.Init(self._tab.Bytes, x)
|
| 490 |
+
return obj
|
| 491 |
+
return None
|
| 492 |
+
|
| 493 |
+
# ConversionMetadata
|
| 494 |
+
def Options(self):
|
| 495 |
+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
|
| 496 |
+
if o != 0:
|
| 497 |
+
x = self._tab.Indirect(o + self._tab.Pos)
|
| 498 |
+
obj = ConversionOptions()
|
| 499 |
+
obj.Init(self._tab.Bytes, x)
|
| 500 |
+
return obj
|
| 501 |
+
return None
|
| 502 |
+
|
| 503 |
+
def ConversionMetadataStart(builder):
|
| 504 |
+
builder.StartObject(2)
|
| 505 |
+
|
| 506 |
+
def ConversionMetadataAddEnvironment(builder, environment):
|
| 507 |
+
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(environment), 0)
|
| 508 |
+
|
| 509 |
+
def ConversionMetadataAddOptions(builder, options):
|
| 510 |
+
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(options), 0)
|
| 511 |
+
|
| 512 |
+
def ConversionMetadataEnd(builder):
|
| 513 |
+
return builder.EndObject()
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
try:
|
| 517 |
+
from typing import Optional
|
| 518 |
+
except:
|
| 519 |
+
pass
|
| 520 |
+
|
| 521 |
+
class ConversionMetadataT(object):
|
| 522 |
+
|
| 523 |
+
# ConversionMetadataT
|
| 524 |
+
def __init__(self):
|
| 525 |
+
self.environment = None # type: Optional[EnvironmentT]
|
| 526 |
+
self.options = None # type: Optional[ConversionOptionsT]
|
| 527 |
+
|
| 528 |
+
@classmethod
|
| 529 |
+
def InitFromBuf(cls, buf, pos):
|
| 530 |
+
conversionMetadata = ConversionMetadata()
|
| 531 |
+
conversionMetadata.Init(buf, pos)
|
| 532 |
+
return cls.InitFromObj(conversionMetadata)
|
| 533 |
+
|
| 534 |
+
@classmethod
|
| 535 |
+
def InitFromPackedBuf(cls, buf, pos=0):
|
| 536 |
+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
|
| 537 |
+
return cls.InitFromBuf(buf, pos+n)
|
| 538 |
+
|
| 539 |
+
@classmethod
|
| 540 |
+
def InitFromObj(cls, conversionMetadata):
|
| 541 |
+
x = ConversionMetadataT()
|
| 542 |
+
x._UnPack(conversionMetadata)
|
| 543 |
+
return x
|
| 544 |
+
|
| 545 |
+
# ConversionMetadataT
|
| 546 |
+
def _UnPack(self, conversionMetadata):
|
| 547 |
+
if conversionMetadata is None:
|
| 548 |
+
return
|
| 549 |
+
if conversionMetadata.Environment() is not None:
|
| 550 |
+
self.environment = EnvironmentT.InitFromObj(conversionMetadata.Environment())
|
| 551 |
+
if conversionMetadata.Options() is not None:
|
| 552 |
+
self.options = ConversionOptionsT.InitFromObj(conversionMetadata.Options())
|
| 553 |
+
|
| 554 |
+
# ConversionMetadataT
|
| 555 |
+
def Pack(self, builder):
|
| 556 |
+
if self.environment is not None:
|
| 557 |
+
environment = self.environment.Pack(builder)
|
| 558 |
+
if self.options is not None:
|
| 559 |
+
options = self.options.Pack(builder)
|
| 560 |
+
ConversionMetadataStart(builder)
|
| 561 |
+
if self.environment is not None:
|
| 562 |
+
ConversionMetadataAddEnvironment(builder, environment)
|
| 563 |
+
if self.options is not None:
|
| 564 |
+
ConversionMetadataAddOptions(builder, options)
|
| 565 |
+
conversionMetadata = ConversionMetadataEnd(builder)
|
| 566 |
+
return conversionMetadata
|
| 567 |
+
|
| 568 |
+
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/convert_phase.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utilities for collecting TFLite metrics."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import enum
|
| 19 |
+
import functools
|
| 20 |
+
from typing import Text
|
| 21 |
+
|
| 22 |
+
from tensorflow.compiler.mlir.lite.metrics import converter_error_data_pb2
|
| 23 |
+
from tensorflow.lite.python.metrics import metrics
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Component(enum.Enum):
|
| 27 |
+
"""Enum class defining name of the converter components."""
|
| 28 |
+
# Validate the given input and prepare and optimize TensorFlow Model.
|
| 29 |
+
PREPARE_TF_MODEL = "PREPARE_TF_MODEL"
|
| 30 |
+
|
| 31 |
+
# Convert to TFLite model format.
|
| 32 |
+
CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL"
|
| 33 |
+
|
| 34 |
+
# RUN quantization and sparsification.
|
| 35 |
+
OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
SubComponentItem = collections.namedtuple("SubComponentItem",
|
| 39 |
+
["name", "component"])
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SubComponent(SubComponentItem, enum.Enum):
|
| 43 |
+
"""Enum class defining name of the converter subcomponents.
|
| 44 |
+
|
| 45 |
+
This enum only defines the subcomponents in Python, there might be more
|
| 46 |
+
subcomponents defined in C++.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __str__(self):
|
| 50 |
+
return self.value.name
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def name(self):
|
| 54 |
+
return self.value.name
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def component(self):
|
| 58 |
+
return self.value.component
|
| 59 |
+
|
| 60 |
+
# The subcomponent name is unspecified.
|
| 61 |
+
UNSPECIFIED = SubComponentItem("UNSPECIFIED", None)
|
| 62 |
+
|
| 63 |
+
# Valid the given input and parameters.
|
| 64 |
+
VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS",
|
| 65 |
+
Component.PREPARE_TF_MODEL)
|
| 66 |
+
|
| 67 |
+
# Load GraphDef from SavedModel.
|
| 68 |
+
LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL",
|
| 69 |
+
Component.PREPARE_TF_MODEL)
|
| 70 |
+
|
| 71 |
+
# Convert a SavedModel to frozen graph.
|
| 72 |
+
FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL",
|
| 73 |
+
Component.PREPARE_TF_MODEL)
|
| 74 |
+
|
| 75 |
+
# Save a Keras model to SavedModel.
|
| 76 |
+
CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem(
|
| 77 |
+
"CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
|
| 78 |
+
|
| 79 |
+
# Save Concrete functions to SavedModel.
|
| 80 |
+
CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem(
|
| 81 |
+
"CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
|
| 82 |
+
|
| 83 |
+
# Convert a Keras model to a frozen graph.
|
| 84 |
+
FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL",
|
| 85 |
+
Component.PREPARE_TF_MODEL)
|
| 86 |
+
|
| 87 |
+
# Replace all the variables with constants in a ConcreteFunction.
|
| 88 |
+
FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION",
|
| 89 |
+
Component.PREPARE_TF_MODEL)
|
| 90 |
+
|
| 91 |
+
# Run grappler optimization.
|
| 92 |
+
OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL",
|
| 93 |
+
Component.PREPARE_TF_MODEL)
|
| 94 |
+
|
| 95 |
+
# Convert using the old TOCO converter.
|
| 96 |
+
CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem(
|
| 97 |
+
"CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER",
|
| 98 |
+
Component.CONVERT_TF_TO_TFLITE_MODEL)
|
| 99 |
+
|
| 100 |
+
# Convert a GraphDef to TFLite model.
|
| 101 |
+
CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF",
|
| 102 |
+
Component.CONVERT_TF_TO_TFLITE_MODEL)
|
| 103 |
+
|
| 104 |
+
# Convert a SavedModel to TFLite model.
|
| 105 |
+
CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL",
|
| 106 |
+
Component.CONVERT_TF_TO_TFLITE_MODEL)
|
| 107 |
+
|
| 108 |
+
# Convert a Jax HLO to TFLite model.
|
| 109 |
+
CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO",
|
| 110 |
+
Component.CONVERT_TF_TO_TFLITE_MODEL)
|
| 111 |
+
|
| 112 |
+
# Do quantization by the deprecated quantizer.
|
| 113 |
+
QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem(
|
| 114 |
+
"QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL)
|
| 115 |
+
|
| 116 |
+
# Do calibration.
|
| 117 |
+
CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL)
|
| 118 |
+
|
| 119 |
+
# Do quantization by MLIR.
|
| 120 |
+
QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL)
|
| 121 |
+
|
| 122 |
+
# Do sparsification by MLIR.
|
| 123 |
+
SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ConverterError(Exception):
|
| 127 |
+
"""Raised when an error occurs during model conversion."""
|
| 128 |
+
|
| 129 |
+
def __init__(self, message):
|
| 130 |
+
super(ConverterError, self).__init__(message)
|
| 131 |
+
self.errors = []
|
| 132 |
+
self._parse_error_message(message)
|
| 133 |
+
|
| 134 |
+
def append_error(self,
|
| 135 |
+
error_data: converter_error_data_pb2.ConverterErrorData):
|
| 136 |
+
self.errors.append(error_data)
|
| 137 |
+
|
| 138 |
+
def _parse_error_message(self, message):
|
| 139 |
+
"""If the message matches a pattern, assigns the associated error code.
|
| 140 |
+
|
| 141 |
+
It is difficult to assign an error code to some errrors in MLIR side, Ex:
|
| 142 |
+
errors thrown by other components than TFLite or not using mlir::emitError.
|
| 143 |
+
This function try to detect them by the error message and assign the
|
| 144 |
+
corresponding error code.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
message: The error message of this exception.
|
| 148 |
+
"""
|
| 149 |
+
error_code_mapping = {
|
| 150 |
+
"Failed to functionalize Control Flow V1 ops. Consider using Control "
|
| 151 |
+
"Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/"
|
| 152 |
+
"tf/compat/v1/enable_control_flow_v2.":
|
| 153 |
+
converter_error_data_pb2.ConverterErrorData
|
| 154 |
+
.ERROR_UNSUPPORTED_CONTROL_FLOW_V1,
|
| 155 |
+
}
|
| 156 |
+
for pattern, error_code in error_code_mapping.items():
|
| 157 |
+
if pattern in message:
|
| 158 |
+
error_data = converter_error_data_pb2.ConverterErrorData()
|
| 159 |
+
error_data.error_message = message
|
| 160 |
+
error_data.error_code = error_code
|
| 161 |
+
self.append_error(error_data)
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED):
|
| 166 |
+
"""The decorator to identify converter component and subcomponent.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
component: Converter component name.
|
| 170 |
+
subcomponent: Converter subcomponent name.
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Forward the result from the wrapped function.
|
| 174 |
+
|
| 175 |
+
Raises:
|
| 176 |
+
ValueError: if component and subcomponent name is not valid.
|
| 177 |
+
"""
|
| 178 |
+
if component not in Component:
|
| 179 |
+
raise ValueError("Given component name not found")
|
| 180 |
+
if subcomponent not in SubComponent:
|
| 181 |
+
raise ValueError("Given subcomponent name not found")
|
| 182 |
+
if (subcomponent != SubComponent.UNSPECIFIED and
|
| 183 |
+
subcomponent.component != component):
|
| 184 |
+
raise ValueError("component and subcomponent name don't match")
|
| 185 |
+
|
| 186 |
+
def report_error(error_data: converter_error_data_pb2.ConverterErrorData):
|
| 187 |
+
# Always overwrites the component information, but only overwrites the
|
| 188 |
+
# subcomponent if it is not available.
|
| 189 |
+
error_data.component = component.value
|
| 190 |
+
if not error_data.subcomponent:
|
| 191 |
+
error_data.subcomponent = subcomponent.name
|
| 192 |
+
tflite_metrics = metrics.TFLiteConverterMetrics()
|
| 193 |
+
tflite_metrics.set_converter_error(error_data)
|
| 194 |
+
|
| 195 |
+
def report_error_message(error_message: Text):
|
| 196 |
+
error_data = converter_error_data_pb2.ConverterErrorData()
|
| 197 |
+
error_data.error_message = error_message
|
| 198 |
+
report_error(error_data)
|
| 199 |
+
|
| 200 |
+
def actual_decorator(func):
|
| 201 |
+
|
| 202 |
+
@functools.wraps(func)
|
| 203 |
+
def wrapper(*args, **kwargs):
|
| 204 |
+
try:
|
| 205 |
+
return func(*args, **kwargs)
|
| 206 |
+
except ConverterError as converter_error:
|
| 207 |
+
if converter_error.errors:
|
| 208 |
+
for error_data in converter_error.errors:
|
| 209 |
+
report_error(error_data)
|
| 210 |
+
else:
|
| 211 |
+
report_error_message(str(converter_error))
|
| 212 |
+
raise converter_error from None # Re-throws the exception.
|
| 213 |
+
except Exception as error:
|
| 214 |
+
report_error_message(str(error))
|
| 215 |
+
raise error from None # Re-throws the exception.
|
| 216 |
+
|
| 217 |
+
return wrapper
|
| 218 |
+
|
| 219 |
+
return actual_decorator
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/schema_util.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Schema utilities to get builtin code from operator code."""
|
| 16 |
+
|
| 17 |
+
from tensorflow.python.util import all_util
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_builtin_code_from_operator_code(opcode):
|
| 21 |
+
"""Return the builtin code of the given operator code.
|
| 22 |
+
|
| 23 |
+
The following method is introduced to resolve op builtin code shortage
|
| 24 |
+
problem. The new builtin operator will be assigned to the extended builtin
|
| 25 |
+
code field in the flatbuffer schema. Those methods helps to hide builtin code
|
| 26 |
+
details.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
opcode: Operator code.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
The builtin code of the given operator code.
|
| 33 |
+
"""
|
| 34 |
+
# Access BuiltinCode() method first if available.
|
| 35 |
+
if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode):
|
| 36 |
+
return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode())
|
| 37 |
+
|
| 38 |
+
return max(opcode.builtinCode, opcode.deprecatedBuiltinCode)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
_allowed_symbols = [
|
| 42 |
+
'get_builtin_code_from_operator_code',
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
all_util.remove_undocumented(__name__, _allowed_symbols)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/tflite_convert.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Python command line interface for converting TF models to TFLite models."""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import warnings
|
| 21 |
+
|
| 22 |
+
from absl import app
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
|
| 25 |
+
from tensorflow.lite.python import lite
|
| 26 |
+
from tensorflow.lite.python.convert import register_custom_opdefs
|
| 27 |
+
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
|
| 28 |
+
from tensorflow.lite.toco.logging import gen_html
|
| 29 |
+
from tensorflow.python import tf2
|
| 30 |
+
from tensorflow.python.framework import dtypes
|
| 31 |
+
from tensorflow.python.platform import gfile
|
| 32 |
+
from tensorflow.python.util import keras_deps
|
| 33 |
+
|
| 34 |
+
# Needed to enable TF2 by default.
|
| 35 |
+
|
| 36 |
+
_ = tf.keras.models.save_model # ensure necessary imports are executed
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _parse_array(values, type_fn=str):
|
| 40 |
+
if values is not None:
|
| 41 |
+
return [type_fn(val) for val in values.split(",") if val]
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _parse_set(values):
|
| 46 |
+
if values is not None:
|
| 47 |
+
return set([item for item in values.split(",") if item])
|
| 48 |
+
return None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _parse_inference_type(value, flag):
|
| 52 |
+
"""Converts the inference type to the value of the constant.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
value: str representing the inference type.
|
| 56 |
+
flag: str representing the flag name.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
tf.dtype.
|
| 60 |
+
|
| 61 |
+
Raises:
|
| 62 |
+
ValueError: Unsupported value.
|
| 63 |
+
"""
|
| 64 |
+
if value == "FLOAT":
|
| 65 |
+
return dtypes.float32
|
| 66 |
+
if value == "INT8":
|
| 67 |
+
return dtypes.int8
|
| 68 |
+
if value == "UINT8" or value == "QUANTIZED_UINT8":
|
| 69 |
+
return dtypes.uint8
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Unsupported value for `{}` flag. Expected FLOAT, INT8, UINT8, or "
|
| 72 |
+
"QUANTIZED_UINT8 instead got {}.".format(flag, value))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class _ParseBooleanFlag(argparse.Action):
|
| 76 |
+
"""Helper class to parse boolean flag that optionally accepts truth value."""
|
| 77 |
+
|
| 78 |
+
def __init__(self, option_strings, dest, nargs=None, **kwargs):
|
| 79 |
+
if nargs != "?":
|
| 80 |
+
# This should never happen. This class is only used once below with
|
| 81 |
+
# nargs="?".
|
| 82 |
+
raise ValueError(
|
| 83 |
+
"This parser only supports nargs='?' (0 or 1 additional arguments)")
|
| 84 |
+
super(_ParseBooleanFlag, self).__init__(
|
| 85 |
+
option_strings, dest, nargs=nargs, **kwargs)
|
| 86 |
+
|
| 87 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
| 88 |
+
if values is None:
|
| 89 |
+
# Handling `--boolean_flag`.
|
| 90 |
+
# Without additional arguments, it implies true.
|
| 91 |
+
flag_value = True
|
| 92 |
+
elif values.lower() == "true":
|
| 93 |
+
# Handling `--boolean_flag=true`.
|
| 94 |
+
# (Case insensitive after the equal sign)
|
| 95 |
+
flag_value = True
|
| 96 |
+
elif values.lower() == "false":
|
| 97 |
+
# Handling `--boolean_flag=false`.
|
| 98 |
+
# (Case insensitive after the equal sign)
|
| 99 |
+
flag_value = False
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError("Invalid argument to --{}. Must use flag alone,"
|
| 102 |
+
" or specify true/false.".format(self.dest))
|
| 103 |
+
setattr(namespace, self.dest, flag_value)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _get_tflite_converter(flags):
|
| 107 |
+
"""Makes a TFLiteConverter object based on the flags provided.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
flags: argparse.Namespace object containing TFLite flags.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
TFLiteConverter object.
|
| 114 |
+
|
| 115 |
+
Raises:
|
| 116 |
+
ValueError: Invalid flags.
|
| 117 |
+
"""
|
| 118 |
+
# Parse input and output arrays.
|
| 119 |
+
input_arrays = _parse_array(flags.input_arrays)
|
| 120 |
+
input_shapes = None
|
| 121 |
+
if flags.input_shapes:
|
| 122 |
+
input_shapes_list = [
|
| 123 |
+
_parse_array(shape, type_fn=int)
|
| 124 |
+
for shape in flags.input_shapes.split(":")
|
| 125 |
+
]
|
| 126 |
+
input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
|
| 127 |
+
output_arrays = _parse_array(flags.output_arrays)
|
| 128 |
+
|
| 129 |
+
converter_kwargs = {
|
| 130 |
+
"input_arrays": input_arrays,
|
| 131 |
+
"input_shapes": input_shapes,
|
| 132 |
+
"output_arrays": output_arrays
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
# Create TFLiteConverter.
|
| 136 |
+
if flags.graph_def_file:
|
| 137 |
+
converter_fn = lite.TFLiteConverter.from_frozen_graph
|
| 138 |
+
converter_kwargs["graph_def_file"] = flags.graph_def_file
|
| 139 |
+
elif flags.saved_model_dir:
|
| 140 |
+
converter_fn = lite.TFLiteConverter.from_saved_model
|
| 141 |
+
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
|
| 142 |
+
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
|
| 143 |
+
converter_kwargs["signature_key"] = flags.saved_model_signature_key
|
| 144 |
+
elif flags.keras_model_file:
|
| 145 |
+
converter_fn = lite.TFLiteConverter.from_keras_model_file
|
| 146 |
+
converter_kwargs["model_file"] = flags.keras_model_file
|
| 147 |
+
else:
|
| 148 |
+
raise ValueError("--graph_def_file, --saved_model_dir, or "
|
| 149 |
+
"--keras_model_file must be specified.")
|
| 150 |
+
|
| 151 |
+
return converter_fn(**converter_kwargs)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _convert_tf1_model(flags):
|
| 155 |
+
"""Calls function to convert the TensorFlow 1.X model into a TFLite model.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
flags: argparse.Namespace object.
|
| 159 |
+
|
| 160 |
+
Raises:
|
| 161 |
+
ValueError: Invalid flags.
|
| 162 |
+
"""
|
| 163 |
+
# Register custom opdefs before converter object creation.
|
| 164 |
+
if flags.custom_opdefs:
|
| 165 |
+
register_custom_opdefs(_parse_array(flags.custom_opdefs))
|
| 166 |
+
|
| 167 |
+
# Create converter.
|
| 168 |
+
converter = _get_tflite_converter(flags)
|
| 169 |
+
if flags.inference_type:
|
| 170 |
+
converter.inference_type = _parse_inference_type(flags.inference_type,
|
| 171 |
+
"inference_type")
|
| 172 |
+
if flags.inference_input_type:
|
| 173 |
+
converter.inference_input_type = _parse_inference_type(
|
| 174 |
+
flags.inference_input_type, "inference_input_type")
|
| 175 |
+
if flags.output_format:
|
| 176 |
+
converter.output_format = _toco_flags_pb2.FileFormat.Value(
|
| 177 |
+
flags.output_format)
|
| 178 |
+
|
| 179 |
+
if flags.mean_values and flags.std_dev_values:
|
| 180 |
+
input_arrays = converter.get_input_arrays()
|
| 181 |
+
std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
|
| 182 |
+
|
| 183 |
+
# In quantized inference, mean_value has to be integer so that the real
|
| 184 |
+
# value 0.0 is exactly representable.
|
| 185 |
+
if converter.inference_type == dtypes.float32:
|
| 186 |
+
mean_values = _parse_array(flags.mean_values, type_fn=float)
|
| 187 |
+
else:
|
| 188 |
+
mean_values = _parse_array(flags.mean_values, type_fn=int)
|
| 189 |
+
quant_stats = list(zip(mean_values, std_dev_values))
|
| 190 |
+
if ((not flags.input_arrays and len(input_arrays) > 1) or
|
| 191 |
+
(len(input_arrays) != len(quant_stats))):
|
| 192 |
+
raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
|
| 193 |
+
"--mean_values. The flags must have the same number of "
|
| 194 |
+
"items. The current input arrays are '{0}'. "
|
| 195 |
+
"--input_arrays must be present when specifying "
|
| 196 |
+
"--std_dev_values and --mean_values with multiple input "
|
| 197 |
+
"tensors in order to map between names and "
|
| 198 |
+
"values.".format(",".join(input_arrays)))
|
| 199 |
+
converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
|
| 200 |
+
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
|
| 201 |
+
not None):
|
| 202 |
+
converter.default_ranges_stats = (flags.default_ranges_min,
|
| 203 |
+
flags.default_ranges_max)
|
| 204 |
+
|
| 205 |
+
if flags.drop_control_dependency:
|
| 206 |
+
converter.drop_control_dependency = flags.drop_control_dependency
|
| 207 |
+
if flags.reorder_across_fake_quant:
|
| 208 |
+
converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
|
| 209 |
+
if flags.change_concat_input_ranges:
|
| 210 |
+
converter.change_concat_input_ranges = (
|
| 211 |
+
flags.change_concat_input_ranges == "TRUE")
|
| 212 |
+
|
| 213 |
+
if flags.allow_custom_ops:
|
| 214 |
+
converter.allow_custom_ops = flags.allow_custom_ops
|
| 215 |
+
|
| 216 |
+
if flags.target_ops:
|
| 217 |
+
ops_set_options = lite.OpsSet.get_options()
|
| 218 |
+
converter.target_spec.supported_ops = set()
|
| 219 |
+
for option in flags.target_ops.split(","):
|
| 220 |
+
if option not in ops_set_options:
|
| 221 |
+
raise ValueError("Invalid value for --target_ops. Options: "
|
| 222 |
+
"{0}".format(",".join(ops_set_options)))
|
| 223 |
+
converter.target_spec.supported_ops.add(lite.OpsSet(option))
|
| 224 |
+
|
| 225 |
+
if flags.experimental_select_user_tf_ops:
|
| 226 |
+
if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops:
|
| 227 |
+
raise ValueError("--experimental_select_user_tf_ops can only be set if "
|
| 228 |
+
"--target_ops contains SELECT_TF_OPS.")
|
| 229 |
+
user_op_set = set()
|
| 230 |
+
for op_name in flags.experimental_select_user_tf_ops.split(","):
|
| 231 |
+
user_op_set.add(op_name)
|
| 232 |
+
converter.target_spec.experimental_select_user_tf_ops = list(user_op_set)
|
| 233 |
+
|
| 234 |
+
if flags.post_training_quantize:
|
| 235 |
+
converter.optimizations = [lite.Optimize.DEFAULT]
|
| 236 |
+
if converter.inference_type != dtypes.float32:
|
| 237 |
+
print("--post_training_quantize quantizes a graph of inference_type "
|
| 238 |
+
"FLOAT. Overriding inference_type to FLOAT.")
|
| 239 |
+
converter.inference_type = dtypes.float32
|
| 240 |
+
|
| 241 |
+
if flags.quantize_to_float16:
|
| 242 |
+
converter.target_spec.supported_types = [dtypes.float16]
|
| 243 |
+
if not flags.post_training_quantize:
|
| 244 |
+
print("--quantize_to_float16 will only take effect with the "
|
| 245 |
+
"--post_training_quantize flag enabled.")
|
| 246 |
+
|
| 247 |
+
if flags.dump_graphviz_dir:
|
| 248 |
+
converter.dump_graphviz_dir = flags.dump_graphviz_dir
|
| 249 |
+
if flags.dump_graphviz_video:
|
| 250 |
+
converter.dump_graphviz_vode = flags.dump_graphviz_video
|
| 251 |
+
if flags.conversion_summary_dir:
|
| 252 |
+
converter.conversion_summary_dir = flags.conversion_summary_dir
|
| 253 |
+
|
| 254 |
+
converter.experimental_new_converter = flags.experimental_new_converter
|
| 255 |
+
|
| 256 |
+
if flags.experimental_new_quantizer is not None:
|
| 257 |
+
converter.experimental_new_quantizer = flags.experimental_new_quantizer
|
| 258 |
+
|
| 259 |
+
# Convert model.
|
| 260 |
+
output_data = converter.convert()
|
| 261 |
+
with gfile.GFile(flags.output_file, "wb") as f:
|
| 262 |
+
f.write(output_data)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _convert_tf2_model(flags):
|
| 266 |
+
"""Calls function to convert the TensorFlow 2.0 model into a TFLite model.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
flags: argparse.Namespace object.
|
| 270 |
+
|
| 271 |
+
Raises:
|
| 272 |
+
ValueError: Unsupported file format.
|
| 273 |
+
"""
|
| 274 |
+
# Load the model.
|
| 275 |
+
if flags.saved_model_dir:
|
| 276 |
+
converter = lite.TFLiteConverterV2.from_saved_model(
|
| 277 |
+
flags.saved_model_dir,
|
| 278 |
+
signature_keys=_parse_array(flags.saved_model_signature_key),
|
| 279 |
+
tags=_parse_set(flags.saved_model_tag_set))
|
| 280 |
+
elif flags.keras_model_file:
|
| 281 |
+
model = keras_deps.get_load_model_function()(flags.keras_model_file)
|
| 282 |
+
converter = lite.TFLiteConverterV2.from_keras_model(model)
|
| 283 |
+
|
| 284 |
+
converter.experimental_new_converter = flags.experimental_new_converter
|
| 285 |
+
|
| 286 |
+
if flags.experimental_new_quantizer is not None:
|
| 287 |
+
converter.experimental_new_quantizer = flags.experimental_new_quantizer
|
| 288 |
+
|
| 289 |
+
# Convert the model.
|
| 290 |
+
tflite_model = converter.convert()
|
| 291 |
+
with gfile.GFile(flags.output_file, "wb") as f:
|
| 292 |
+
f.write(tflite_model)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _check_tf1_flags(flags, unparsed):
|
| 296 |
+
"""Checks the parsed and unparsed flags to ensure they are valid in 1.X.
|
| 297 |
+
|
| 298 |
+
Raises an error if previously support unparsed flags are found. Raises an
|
| 299 |
+
error for parsed flags that don't meet the required conditions.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
flags: argparse.Namespace object containing TFLite flags.
|
| 303 |
+
unparsed: List of unparsed flags.
|
| 304 |
+
|
| 305 |
+
Raises:
|
| 306 |
+
ValueError: Invalid flags.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
# Check unparsed flags for common mistakes based on previous TOCO.
|
| 310 |
+
def _get_message_unparsed(flag, orig_flag, new_flag):
|
| 311 |
+
if flag.startswith(orig_flag):
|
| 312 |
+
return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
|
| 313 |
+
return ""
|
| 314 |
+
|
| 315 |
+
if unparsed:
|
| 316 |
+
output = ""
|
| 317 |
+
for flag in unparsed:
|
| 318 |
+
output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
|
| 319 |
+
output += _get_message_unparsed(flag, "--savedmodel_directory",
|
| 320 |
+
"--saved_model_dir")
|
| 321 |
+
output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
|
| 322 |
+
output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
|
| 323 |
+
output += _get_message_unparsed(flag, "--dump_graphviz",
|
| 324 |
+
"--dump_graphviz_dir")
|
| 325 |
+
if output:
|
| 326 |
+
raise ValueError(output)
|
| 327 |
+
|
| 328 |
+
# Check that flags are valid.
|
| 329 |
+
if flags.graph_def_file and (not flags.input_arrays or
|
| 330 |
+
not flags.output_arrays):
|
| 331 |
+
raise ValueError("--input_arrays and --output_arrays are required with "
|
| 332 |
+
"--graph_def_file")
|
| 333 |
+
|
| 334 |
+
if flags.input_shapes:
|
| 335 |
+
if not flags.input_arrays:
|
| 336 |
+
raise ValueError("--input_shapes must be used with --input_arrays")
|
| 337 |
+
if flags.input_shapes.count(":") != flags.input_arrays.count(","):
|
| 338 |
+
raise ValueError("--input_shapes and --input_arrays must have the same "
|
| 339 |
+
"number of items")
|
| 340 |
+
|
| 341 |
+
if flags.std_dev_values or flags.mean_values:
|
| 342 |
+
if bool(flags.std_dev_values) != bool(flags.mean_values):
|
| 343 |
+
raise ValueError("--std_dev_values and --mean_values must be used "
|
| 344 |
+
"together")
|
| 345 |
+
if flags.std_dev_values.count(",") != flags.mean_values.count(","):
|
| 346 |
+
raise ValueError("--std_dev_values, --mean_values must have the same "
|
| 347 |
+
"number of items")
|
| 348 |
+
|
| 349 |
+
if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
|
| 350 |
+
raise ValueError("--default_ranges_min and --default_ranges_max must be "
|
| 351 |
+
"used together")
|
| 352 |
+
|
| 353 |
+
if flags.dump_graphviz_video and not flags.dump_graphviz_dir:
|
| 354 |
+
raise ValueError("--dump_graphviz_video must be used with "
|
| 355 |
+
"--dump_graphviz_dir")
|
| 356 |
+
|
| 357 |
+
if flags.custom_opdefs and not flags.experimental_new_converter:
|
| 358 |
+
raise ValueError("--custom_opdefs must be used with "
|
| 359 |
+
"--experimental_new_converter")
|
| 360 |
+
if flags.custom_opdefs and not flags.allow_custom_ops:
|
| 361 |
+
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
|
| 362 |
+
if (flags.experimental_select_user_tf_ops and
|
| 363 |
+
not flags.experimental_new_converter):
|
| 364 |
+
raise ValueError("--experimental_select_user_tf_ops must be used with "
|
| 365 |
+
"--experimental_new_converter")
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def _check_tf2_flags(flags):
|
| 369 |
+
"""Checks the parsed and unparsed flags to ensure they are valid in 2.X.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
flags: argparse.Namespace object containing TFLite flags.
|
| 373 |
+
|
| 374 |
+
Raises:
|
| 375 |
+
ValueError: Invalid flags.
|
| 376 |
+
"""
|
| 377 |
+
if not flags.keras_model_file and not flags.saved_model_dir:
|
| 378 |
+
raise ValueError("one of the arguments --saved_model_dir "
|
| 379 |
+
"--keras_model_file is required")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def _get_tf1_flags(parser):
|
| 383 |
+
"""Returns ArgumentParser for tflite_convert for TensorFlow 1.X.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
parser: ArgumentParser
|
| 387 |
+
"""
|
| 388 |
+
# Input file flags.
|
| 389 |
+
input_file_group = parser.add_mutually_exclusive_group(required=True)
|
| 390 |
+
input_file_group.add_argument(
|
| 391 |
+
"--graph_def_file",
|
| 392 |
+
type=str,
|
| 393 |
+
help="Full filepath of file containing frozen TensorFlow GraphDef.")
|
| 394 |
+
input_file_group.add_argument(
|
| 395 |
+
"--saved_model_dir",
|
| 396 |
+
type=str,
|
| 397 |
+
help="Full filepath of directory containing the SavedModel.")
|
| 398 |
+
input_file_group.add_argument(
|
| 399 |
+
"--keras_model_file",
|
| 400 |
+
type=str,
|
| 401 |
+
help="Full filepath of HDF5 file containing tf.Keras model.")
|
| 402 |
+
|
| 403 |
+
# Model format flags.
|
| 404 |
+
parser.add_argument(
|
| 405 |
+
"--output_format",
|
| 406 |
+
type=str.upper,
|
| 407 |
+
choices=["TFLITE", "GRAPHVIZ_DOT"],
|
| 408 |
+
help="Output file format.")
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--inference_type",
|
| 411 |
+
type=str.upper,
|
| 412 |
+
default="FLOAT",
|
| 413 |
+
help=("Target data type of real-number arrays in the output file. "
|
| 414 |
+
"Must be either FLOAT, INT8 or UINT8."))
|
| 415 |
+
parser.add_argument(
|
| 416 |
+
"--inference_input_type",
|
| 417 |
+
type=str.upper,
|
| 418 |
+
help=("Target data type of real-number input arrays. Allows for a "
|
| 419 |
+
"different type for input arrays in the case of quantization. "
|
| 420 |
+
"Must be either FLOAT, INT8 or UINT8."))
|
| 421 |
+
|
| 422 |
+
# Input and output arrays flags.
|
| 423 |
+
parser.add_argument(
|
| 424 |
+
"--input_arrays",
|
| 425 |
+
type=str,
|
| 426 |
+
help="Names of the input arrays, comma-separated.")
|
| 427 |
+
parser.add_argument(
|
| 428 |
+
"--input_shapes",
|
| 429 |
+
type=str,
|
| 430 |
+
help="Shapes corresponding to --input_arrays, colon-separated.")
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--output_arrays",
|
| 433 |
+
type=str,
|
| 434 |
+
help="Names of the output arrays, comma-separated.")
|
| 435 |
+
|
| 436 |
+
# SavedModel related flags.
|
| 437 |
+
parser.add_argument(
|
| 438 |
+
"--saved_model_tag_set",
|
| 439 |
+
type=str,
|
| 440 |
+
help=("Comma-separated set of tags identifying the MetaGraphDef within "
|
| 441 |
+
"the SavedModel to analyze. All tags must be present. In order to "
|
| 442 |
+
"pass in an empty tag set, pass in \"\". (default \"serve\")"))
|
| 443 |
+
parser.add_argument(
|
| 444 |
+
"--saved_model_signature_key",
|
| 445 |
+
type=str,
|
| 446 |
+
help=("Key identifying the SignatureDef containing inputs and outputs. "
|
| 447 |
+
"(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
|
| 448 |
+
|
| 449 |
+
# Quantization flags.
|
| 450 |
+
parser.add_argument(
|
| 451 |
+
"--std_dev_values",
|
| 452 |
+
type=str,
|
| 453 |
+
help=("Standard deviation of training data for each input tensor, "
|
| 454 |
+
"comma-separated floats. Used for quantized input tensors. "
|
| 455 |
+
"(default None)"))
|
| 456 |
+
parser.add_argument(
|
| 457 |
+
"--mean_values",
|
| 458 |
+
type=str,
|
| 459 |
+
help=("Mean of training data for each input tensor, comma-separated "
|
| 460 |
+
"floats. Used for quantized input tensors. (default None)"))
|
| 461 |
+
parser.add_argument(
|
| 462 |
+
"--default_ranges_min",
|
| 463 |
+
type=float,
|
| 464 |
+
help=("Default value for min bound of min/max range values used for all "
|
| 465 |
+
"arrays without a specified range, Intended for experimenting with "
|
| 466 |
+
"quantization via \"dummy quantization\". (default None)"))
|
| 467 |
+
parser.add_argument(
|
| 468 |
+
"--default_ranges_max",
|
| 469 |
+
type=float,
|
| 470 |
+
help=("Default value for max bound of min/max range values used for all "
|
| 471 |
+
"arrays without a specified range, Intended for experimenting with "
|
| 472 |
+
"quantization via \"dummy quantization\". (default None)"))
|
| 473 |
+
# quantize_weights is DEPRECATED.
|
| 474 |
+
parser.add_argument(
|
| 475 |
+
"--quantize_weights",
|
| 476 |
+
dest="post_training_quantize",
|
| 477 |
+
action="store_true",
|
| 478 |
+
help=argparse.SUPPRESS)
|
| 479 |
+
parser.add_argument(
|
| 480 |
+
"--post_training_quantize",
|
| 481 |
+
dest="post_training_quantize",
|
| 482 |
+
action="store_true",
|
| 483 |
+
help=(
|
| 484 |
+
"Boolean indicating whether to quantize the weights of the "
|
| 485 |
+
"converted float model. Model size will be reduced and there will "
|
| 486 |
+
"be latency improvements (at the cost of accuracy). (default False)"))
|
| 487 |
+
parser.add_argument(
|
| 488 |
+
"--quantize_to_float16",
|
| 489 |
+
dest="quantize_to_float16",
|
| 490 |
+
action="store_true",
|
| 491 |
+
help=("Boolean indicating whether to quantize weights to fp16 instead of "
|
| 492 |
+
"the default int8 when post-training quantization "
|
| 493 |
+
"(--post_training_quantize) is enabled. (default False)"))
|
| 494 |
+
# Graph manipulation flags.
|
| 495 |
+
parser.add_argument(
|
| 496 |
+
"--drop_control_dependency",
|
| 497 |
+
action="store_true",
|
| 498 |
+
help=("Boolean indicating whether to drop control dependencies silently. "
|
| 499 |
+
"This is due to TensorFlow not supporting control dependencies. "
|
| 500 |
+
"(default True)"))
|
| 501 |
+
parser.add_argument(
|
| 502 |
+
"--reorder_across_fake_quant",
|
| 503 |
+
action="store_true",
|
| 504 |
+
help=("Boolean indicating whether to reorder FakeQuant nodes in "
|
| 505 |
+
"unexpected locations. Used when the location of the FakeQuant "
|
| 506 |
+
"nodes is preventing graph transformations necessary to convert "
|
| 507 |
+
"the graph. Results in a graph that differs from the quantized "
|
| 508 |
+
"training graph, potentially causing differing arithmetic "
|
| 509 |
+
"behavior. (default False)"))
|
| 510 |
+
# Usage for this flag is --change_concat_input_ranges=true or
|
| 511 |
+
# --change_concat_input_ranges=false in order to make it clear what the flag
|
| 512 |
+
# is set to. This keeps the usage consistent with other usages of the flag
|
| 513 |
+
# where the default is different. The default value here is False.
|
| 514 |
+
parser.add_argument(
|
| 515 |
+
"--change_concat_input_ranges",
|
| 516 |
+
type=str.upper,
|
| 517 |
+
choices=["TRUE", "FALSE"],
|
| 518 |
+
help=("Boolean to change behavior of min/max ranges for inputs and "
|
| 519 |
+
"outputs of the concat operator for quantized models. Changes the "
|
| 520 |
+
"ranges of concat operator overlap when true. (default False)"))
|
| 521 |
+
|
| 522 |
+
# Permitted ops flags.
|
| 523 |
+
parser.add_argument(
|
| 524 |
+
"--allow_custom_ops",
|
| 525 |
+
action=_ParseBooleanFlag,
|
| 526 |
+
nargs="?",
|
| 527 |
+
help=("Boolean indicating whether to allow custom operations. When false "
|
| 528 |
+
"any unknown operation is an error. When true, custom ops are "
|
| 529 |
+
"created for any op that is unknown. The developer will need to "
|
| 530 |
+
"provide these to the TensorFlow Lite runtime with a custom "
|
| 531 |
+
"resolver. (default False)"))
|
| 532 |
+
parser.add_argument(
|
| 533 |
+
"--custom_opdefs",
|
| 534 |
+
type=str,
|
| 535 |
+
help=("String representing a list of custom ops OpDefs delineated with "
|
| 536 |
+
"commas that are included in the GraphDef. Required when using "
|
| 537 |
+
"custom operations with --experimental_new_converter."))
|
| 538 |
+
parser.add_argument(
|
| 539 |
+
"--target_ops",
|
| 540 |
+
type=str,
|
| 541 |
+
help=("Experimental flag, subject to change. Set of OpsSet options "
|
| 542 |
+
"indicating which converter to use. Options: {0}. One or more "
|
| 543 |
+
"option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
|
| 544 |
+
"".format(",".join(lite.OpsSet.get_options()))))
|
| 545 |
+
parser.add_argument(
|
| 546 |
+
"--experimental_select_user_tf_ops",
|
| 547 |
+
type=str,
|
| 548 |
+
help=("Experimental flag, subject to change. Comma separated list of "
|
| 549 |
+
"user's defined TensorFlow operators required in the runtime."))
|
| 550 |
+
|
| 551 |
+
# Logging flags.
|
| 552 |
+
parser.add_argument(
|
| 553 |
+
"--dump_graphviz_dir",
|
| 554 |
+
type=str,
|
| 555 |
+
help=("Full filepath of folder to dump the graphs at various stages of "
|
| 556 |
+
"processing GraphViz .dot files. Preferred over --output_format="
|
| 557 |
+
"GRAPHVIZ_DOT in order to keep the requirements of the output "
|
| 558 |
+
"file."))
|
| 559 |
+
parser.add_argument(
|
| 560 |
+
"--dump_graphviz_video",
|
| 561 |
+
action="store_true",
|
| 562 |
+
help=("Boolean indicating whether to dump the graph after every graph "
|
| 563 |
+
"transformation"))
|
| 564 |
+
parser.add_argument(
|
| 565 |
+
"--conversion_summary_dir",
|
| 566 |
+
type=str,
|
| 567 |
+
help=("Full filepath to store the conversion logs, which includes "
|
| 568 |
+
"graphviz of the model before/after the conversion, an HTML report "
|
| 569 |
+
"and the conversion proto buffers. This will only be generated "
|
| 570 |
+
"when passing --experimental_new_converter"))
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
def _get_tf2_flags(parser):
|
| 574 |
+
"""Returns ArgumentParser for tflite_convert for TensorFlow 2.0.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
parser: ArgumentParser
|
| 578 |
+
"""
|
| 579 |
+
# Input file flags.
|
| 580 |
+
input_file_group = parser.add_mutually_exclusive_group()
|
| 581 |
+
input_file_group.add_argument(
|
| 582 |
+
"--saved_model_dir",
|
| 583 |
+
type=str,
|
| 584 |
+
help="Full path of the directory containing the SavedModel.")
|
| 585 |
+
input_file_group.add_argument(
|
| 586 |
+
"--keras_model_file",
|
| 587 |
+
type=str,
|
| 588 |
+
help="Full filepath of HDF5 file containing tf.Keras model.")
|
| 589 |
+
# SavedModel related flags.
|
| 590 |
+
parser.add_argument(
|
| 591 |
+
"--saved_model_tag_set",
|
| 592 |
+
type=str,
|
| 593 |
+
help=("Comma-separated set of tags identifying the MetaGraphDef within "
|
| 594 |
+
"the SavedModel to analyze. All tags must be present. In order to "
|
| 595 |
+
"pass in an empty tag set, pass in \"\". (default \"serve\")"))
|
| 596 |
+
parser.add_argument(
|
| 597 |
+
"--saved_model_signature_key",
|
| 598 |
+
type=str,
|
| 599 |
+
help=("Key identifying the SignatureDef containing inputs and outputs. "
|
| 600 |
+
"(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
|
| 601 |
+
|
| 602 |
+
# Enables 1.X converter in 2.X.
|
| 603 |
+
parser.add_argument(
|
| 604 |
+
"--enable_v1_converter",
|
| 605 |
+
action="store_true",
|
| 606 |
+
help=("Enables the TensorFlow V1 converter in 2.0"))
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def _get_parser(use_v2_converter):
|
| 610 |
+
"""Returns an ArgumentParser for tflite_convert.
|
| 611 |
+
|
| 612 |
+
Args:
|
| 613 |
+
use_v2_converter: Indicates which converter to return.
|
| 614 |
+
Return: ArgumentParser.
|
| 615 |
+
"""
|
| 616 |
+
parser = argparse.ArgumentParser(
|
| 617 |
+
description=("Command line tool to run TensorFlow Lite Converter."))
|
| 618 |
+
|
| 619 |
+
# Output file flag.
|
| 620 |
+
parser.add_argument(
|
| 621 |
+
"--output_file",
|
| 622 |
+
type=str,
|
| 623 |
+
help="Full filepath of the output file.",
|
| 624 |
+
required=True)
|
| 625 |
+
|
| 626 |
+
if use_v2_converter:
|
| 627 |
+
_get_tf2_flags(parser)
|
| 628 |
+
else:
|
| 629 |
+
_get_tf1_flags(parser)
|
| 630 |
+
|
| 631 |
+
parser.add_argument(
|
| 632 |
+
"--experimental_new_converter",
|
| 633 |
+
action=_ParseBooleanFlag,
|
| 634 |
+
nargs="?",
|
| 635 |
+
default=True,
|
| 636 |
+
help=("Experimental flag, subject to change. Enables MLIR-based "
|
| 637 |
+
"conversion instead of TOCO conversion. (default True)"))
|
| 638 |
+
|
| 639 |
+
parser.add_argument(
|
| 640 |
+
"--experimental_new_quantizer",
|
| 641 |
+
action=_ParseBooleanFlag,
|
| 642 |
+
nargs="?",
|
| 643 |
+
help=("Experimental flag, subject to change. Enables MLIR-based "
|
| 644 |
+
"quantizer instead of flatbuffer conversion. (default True)"))
|
| 645 |
+
return parser
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
def run_main(_):
|
| 649 |
+
"""Main in tflite_convert.py."""
|
| 650 |
+
use_v2_converter = tf2.enabled()
|
| 651 |
+
parser = _get_parser(use_v2_converter)
|
| 652 |
+
tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
|
| 653 |
+
|
| 654 |
+
# If the user is running TensorFlow 2.X but has passed in enable_v1_converter
|
| 655 |
+
# then parse the flags again with the 1.X converter flags.
|
| 656 |
+
if tf2.enabled() and tflite_flags.enable_v1_converter:
|
| 657 |
+
use_v2_converter = False
|
| 658 |
+
parser = _get_parser(use_v2_converter)
|
| 659 |
+
tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
|
| 660 |
+
|
| 661 |
+
# Checks if the flags are valid.
|
| 662 |
+
try:
|
| 663 |
+
if use_v2_converter:
|
| 664 |
+
_check_tf2_flags(tflite_flags)
|
| 665 |
+
else:
|
| 666 |
+
_check_tf1_flags(tflite_flags, unparsed)
|
| 667 |
+
except ValueError as e:
|
| 668 |
+
parser.print_usage()
|
| 669 |
+
file_name = os.path.basename(sys.argv[0])
|
| 670 |
+
sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
|
| 671 |
+
sys.exit(1)
|
| 672 |
+
|
| 673 |
+
# Convert the model according to the user provided flag.
|
| 674 |
+
if use_v2_converter:
|
| 675 |
+
_convert_tf2_model(tflite_flags)
|
| 676 |
+
else:
|
| 677 |
+
try:
|
| 678 |
+
_convert_tf1_model(tflite_flags)
|
| 679 |
+
finally:
|
| 680 |
+
if tflite_flags.conversion_summary_dir:
|
| 681 |
+
if tflite_flags.experimental_new_converter:
|
| 682 |
+
gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir,
|
| 683 |
+
tflite_flags.post_training_quantize,
|
| 684 |
+
tflite_flags.output_file)
|
| 685 |
+
else:
|
| 686 |
+
warnings.warn(
|
| 687 |
+
"Conversion summary will only be generated when enabling"
|
| 688 |
+
" the new converter via --experimental_new_converter. ")
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def main():
|
| 692 |
+
app.run(main=run_main, argv=sys.argv[:1])
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
if __name__ == "__main__":
|
| 696 |
+
main()
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/util.py
ADDED
|
@@ -0,0 +1,1177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Functions used by multiple converter files."""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import datetime
|
| 19 |
+
import sys
|
| 20 |
+
|
| 21 |
+
from absl import logging
|
| 22 |
+
import flatbuffers
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
from tensorflow.core.protobuf import config_pb2 as _config_pb2
|
| 26 |
+
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
|
| 27 |
+
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
|
| 28 |
+
from tensorflow.lite.python import schema_py_generated as schema_fb
|
| 29 |
+
from tensorflow.lite.python import schema_util
|
| 30 |
+
from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
|
| 31 |
+
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
|
| 32 |
+
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
|
| 33 |
+
from tensorflow.lite.tools import flatbuffer_utils
|
| 34 |
+
from tensorflow.python.eager import function
|
| 35 |
+
from tensorflow.python.framework import convert_to_constants as _convert_to_constants
|
| 36 |
+
from tensorflow.python.framework import dtypes
|
| 37 |
+
from tensorflow.python.framework import error_interpolation as _error_interpolation
|
| 38 |
+
from tensorflow.python.grappler import tf_optimizer
|
| 39 |
+
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
|
| 40 |
+
|
| 41 |
+
# The field name of conversion metadata in the flatbuffer file.
|
| 42 |
+
CONVERSION_METADATA_FIELD_NAME = "CONVERSION_METADATA"
|
| 43 |
+
|
| 44 |
+
# Keras functions used by TFLite
|
| 45 |
+
model_input_signature = _tflite_keras_util.model_input_signature
|
| 46 |
+
trace_model_call = _tflite_keras_util.trace_model_call
|
| 47 |
+
get_save_spec = _tflite_keras_util.get_save_spec
|
| 48 |
+
|
| 49 |
+
# Jax functions used by TFLite
|
| 50 |
+
# pylint: disable=g-import-not-at-top
|
| 51 |
+
# pylint: disable=unused-import
|
| 52 |
+
try:
|
| 53 |
+
from jax import jit as _jit
|
| 54 |
+
except ImportError:
|
| 55 |
+
_jit = None
|
| 56 |
+
# pylint: enable=g-import-not-at-top
|
| 57 |
+
# pylint: enable=unused-import
|
| 58 |
+
|
| 59 |
+
# Defined as per TFLite schema
|
| 60 |
+
_MAP_TFLITE_ENUM_TO_TF_TYPES = {
|
| 61 |
+
0: dtypes.float32,
|
| 62 |
+
1: dtypes.float16,
|
| 63 |
+
2: dtypes.int32,
|
| 64 |
+
3: dtypes.uint8,
|
| 65 |
+
4: dtypes.int64,
|
| 66 |
+
5: dtypes.string,
|
| 67 |
+
6: dtypes.bool,
|
| 68 |
+
7: dtypes.int16,
|
| 69 |
+
8: dtypes.complex64,
|
| 70 |
+
9: dtypes.int8,
|
| 71 |
+
10: dtypes.float64,
|
| 72 |
+
11: dtypes.complex128,
|
| 73 |
+
16: dtypes.uint32,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
_TFLITE_FILE_IDENTIFIER = b"TFL3"
|
| 77 |
+
|
| 78 |
+
_MAP_QUANT_TO_IO_TYPES = {
|
| 79 |
+
dtypes.int8: {dtypes.int8, dtypes.uint8},
|
| 80 |
+
dtypes.int16: {dtypes.int16},
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
|
| 85 |
+
"""Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
|
| 89 |
+
|
| 90 |
+
Raises:
|
| 91 |
+
ValueError: If an invalid tflite enum type is provided.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
tf type (eg: tf.float32)
|
| 95 |
+
"""
|
| 96 |
+
tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
|
| 97 |
+
if tf_type is None:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
"Unsupported enum {}. The valid map of enum to tf types is : {}"
|
| 100 |
+
.format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
|
| 101 |
+
return tf_type
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_tf_type_name(tf_type):
|
| 105 |
+
"""Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
|
| 106 |
+
return "tf." + tf_type.name if tf_type else None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def get_tensor_name(tensor):
|
| 110 |
+
"""Returns name of the input tensor.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
tensor: tf.Tensor
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
str
|
| 117 |
+
"""
|
| 118 |
+
parts = tensor.name.split(":")
|
| 119 |
+
if len(parts) > 2:
|
| 120 |
+
raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
|
| 121 |
+
len(parts) - 1))
|
| 122 |
+
|
| 123 |
+
# To be consistent with the tensor naming scheme in tensorflow, we need
|
| 124 |
+
# drop the ':0' suffix for the first tensor.
|
| 125 |
+
if len(parts) > 1 and parts[1] != "0":
|
| 126 |
+
return tensor.name
|
| 127 |
+
return parts[0]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def get_tensors_from_tensor_names(graph, tensor_names):
|
| 131 |
+
"""Gets the Tensors associated with the `tensor_names` in the provided graph.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
graph: TensorFlow Graph.
|
| 135 |
+
tensor_names: List of strings that represent names of tensors in the graph.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
A list of Tensor objects in the same order the names are provided.
|
| 139 |
+
|
| 140 |
+
Raises:
|
| 141 |
+
ValueError:
|
| 142 |
+
tensor_names contains an invalid tensor name.
|
| 143 |
+
"""
|
| 144 |
+
# Get the list of all of the tensors.
|
| 145 |
+
tensor_name_to_tensor = {}
|
| 146 |
+
for op in graph.get_operations():
|
| 147 |
+
for tensor in op.values():
|
| 148 |
+
tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
|
| 149 |
+
|
| 150 |
+
# Get the tensors associated with tensor_names.
|
| 151 |
+
tensors = []
|
| 152 |
+
invalid_tensors = []
|
| 153 |
+
for name in tensor_names:
|
| 154 |
+
if not isinstance(name, str):
|
| 155 |
+
raise ValueError("Invalid type for a tensor name in the provided graph. "
|
| 156 |
+
"Expected type for a tensor name is 'str', instead got "
|
| 157 |
+
"type '{}' for tensor name '{}'".format(
|
| 158 |
+
type(name), name))
|
| 159 |
+
|
| 160 |
+
tensor = tensor_name_to_tensor.get(name)
|
| 161 |
+
if tensor is None:
|
| 162 |
+
invalid_tensors.append(name)
|
| 163 |
+
else:
|
| 164 |
+
tensors.append(tensor)
|
| 165 |
+
|
| 166 |
+
# Throw ValueError if any user input names are not valid tensors.
|
| 167 |
+
if invalid_tensors:
|
| 168 |
+
raise ValueError("Invalid tensors '{}' were found.".format(
|
| 169 |
+
",".join(invalid_tensors)))
|
| 170 |
+
return tensors
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def set_tensor_shapes(tensors, shapes):
|
| 174 |
+
"""Sets Tensor shape for each tensor if the shape is defined.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
tensors: TensorFlow tensor.Tensor.
|
| 178 |
+
shapes: Dict of strings representing input tensor names to list of
|
| 179 |
+
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
|
| 180 |
+
|
| 181 |
+
Raises:
|
| 182 |
+
ValueError:
|
| 183 |
+
`shapes` contains an invalid tensor.
|
| 184 |
+
`shapes` contains an invalid shape for a valid tensor.
|
| 185 |
+
"""
|
| 186 |
+
if shapes:
|
| 187 |
+
tensor_names_to_tensor = {
|
| 188 |
+
get_tensor_name(tensor): tensor for tensor in tensors
|
| 189 |
+
}
|
| 190 |
+
for name, shape in shapes.items():
|
| 191 |
+
if name not in tensor_names_to_tensor:
|
| 192 |
+
raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
|
| 193 |
+
"map.".format(name))
|
| 194 |
+
if shape is not None:
|
| 195 |
+
tensor = tensor_names_to_tensor[name]
|
| 196 |
+
try:
|
| 197 |
+
tensor.set_shape(shape)
|
| 198 |
+
except ValueError as error:
|
| 199 |
+
message = ("The shape of tensor '{0}' cannot be changed from {1} to "
|
| 200 |
+
"{2}. {3}".format(name, tensor.shape, shape, str(error)))
|
| 201 |
+
raise ValueError(message)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def get_grappler_config(optimizers_list):
|
| 205 |
+
"""Creates a tf.compat.v1.ConfigProto for configuring Grappler.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
optimizers_list: List of strings that represents the list of optimizers.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
tf.ConfigProto.
|
| 212 |
+
"""
|
| 213 |
+
config = _config_pb2.ConfigProto()
|
| 214 |
+
rewrite_options = config.graph_options.rewrite_options
|
| 215 |
+
for optimizer in optimizers_list:
|
| 216 |
+
rewrite_options.optimizers.append(optimizer)
|
| 217 |
+
return config
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def run_graph_optimizations(graph_def,
|
| 221 |
+
input_arrays,
|
| 222 |
+
output_arrays,
|
| 223 |
+
config,
|
| 224 |
+
graph=None):
|
| 225 |
+
"""Apply standard TensorFlow optimizations to the graph_def.
|
| 226 |
+
|
| 227 |
+
Args:
|
| 228 |
+
graph_def: Frozen GraphDef to be optimized.
|
| 229 |
+
input_arrays: List of arrays that are considered inputs of the graph.
|
| 230 |
+
output_arrays: List of arrays that are considered outputs of the graph.
|
| 231 |
+
config: tf.ConfigProto.
|
| 232 |
+
graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
A new, optimized GraphDef.
|
| 236 |
+
"""
|
| 237 |
+
meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
|
| 238 |
+
|
| 239 |
+
signature = _meta_graph_pb2.SignatureDef()
|
| 240 |
+
for array in input_arrays:
|
| 241 |
+
signature.inputs[array.name].name = array.name
|
| 242 |
+
signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
|
| 243 |
+
signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
|
| 244 |
+
|
| 245 |
+
for array in output_arrays:
|
| 246 |
+
signature.outputs[array.name].name = array.name
|
| 247 |
+
signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
|
| 248 |
+
signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
|
| 249 |
+
|
| 250 |
+
meta_graph.signature_def["not_used_key"].CopyFrom(signature)
|
| 251 |
+
|
| 252 |
+
# We need to add a collection called 'train_op' so that grappler
|
| 253 |
+
# knows what the outputs are.
|
| 254 |
+
fetch_collection = _meta_graph_pb2.CollectionDef()
|
| 255 |
+
for array in input_arrays + output_arrays:
|
| 256 |
+
fetch_collection.node_list.value.append(array.name)
|
| 257 |
+
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
|
| 258 |
+
|
| 259 |
+
return tf_optimizer.OptimizeGraph(config, meta_graph)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _convert_op_hints_if_present(sess, graph_def, output_tensors,
|
| 263 |
+
hinted_outputs_nodes):
|
| 264 |
+
if is_frozen_graph(sess):
|
| 265 |
+
raise ValueError("Try to convert op hints, needs unfrozen graph.")
|
| 266 |
+
output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
|
| 267 |
+
graph_def = _convert_to_constants.convert_variables_to_constants(
|
| 268 |
+
sess, graph_def, output_arrays + hinted_outputs_nodes)
|
| 269 |
+
graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
|
| 270 |
+
return graph_def
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def freeze_graph(sess, input_tensors, output_tensors):
|
| 274 |
+
"""Returns a frozen GraphDef.
|
| 275 |
+
|
| 276 |
+
Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
|
| 277 |
+
existing GraphDef is returned. The Grappler pass is only run on models that
|
| 278 |
+
are frozen in order to inline the functions in the graph.
|
| 279 |
+
If OpHints is present, it will try to convert the OpHint graph.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
sess: TensorFlow Session.
|
| 283 |
+
input_tensors: List of input tensors.
|
| 284 |
+
output_tensors: List of output tensors (only .name is used from this).
|
| 285 |
+
|
| 286 |
+
Returns:
|
| 287 |
+
Frozen GraphDef.
|
| 288 |
+
"""
|
| 289 |
+
# Runs a Grappler pass in order to inline any functions in the graph.
|
| 290 |
+
# Asides from inlining any simple function, Grappler will also try to lower
|
| 291 |
+
# while loop into switch merge representation which is undesired for Ophints,
|
| 292 |
+
# so we simply remove those attributes to prevent Grappler from doing so.
|
| 293 |
+
graph_def = _convert_to_constants.disable_lower_using_switch_merge(
|
| 294 |
+
sess.graph_def)
|
| 295 |
+
config = get_grappler_config(["function"])
|
| 296 |
+
graph_def = run_graph_optimizations(
|
| 297 |
+
graph_def, input_tensors, output_tensors, config, graph=sess.graph)
|
| 298 |
+
|
| 299 |
+
# If ophints are present, just convert them.
|
| 300 |
+
hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
|
| 301 |
+
if hinted_outputs_nodes:
|
| 302 |
+
return _convert_op_hints_if_present(sess, graph_def, output_tensors,
|
| 303 |
+
hinted_outputs_nodes)
|
| 304 |
+
|
| 305 |
+
if not is_frozen_graph(sess):
|
| 306 |
+
output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
|
| 307 |
+
return _convert_to_constants.convert_variables_to_constants(
|
| 308 |
+
sess, graph_def, output_node_names
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
return sess.graph_def
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def is_frozen_graph(sess):
|
| 315 |
+
"""Determines if the graph is frozen.
|
| 316 |
+
|
| 317 |
+
Determines if a graph has previously been frozen by checking for any
|
| 318 |
+
operations of type Variable*. If variables are found, the graph is not frozen.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
sess: TensorFlow Session.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Bool.
|
| 325 |
+
"""
|
| 326 |
+
for op in sess.graph.get_operations():
|
| 327 |
+
if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
|
| 328 |
+
return False
|
| 329 |
+
return True
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def build_debug_info_func(original_graph):
|
| 333 |
+
"""Returns a method to retrieve the `GraphDebugInfo` from the original graph.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
original_graph: The original `Graph` containing all the op stack traces.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
A function which retrieves the stack traces from the original graph and
|
| 340 |
+
converts them to a `GraphDebugInfo` for a given set of nodes.
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
def f(original_nodes):
|
| 344 |
+
"""Function to create `GraphDebugInfo` for the given `original_nodes`."""
|
| 345 |
+
if not original_graph:
|
| 346 |
+
return None
|
| 347 |
+
# For the given nodes, gets all the op definitions in the original graph.
|
| 348 |
+
useful_ops = []
|
| 349 |
+
for func, name in original_nodes:
|
| 350 |
+
try:
|
| 351 |
+
if not func:
|
| 352 |
+
useful_ops.append((func, original_graph.get_operation_by_name(name)))
|
| 353 |
+
else:
|
| 354 |
+
sub_func = original_graph._get_function(func) # pylint: disable=protected-access
|
| 355 |
+
if isinstance(sub_func, function.AtomicFunction): # pylint: disable=protected-access
|
| 356 |
+
useful_ops.append(
|
| 357 |
+
(func, sub_func.graph.get_operation_by_name(name)))
|
| 358 |
+
else:
|
| 359 |
+
sys.stderr.write(
|
| 360 |
+
"Use '@tf.function' or '@defun' to decorate the function.\n")
|
| 361 |
+
continue
|
| 362 |
+
except KeyError:
|
| 363 |
+
# New node created by graph optimizer. No stack trace from source code.
|
| 364 |
+
continue
|
| 365 |
+
# Convert all the op definitions to stack traces in terms of GraphDebugInfo.
|
| 366 |
+
return _error_interpolation.create_graph_debug_info_def(useful_ops)
|
| 367 |
+
|
| 368 |
+
return f
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def convert_debug_info_func(saved_debug_info):
|
| 372 |
+
"""Returns a method to retrieve the `GraphDebugInfo` from the original graph.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
saved_debug_info: The `GraphDebugInfo` containing all the debug info.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
A function which retrieves the stack traces from the original graph and
|
| 379 |
+
converts them to a `GraphDebugInfo` for a given set of nodes.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def f(original_nodes):
|
| 383 |
+
"""Function to create `GraphDebugInfo` for the given `original_nodes`."""
|
| 384 |
+
del original_nodes
|
| 385 |
+
return saved_debug_info
|
| 386 |
+
|
| 387 |
+
return f
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def get_debug_info(nodes_to_debug_info_func, converted_graph):
|
| 391 |
+
"""Returns the debug info for the original nodes in the `converted_graph`.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
nodes_to_debug_info_func: The method to collect the op debug info for the
|
| 395 |
+
nodes.
|
| 396 |
+
converted_graph: A `GraphDef` after optimization and transformation.
|
| 397 |
+
|
| 398 |
+
Returns:
|
| 399 |
+
`GraphDebugInfo` for all the original nodes in `converted_graph`.
|
| 400 |
+
"""
|
| 401 |
+
if not nodes_to_debug_info_func:
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
# Collect all the debug info nodes from the converted_graph
|
| 405 |
+
original_nodes = set()
|
| 406 |
+
for node in converted_graph.node:
|
| 407 |
+
debug_nodes = node.experimental_debug_info.original_node_names
|
| 408 |
+
debug_funcs = node.experimental_debug_info.original_func_names
|
| 409 |
+
# If the `original_node_names` are empty, uses the node name directly.
|
| 410 |
+
if not debug_nodes:
|
| 411 |
+
original_nodes.add(("", node.name))
|
| 412 |
+
else:
|
| 413 |
+
for i in range(len(debug_nodes)):
|
| 414 |
+
debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
|
| 415 |
+
original_nodes.add((debug_func, debug_nodes[i]))
|
| 416 |
+
|
| 417 |
+
# Convert the nodes to the debug info proto object.
|
| 418 |
+
return nodes_to_debug_info_func(original_nodes)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def convert_bytes_to_c_source(data,
|
| 422 |
+
array_name,
|
| 423 |
+
max_line_width=80,
|
| 424 |
+
include_guard=None,
|
| 425 |
+
include_path=None,
|
| 426 |
+
use_tensorflow_license=False):
|
| 427 |
+
"""Returns strings representing a C constant array containing `data`.
|
| 428 |
+
|
| 429 |
+
Args:
|
| 430 |
+
data: Byte array that will be converted into a C constant.
|
| 431 |
+
array_name: String to use as the variable name for the constant array.
|
| 432 |
+
max_line_width: The longest line length, for formatting purposes.
|
| 433 |
+
include_guard: Name to use for the include guard macro definition.
|
| 434 |
+
include_path: Optional path to include in the source file.
|
| 435 |
+
use_tensorflow_license: Whether to include the standard TensorFlow Apache2
|
| 436 |
+
license in the generated files.
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
Text that can be compiled as a C source file to link in the data as a
|
| 440 |
+
literal array of values.
|
| 441 |
+
Text that can be used as a C header file to reference the literal array.
|
| 442 |
+
"""
|
| 443 |
+
|
| 444 |
+
starting_pad = " "
|
| 445 |
+
array_lines = []
|
| 446 |
+
array_line = starting_pad
|
| 447 |
+
for value in bytearray(data):
|
| 448 |
+
if (len(array_line) + 4) > max_line_width:
|
| 449 |
+
array_lines.append(array_line + "\n")
|
| 450 |
+
array_line = starting_pad
|
| 451 |
+
array_line += " 0x%02x," % (value,)
|
| 452 |
+
if len(array_line) > len(starting_pad):
|
| 453 |
+
array_lines.append(array_line + "\n")
|
| 454 |
+
array_values = "".join(array_lines)
|
| 455 |
+
|
| 456 |
+
if include_guard is None:
|
| 457 |
+
include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
|
| 458 |
+
|
| 459 |
+
if include_path is not None:
|
| 460 |
+
include_line = "#include \"{include_path}\"\n".format(
|
| 461 |
+
include_path=include_path)
|
| 462 |
+
else:
|
| 463 |
+
include_line = ""
|
| 464 |
+
|
| 465 |
+
if use_tensorflow_license:
|
| 466 |
+
license_text = """
|
| 467 |
+
/* Copyright {year} The TensorFlow Authors. All Rights Reserved.
|
| 468 |
+
|
| 469 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 470 |
+
you may not use this file except in compliance with the License.
|
| 471 |
+
You may obtain a copy of the License at
|
| 472 |
+
|
| 473 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 474 |
+
|
| 475 |
+
Unless required by applicable law or agreed to in writing, software
|
| 476 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 477 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 478 |
+
See the License for the specific language governing permissions and
|
| 479 |
+
limitations under the License.
|
| 480 |
+
==============================================================================*/
|
| 481 |
+
""".format(year=datetime.date.today().year)
|
| 482 |
+
else:
|
| 483 |
+
license_text = ""
|
| 484 |
+
|
| 485 |
+
source_template = """{license_text}
|
| 486 |
+
// This is a TensorFlow Lite model file that has been converted into a C data
|
| 487 |
+
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
|
| 488 |
+
// This form is useful for compiling into a binary for devices that don't have a
|
| 489 |
+
// file system.
|
| 490 |
+
|
| 491 |
+
{include_line}
|
| 492 |
+
// We need to keep the data array aligned on some architectures.
|
| 493 |
+
#ifdef __has_attribute
|
| 494 |
+
#define HAVE_ATTRIBUTE(x) __has_attribute(x)
|
| 495 |
+
#else
|
| 496 |
+
#define HAVE_ATTRIBUTE(x) 0
|
| 497 |
+
#endif
|
| 498 |
+
#if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
|
| 499 |
+
#define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
|
| 500 |
+
#else
|
| 501 |
+
#define DATA_ALIGN_ATTRIBUTE
|
| 502 |
+
#endif
|
| 503 |
+
|
| 504 |
+
const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
|
| 505 |
+
{array_values}}};
|
| 506 |
+
const int {array_name}_len = {array_length};
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
source_text = source_template.format(
|
| 510 |
+
array_name=array_name,
|
| 511 |
+
array_length=len(data),
|
| 512 |
+
array_values=array_values,
|
| 513 |
+
license_text=license_text,
|
| 514 |
+
include_line=include_line)
|
| 515 |
+
|
| 516 |
+
header_template = """
|
| 517 |
+
{license_text}
|
| 518 |
+
|
| 519 |
+
// This is a TensorFlow Lite model file that has been converted into a C data
|
| 520 |
+
// array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
|
| 521 |
+
// This form is useful for compiling into a binary for devices that don't have a
|
| 522 |
+
// file system.
|
| 523 |
+
|
| 524 |
+
#ifndef {include_guard}
|
| 525 |
+
#define {include_guard}
|
| 526 |
+
|
| 527 |
+
extern const unsigned char {array_name}[];
|
| 528 |
+
extern const int {array_name}_len;
|
| 529 |
+
|
| 530 |
+
#endif // {include_guard}
|
| 531 |
+
"""
|
| 532 |
+
|
| 533 |
+
header_text = header_template.format(
|
| 534 |
+
array_name=array_name,
|
| 535 |
+
include_guard=include_guard,
|
| 536 |
+
license_text=license_text)
|
| 537 |
+
|
| 538 |
+
return source_text, header_text
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def _convert_model_from_bytearray_to_object(model_bytearray):
|
| 542 |
+
"""Converts a tflite model from a bytearray into a parsable object."""
|
| 543 |
+
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
|
| 544 |
+
model_object = schema_fb.ModelT.InitFromObj(model_object)
|
| 545 |
+
model_object = copy.deepcopy(model_object)
|
| 546 |
+
return model_object
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def _convert_model_from_object_to_bytearray(model_object):
|
| 550 |
+
"""Converts a tflite model from a parsable object into a bytearray."""
|
| 551 |
+
# Initial size of the buffer, which will grow automatically if needed
|
| 552 |
+
builder = flatbuffers.Builder(1024)
|
| 553 |
+
model_offset = model_object.Pack(builder)
|
| 554 |
+
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
| 555 |
+
return bytes(builder.Output())
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def get_quantize_opcode_idx(model):
|
| 559 |
+
"""Returns the quantize op idx."""
|
| 560 |
+
quant_opcode_idxs = []
|
| 561 |
+
for idx, opcode in enumerate(model.operatorCodes):
|
| 562 |
+
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
| 563 |
+
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
| 564 |
+
quant_opcode_idxs.append(idx)
|
| 565 |
+
return quant_opcode_idxs
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def get_dequantize_opcode_idx(model):
|
| 569 |
+
"""Returns the quantize op idx."""
|
| 570 |
+
quant_opcode_idxs = []
|
| 571 |
+
for idx, opcode in enumerate(model.operatorCodes):
|
| 572 |
+
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
| 573 |
+
if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
|
| 574 |
+
quant_opcode_idxs.append(idx)
|
| 575 |
+
return quant_opcode_idxs
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
|
| 579 |
+
"""Update the tensors in the SignatureDef's TensorMaps."""
|
| 580 |
+
for i in range(len(tensor_maps)):
|
| 581 |
+
if tensor_maps[i].tensorIndex in map_old_to_new_tensors:
|
| 582 |
+
tensor_maps[i].tensorIndex = (
|
| 583 |
+
map_old_to_new_tensors[tensor_maps[i].tensorIndex])
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def _remove_tensors_from_model(model, remove_tensors_idxs):
|
| 587 |
+
"""Remove tensors from model."""
|
| 588 |
+
if not remove_tensors_idxs:
|
| 589 |
+
return
|
| 590 |
+
if len(model.subgraphs) > 1:
|
| 591 |
+
logging.info("Skipping the removal of dangled tensors since the model has "
|
| 592 |
+
"multiple subgraphs and tensors can be used in the different "
|
| 593 |
+
"subgraph(s)")
|
| 594 |
+
return
|
| 595 |
+
subgraph = model.subgraphs[0]
|
| 596 |
+
tensors = subgraph.tensors
|
| 597 |
+
operators = subgraph.operators
|
| 598 |
+
|
| 599 |
+
logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
|
| 600 |
+
# An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
|
| 601 |
+
# exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
|
| 602 |
+
if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
|
| 603 |
+
logging.debug("Removing tensors only at the end of the tensor list")
|
| 604 |
+
del tensors[min(remove_tensors_idxs):]
|
| 605 |
+
else:
|
| 606 |
+
logging.debug("Removing tensors requires updating the model")
|
| 607 |
+
# Map the old tensor indices to new tensor indices
|
| 608 |
+
d_old_to_new_tensors = {}
|
| 609 |
+
left_shift_by = 0
|
| 610 |
+
for idx in range(len(tensors)):
|
| 611 |
+
if idx in remove_tensors_idxs:
|
| 612 |
+
left_shift_by += 1
|
| 613 |
+
else:
|
| 614 |
+
d_old_to_new_tensors[idx] = idx - left_shift_by
|
| 615 |
+
logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
|
| 616 |
+
# Update tensor indices referenced throughout the model
|
| 617 |
+
def update_tensors(tensor_idxs):
|
| 618 |
+
for i, ti in enumerate(tensor_idxs):
|
| 619 |
+
tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
|
| 620 |
+
update_tensors(subgraph.inputs)
|
| 621 |
+
update_tensors(subgraph.outputs)
|
| 622 |
+
for op in operators:
|
| 623 |
+
update_tensors(op.inputs)
|
| 624 |
+
update_tensors(op.outputs)
|
| 625 |
+
if model.signatureDefs:
|
| 626 |
+
signature_def = model.signatureDefs[0]
|
| 627 |
+
_update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors)
|
| 628 |
+
_update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors)
|
| 629 |
+
# Delete the tensors
|
| 630 |
+
for idx in sorted(remove_tensors_idxs, reverse=True):
|
| 631 |
+
tensors.pop(idx)
|
| 632 |
+
logging.debug("Removed tensors marked for deletion")
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
def _modify_model_input_type(model, inference_input_type=dtypes.float32):
|
| 636 |
+
"""Modify model input type."""
|
| 637 |
+
if inference_input_type == dtypes.float32:
|
| 638 |
+
return
|
| 639 |
+
|
| 640 |
+
if not model.signatureDefs:
|
| 641 |
+
_modify_model_input_type_per_subgraph(model, 0, -1, inference_input_type)
|
| 642 |
+
return
|
| 643 |
+
|
| 644 |
+
for signature_index, signature_def in enumerate(model.signatureDefs):
|
| 645 |
+
_modify_model_input_type_per_subgraph(model, signature_def.subgraphIndex,
|
| 646 |
+
signature_index, inference_input_type)
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def _modify_model_input_type_per_subgraph(model, subgraph_index,
|
| 650 |
+
signature_index,
|
| 651 |
+
inference_input_type):
|
| 652 |
+
"""Modify model input type per subgraph."""
|
| 653 |
+
subgraph = model.subgraphs[subgraph_index]
|
| 654 |
+
tensors = subgraph.tensors
|
| 655 |
+
operators = subgraph.operators
|
| 656 |
+
|
| 657 |
+
# Find all quantize operators
|
| 658 |
+
quant_opcode_idxs = get_quantize_opcode_idx(model)
|
| 659 |
+
if operators and not quant_opcode_idxs:
|
| 660 |
+
for input_idx in subgraph.inputs:
|
| 661 |
+
input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
|
| 662 |
+
if input_type == dtypes.float32:
|
| 663 |
+
raise ValueError("Model input is not dequantized.")
|
| 664 |
+
# None of the inputs have float32, then they must be int16, int8, or bool
|
| 665 |
+
return
|
| 666 |
+
|
| 667 |
+
# Validate that the model input is quantized
|
| 668 |
+
input_quant_ops = []
|
| 669 |
+
for op in operators:
|
| 670 |
+
# Find operators that quantize model input
|
| 671 |
+
if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
|
| 672 |
+
float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
| 673 |
+
# If found, validate that the operator's input type is float
|
| 674 |
+
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
| 675 |
+
if float_type != dtypes.float32:
|
| 676 |
+
if float_type == inference_input_type:
|
| 677 |
+
continue
|
| 678 |
+
else:
|
| 679 |
+
raise ValueError(
|
| 680 |
+
"Initial model input type must be tf.float32. Expected type for "
|
| 681 |
+
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
| 682 |
+
float_tensor.name, get_tf_type_name(float_type)))
|
| 683 |
+
# If found, validate that the operator output is quantized and compatible
|
| 684 |
+
# with the final model input type
|
| 685 |
+
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
| 686 |
+
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
| 687 |
+
raise ValueError(
|
| 688 |
+
"Initial model input is not quantized. Expected type for "
|
| 689 |
+
"tensor with name '{}' should be in {}, instead type is {}".format(
|
| 690 |
+
quant_tensor.name,
|
| 691 |
+
tuple(get_tf_type_name(t) for t in
|
| 692 |
+
_MAP_QUANT_TO_IO_TYPES.keys()),
|
| 693 |
+
get_tf_type_name(quant_type)))
|
| 694 |
+
else:
|
| 695 |
+
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
| 696 |
+
if inference_input_type not in inference_io_types:
|
| 697 |
+
raise ValueError(
|
| 698 |
+
"Unsupported `inference_input_type` value. Expected to be in "
|
| 699 |
+
"{}, instead got {}.".format(
|
| 700 |
+
tuple(get_tf_type_name(t) for t in inference_io_types),
|
| 701 |
+
get_tf_type_name(inference_input_type)))
|
| 702 |
+
input_quant_ops.append(op)
|
| 703 |
+
|
| 704 |
+
if len(subgraph.inputs) != len(input_quant_ops):
|
| 705 |
+
logging.warning(
|
| 706 |
+
"For model inputs containing unsupported operations which cannot be "
|
| 707 |
+
"quantized, the `inference_input_type` attribute will default to the "
|
| 708 |
+
"original type."
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# Modify model input type
|
| 712 |
+
if inference_input_type == dtypes.uint8:
|
| 713 |
+
# Change quant op (float to int8) to quant op (uint8 to int8)
|
| 714 |
+
for op in input_quant_ops:
|
| 715 |
+
int8_quantization = tensors[op.outputs[0]].quantization
|
| 716 |
+
uint8_quantization = schema_fb.QuantizationParametersT()
|
| 717 |
+
uint8_quantization.scale = [int8_quantization.scale[0]]
|
| 718 |
+
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
| 719 |
+
tensors[op.inputs[0]].quantization = uint8_quantization
|
| 720 |
+
tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
|
| 721 |
+
elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
|
| 722 |
+
# Remove the inputs and the quant operator
|
| 723 |
+
remove_tensors_idxs = set()
|
| 724 |
+
for op in input_quant_ops:
|
| 725 |
+
subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
|
| 726 |
+
if signature_index >= 0:
|
| 727 |
+
signature_def = model.signatureDefs[signature_index]
|
| 728 |
+
for i in range(len(signature_def.inputs)):
|
| 729 |
+
if signature_def.inputs[i].tensorIndex == op.inputs[0]:
|
| 730 |
+
signature_def.inputs[i].tensorIndex = op.outputs[0]
|
| 731 |
+
remove_tensors_idxs.add(op.inputs[0])
|
| 732 |
+
operators.remove(op)
|
| 733 |
+
# Remove tensors marked for deletion.
|
| 734 |
+
_remove_tensors_from_model(model, remove_tensors_idxs)
|
| 735 |
+
else:
|
| 736 |
+
raise ValueError(
|
| 737 |
+
"Unsupported `inference_input_type` value {}.".format(
|
| 738 |
+
get_tf_type_name(inference_input_type)))
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def _modify_model_output_type(model, inference_output_type=dtypes.float32):
|
| 742 |
+
"""Modify model output type."""
|
| 743 |
+
if inference_output_type == dtypes.float32:
|
| 744 |
+
return
|
| 745 |
+
|
| 746 |
+
if not model.signatureDefs:
|
| 747 |
+
_modify_model_output_type_per_subgraph(model, 0, -1, inference_output_type)
|
| 748 |
+
return
|
| 749 |
+
|
| 750 |
+
for signature_index, signature_def in enumerate(model.signatureDefs):
|
| 751 |
+
_modify_model_output_type_per_subgraph(model, signature_def.subgraphIndex,
|
| 752 |
+
signature_index,
|
| 753 |
+
inference_output_type)
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
def _modify_model_output_type_per_subgraph(model, subgraph_index,
|
| 757 |
+
signature_index,
|
| 758 |
+
inference_output_type):
|
| 759 |
+
"""Modify model output type per subgraph."""
|
| 760 |
+
subgraph = model.subgraphs[subgraph_index]
|
| 761 |
+
tensors = subgraph.tensors
|
| 762 |
+
operators = subgraph.operators
|
| 763 |
+
|
| 764 |
+
# Find all dequantize operators
|
| 765 |
+
dequant_opcode_idxs = get_dequantize_opcode_idx(model)
|
| 766 |
+
if operators and not dequant_opcode_idxs:
|
| 767 |
+
for output in subgraph.outputs:
|
| 768 |
+
output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
|
| 769 |
+
if output_type == dtypes.float32:
|
| 770 |
+
raise ValueError("Model output is not dequantized.")
|
| 771 |
+
# None of the outputs have float32, then they must be int16, int8, or bool
|
| 772 |
+
return
|
| 773 |
+
|
| 774 |
+
# Validate that the model output is dequantized
|
| 775 |
+
output_dequant_ops = []
|
| 776 |
+
for op in operators:
|
| 777 |
+
# Find operators that dequantize model output
|
| 778 |
+
if (op.opcodeIndex in dequant_opcode_idxs and
|
| 779 |
+
op.outputs[0] in subgraph.outputs):
|
| 780 |
+
# If found, validate that the operator's output type is float
|
| 781 |
+
quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
|
| 782 |
+
float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
|
| 783 |
+
if float_type != dtypes.float32:
|
| 784 |
+
if float_type == inference_output_type:
|
| 785 |
+
continue
|
| 786 |
+
else:
|
| 787 |
+
raise ValueError(
|
| 788 |
+
"Initial model output type must be tf.float32. Expected type for "
|
| 789 |
+
"tensor with name '{}' is tf.float32, instead type is {}".format(
|
| 790 |
+
float_tensor.name, get_tf_type_name(float_type)))
|
| 791 |
+
# If found, validate that the operator input is quantized and compatible
|
| 792 |
+
# with the final model output type
|
| 793 |
+
quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
|
| 794 |
+
if quant_type not in _MAP_QUANT_TO_IO_TYPES:
|
| 795 |
+
raise ValueError(
|
| 796 |
+
"Initial model output is not dequantized. Expected type for "
|
| 797 |
+
"tensor with name '{}' should be in {}, instead type is {}".format(
|
| 798 |
+
quant_tensor.name,
|
| 799 |
+
tuple(get_tf_type_name(t) for t in
|
| 800 |
+
_MAP_QUANT_TO_IO_TYPES.keys()),
|
| 801 |
+
get_tf_type_name(quant_type)))
|
| 802 |
+
else:
|
| 803 |
+
inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
|
| 804 |
+
if inference_output_type not in inference_io_types:
|
| 805 |
+
raise ValueError(
|
| 806 |
+
"Unsupported `inference_output_type` value. Expected to be in "
|
| 807 |
+
"{}, instead got {}.".format(
|
| 808 |
+
tuple(get_tf_type_name(t) for t in inference_io_types),
|
| 809 |
+
get_tf_type_name(inference_output_type)))
|
| 810 |
+
output_dequant_ops.append(op)
|
| 811 |
+
|
| 812 |
+
if len(subgraph.outputs) != len(output_dequant_ops):
|
| 813 |
+
logging.warning(
|
| 814 |
+
"For model outputs containing unsupported operations which cannot be "
|
| 815 |
+
"quantized, the `inference_output_type` attribute will default to the "
|
| 816 |
+
"original type."
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
# Modify model output type
|
| 820 |
+
if inference_output_type == dtypes.uint8:
|
| 821 |
+
# Find a quantize operator
|
| 822 |
+
quant_opcode_idx = -1
|
| 823 |
+
for idx, opcode in enumerate(model.operatorCodes):
|
| 824 |
+
builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
|
| 825 |
+
if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
|
| 826 |
+
quant_opcode_idx = idx
|
| 827 |
+
break
|
| 828 |
+
# Create a quantize operator, if none exist
|
| 829 |
+
if quant_opcode_idx == -1:
|
| 830 |
+
quant_op = schema_fb.OperatorCodeT()
|
| 831 |
+
quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
|
| 832 |
+
quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
|
| 833 |
+
model.operatorCodes.append(quant_op)
|
| 834 |
+
quant_opcode_idx = len(model.operatorCodes) - 1
|
| 835 |
+
# Change dequant op (int8 to float) to quant op (int8 to uint8)
|
| 836 |
+
for op in output_dequant_ops:
|
| 837 |
+
op.opcodeIndex = quant_opcode_idx
|
| 838 |
+
int8_quantization = tensors[op.inputs[0]].quantization
|
| 839 |
+
uint8_quantization = schema_fb.QuantizationParametersT()
|
| 840 |
+
uint8_quantization.scale = [int8_quantization.scale[0]]
|
| 841 |
+
uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
|
| 842 |
+
tensors[op.outputs[0]].quantization = uint8_quantization
|
| 843 |
+
tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
|
| 844 |
+
elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
|
| 845 |
+
# Remove the outputs and the dequant operator
|
| 846 |
+
remove_tensors_idxs = set()
|
| 847 |
+
for op in output_dequant_ops:
|
| 848 |
+
subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
|
| 849 |
+
if signature_index >= 0:
|
| 850 |
+
signature_def = model.signatureDefs[signature_index]
|
| 851 |
+
for i in range(len(signature_def.outputs)):
|
| 852 |
+
if signature_def.outputs[i].tensorIndex == op.outputs[0]:
|
| 853 |
+
signature_def.outputs[i].tensorIndex = op.inputs[0]
|
| 854 |
+
remove_tensors_idxs.add(op.outputs[0])
|
| 855 |
+
operators.remove(op)
|
| 856 |
+
# Remove tensors marked for deletion.
|
| 857 |
+
_remove_tensors_from_model(model, remove_tensors_idxs)
|
| 858 |
+
else:
|
| 859 |
+
raise ValueError(
|
| 860 |
+
"Unsupported `inference_output_type` value {}.".format(
|
| 861 |
+
get_tf_type_name(inference_output_type)))
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
def _remove_redundant_quantize_ops(model):
|
| 865 |
+
"""Finds back to back quantize ops and remove the first quantize op."""
|
| 866 |
+
if not model.signatureDefs:
|
| 867 |
+
_remove_redundant_quantize_ops_per_subgraph(model, 0, -1)
|
| 868 |
+
return
|
| 869 |
+
|
| 870 |
+
for signature_index, signature_def in enumerate(model.signatureDefs):
|
| 871 |
+
_remove_redundant_quantize_ops_per_subgraph(model,
|
| 872 |
+
signature_def.subgraphIndex,
|
| 873 |
+
signature_index)
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
def _remove_redundant_quantize_ops_per_subgraph(model, subgraph_index,
|
| 877 |
+
signature_index):
|
| 878 |
+
"""Remove redundant quantize ops per subgraph."""
|
| 879 |
+
subgraph = model.subgraphs[subgraph_index]
|
| 880 |
+
tensors = subgraph.tensors
|
| 881 |
+
operators = subgraph.operators
|
| 882 |
+
|
| 883 |
+
# Find all quantize operators.
|
| 884 |
+
quant_opcode_idxs = get_quantize_opcode_idx(model)
|
| 885 |
+
dequant_opcode_idxs = get_dequantize_opcode_idx(model)
|
| 886 |
+
|
| 887 |
+
# Find all redundant quant tensors.
|
| 888 |
+
all_quant_ops = []
|
| 889 |
+
redundant_quant_tensors = {}
|
| 890 |
+
output_dequant_tensors = {}
|
| 891 |
+
for op in operators:
|
| 892 |
+
if op.opcodeIndex in quant_opcode_idxs:
|
| 893 |
+
all_quant_ops.append(op)
|
| 894 |
+
input_tensor = tensors[op.inputs[0]]
|
| 895 |
+
output_tensor = tensors[op.outputs[0]]
|
| 896 |
+
input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type)
|
| 897 |
+
output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type)
|
| 898 |
+
# This is a requantize op, so write down its input tensor index.
|
| 899 |
+
if input_type != dtypes.float32 and output_type != dtypes.float32:
|
| 900 |
+
redundant_quant_tensors[op.inputs[0]] = op
|
| 901 |
+
if (op.opcodeIndex in dequant_opcode_idxs and
|
| 902 |
+
op.outputs[0] in subgraph.outputs):
|
| 903 |
+
output_dequant_tensors[op.inputs[0]] = op
|
| 904 |
+
|
| 905 |
+
# Remove all the quant ops which produce the redundant quant tensors.
|
| 906 |
+
for op in all_quant_ops:
|
| 907 |
+
output_tensor_idx = op.outputs[0]
|
| 908 |
+
if output_tensor_idx in redundant_quant_tensors:
|
| 909 |
+
requantize_op = redundant_quant_tensors[output_tensor_idx]
|
| 910 |
+
if model.signatureDefs:
|
| 911 |
+
signature_def = model.signatureDefs[0]
|
| 912 |
+
for output in signature_def.outputs:
|
| 913 |
+
if output.tensorIndex == op.outputs[0]:
|
| 914 |
+
output.tensorIndex = op.inputs[0]
|
| 915 |
+
deleted_tensor = requantize_op.inputs[0]
|
| 916 |
+
# Reset the input of the requantize op to the float input
|
| 917 |
+
requantize_op.inputs[0] = op.inputs[0]
|
| 918 |
+
# Migrate other operator users to output tensor of requantize op
|
| 919 |
+
for op_user in operators:
|
| 920 |
+
if deleted_tensor in op_user.inputs and op_user != requantize_op:
|
| 921 |
+
for idx, input_tensor in enumerate(op_user.inputs):
|
| 922 |
+
if input_tensor == deleted_tensor:
|
| 923 |
+
op_user.inputs[idx] = requantize_op.outputs[0]
|
| 924 |
+
operators.remove(op)
|
| 925 |
+
|
| 926 |
+
# Remove all the quant ops which connect to the output dequant op.
|
| 927 |
+
for op in all_quant_ops:
|
| 928 |
+
output_tensor_idx = op.outputs[0]
|
| 929 |
+
if output_tensor_idx in output_dequant_tensors:
|
| 930 |
+
dequant_op = output_dequant_tensors[output_tensor_idx]
|
| 931 |
+
subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
|
| 932 |
+
if signature_index >= 0:
|
| 933 |
+
signature_def = model.signatureDefs[signature_index]
|
| 934 |
+
for output in signature_def.outputs:
|
| 935 |
+
if output.tensorIndex == dequant_op.outputs[0]:
|
| 936 |
+
output.tensorIndex = op.inputs[0]
|
| 937 |
+
operators.remove(op)
|
| 938 |
+
operators.remove(dequant_op)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
def modify_model_io_type(
|
| 942 |
+
model, inference_input_type=dtypes.float32,
|
| 943 |
+
inference_output_type=dtypes.float32):
|
| 944 |
+
"""Modify the input/output type of a tflite model.
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
model: A tflite model.
|
| 948 |
+
inference_input_type: tf.DType representing modified input type.
|
| 949 |
+
(default tf.float32. If model input is int8 quantized, it must be in
|
| 950 |
+
{tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
|
| 951 |
+
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
| 952 |
+
inference_output_type: tf.DType representing modified output type.
|
| 953 |
+
(default tf.float32. If model output is int8 dequantized, it must be in
|
| 954 |
+
{tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
|
| 955 |
+
it must be in {tf.float32, tf.int16}, else it must be tf.float32)
|
| 956 |
+
Returns:
|
| 957 |
+
A tflite model with modified input/output type.
|
| 958 |
+
|
| 959 |
+
Raises:
|
| 960 |
+
ValueError: If `inference_input_type`/`inference_output_type` is unsupported
|
| 961 |
+
or a supported integer type is specified for a model whose input/output is
|
| 962 |
+
not quantized/dequantized.
|
| 963 |
+
RuntimeError: If the modification was unsuccessful.
|
| 964 |
+
|
| 965 |
+
"""
|
| 966 |
+
if (inference_input_type == dtypes.float32 and
|
| 967 |
+
inference_output_type == dtypes.float32):
|
| 968 |
+
return model
|
| 969 |
+
|
| 970 |
+
model_object = _convert_model_from_bytearray_to_object(model)
|
| 971 |
+
|
| 972 |
+
_modify_model_input_type(model_object, inference_input_type)
|
| 973 |
+
|
| 974 |
+
_modify_model_output_type(model_object, inference_output_type)
|
| 975 |
+
|
| 976 |
+
_remove_redundant_quantize_ops(model_object)
|
| 977 |
+
|
| 978 |
+
return _convert_model_from_object_to_bytearray(model_object)
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def get_sparsity_modes(model_object):
|
| 982 |
+
"""Get sparsity modes used in a tflite model.
|
| 983 |
+
|
| 984 |
+
The sparsity modes are listed in conversion_metadata.fbs file.
|
| 985 |
+
|
| 986 |
+
Args:
|
| 987 |
+
model_object: A tflite model in object form.
|
| 988 |
+
|
| 989 |
+
Returns:
|
| 990 |
+
The list of sparsity modes used in the model.
|
| 991 |
+
"""
|
| 992 |
+
if not model_object or not model_object.metadata:
|
| 993 |
+
return []
|
| 994 |
+
|
| 995 |
+
result = set()
|
| 996 |
+
for subgraph in model_object.subgraphs:
|
| 997 |
+
for tensor in subgraph.tensors:
|
| 998 |
+
if not tensor.sparsity:
|
| 999 |
+
continue
|
| 1000 |
+
|
| 1001 |
+
# Block map is the list if indexes where the block size is larger than 1.
|
| 1002 |
+
# So empty block map means it is random sparsity.
|
| 1003 |
+
if not tensor.sparsity.blockMap:
|
| 1004 |
+
result.add(
|
| 1005 |
+
conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY)
|
| 1006 |
+
else:
|
| 1007 |
+
result.add(
|
| 1008 |
+
conversion_metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY)
|
| 1009 |
+
|
| 1010 |
+
return list(result)
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def get_model_hash(model):
|
| 1014 |
+
"""Calculate a 64-bit integer hash for a TensorFlow Lite model based on its structure.
|
| 1015 |
+
|
| 1016 |
+
Args:
|
| 1017 |
+
model: A TensorFlow Lite model object.
|
| 1018 |
+
|
| 1019 |
+
Returns:
|
| 1020 |
+
int: A 64-bit integer hash value representing the model structure.
|
| 1021 |
+
"""
|
| 1022 |
+
# TODO(b/344872922): Move the hashing implementation to C++ layer since not
|
| 1023 |
+
# all calls to the converter come via the Python API.
|
| 1024 |
+
hash_value = 0
|
| 1025 |
+
|
| 1026 |
+
for subgraph in model.subgraphs:
|
| 1027 |
+
if subgraph.operators is not None:
|
| 1028 |
+
hash_value = update_hash_with_primitive_value(
|
| 1029 |
+
hash_value, len(subgraph.operators)
|
| 1030 |
+
)
|
| 1031 |
+
|
| 1032 |
+
for operator in subgraph.operators:
|
| 1033 |
+
if operator.inputs is not None:
|
| 1034 |
+
hash_value = update_hash_with_array(hash_value, operator.inputs)
|
| 1035 |
+
|
| 1036 |
+
if operator.outputs is not None:
|
| 1037 |
+
hash_value = update_hash_with_array(hash_value, operator.outputs)
|
| 1038 |
+
|
| 1039 |
+
if subgraph.tensors is not None:
|
| 1040 |
+
hash_value = update_hash_with_primitive_value(
|
| 1041 |
+
hash_value, len(subgraph.tensors)
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
for tensor in subgraph.tensors:
|
| 1045 |
+
if tensor.buffer is not None:
|
| 1046 |
+
buffer = model.buffers[tensor.buffer]
|
| 1047 |
+
if buffer.data is not None:
|
| 1048 |
+
hash_value = update_hash_with_primitive_value(
|
| 1049 |
+
hash_value, len(buffer.data)
|
| 1050 |
+
)
|
| 1051 |
+
|
| 1052 |
+
if tensor.shape is not None:
|
| 1053 |
+
hash_value = update_hash_with_array(hash_value, tensor.shape)
|
| 1054 |
+
|
| 1055 |
+
if subgraph.inputs is not None:
|
| 1056 |
+
hash_value = update_hash_with_primitive_value(
|
| 1057 |
+
hash_value, len(subgraph.inputs)
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
if subgraph.outputs is not None:
|
| 1061 |
+
hash_value = update_hash_with_primitive_value(
|
| 1062 |
+
hash_value, len(subgraph.outputs)
|
| 1063 |
+
)
|
| 1064 |
+
|
| 1065 |
+
return hash_value
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
def update_hash_with_primitive_value(hash_value, value):
|
| 1069 |
+
"""Update the hash value using a primitive value.
|
| 1070 |
+
|
| 1071 |
+
Args:
|
| 1072 |
+
hash_value (uint64): The current hash value.
|
| 1073 |
+
value: The primitive value to incorporate into the hash.
|
| 1074 |
+
|
| 1075 |
+
Returns:
|
| 1076 |
+
int: The updated hash value.
|
| 1077 |
+
"""
|
| 1078 |
+
hash_const = np.uint64(0x9E3779B97F4A7800)
|
| 1079 |
+
hash_value = np.uint64(hash_value)
|
| 1080 |
+
value = np.uint64(value)
|
| 1081 |
+
|
| 1082 |
+
# Convert to arrays before shifting.
|
| 1083 |
+
hash_value = np.array([hash_value])
|
| 1084 |
+
value = np.array([value])
|
| 1085 |
+
|
| 1086 |
+
# Shift the values, then take the value from the first index.
|
| 1087 |
+
hash_value = np.bitwise_xor(
|
| 1088 |
+
hash_value,
|
| 1089 |
+
(
|
| 1090 |
+
value
|
| 1091 |
+
+ hash_const
|
| 1092 |
+
+ np.left_shift(hash_value, 10)
|
| 1093 |
+
+ np.right_shift(hash_value, 4)
|
| 1094 |
+
),
|
| 1095 |
+
)[0]
|
| 1096 |
+
|
| 1097 |
+
return hash_value
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
def update_hash_with_array(hash_value, int_array):
|
| 1101 |
+
"""Update the hash value using a TFLite int array.
|
| 1102 |
+
|
| 1103 |
+
Args:
|
| 1104 |
+
hash_value (int): The current hash value.
|
| 1105 |
+
int_array: A TFLite int array to incorporate into the hash.
|
| 1106 |
+
|
| 1107 |
+
Returns:
|
| 1108 |
+
int: The updated hash value.
|
| 1109 |
+
"""
|
| 1110 |
+
if int_array is not None:
|
| 1111 |
+
for i in int_array:
|
| 1112 |
+
hash_value = update_hash_with_primitive_value(hash_value, i)
|
| 1113 |
+
return hash_value
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
def populate_conversion_metadata(model_object, metadata):
|
| 1117 |
+
"""Add or update conversion metadata to a tflite model.
|
| 1118 |
+
|
| 1119 |
+
Args:
|
| 1120 |
+
model_object: A tflite model in object form.
|
| 1121 |
+
metadata: The conversion metadata.
|
| 1122 |
+
|
| 1123 |
+
Returns:
|
| 1124 |
+
A tflite model object with embedded conversion metadata.
|
| 1125 |
+
"""
|
| 1126 |
+
try:
|
| 1127 |
+
metadata_builder = flatbuffers.Builder(0)
|
| 1128 |
+
metadata_builder.Finish(metadata.Pack(metadata_builder))
|
| 1129 |
+
buffer_field = schema_fb.BufferT()
|
| 1130 |
+
buffer_field.data = metadata_builder.Output()
|
| 1131 |
+
|
| 1132 |
+
if not model_object.metadata:
|
| 1133 |
+
model_object.metadata = []
|
| 1134 |
+
else:
|
| 1135 |
+
# Check if metadata has already been populated.
|
| 1136 |
+
for meta in model_object.metadata:
|
| 1137 |
+
if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
|
| 1138 |
+
model_object.buffers[meta.buffer] = buffer_field
|
| 1139 |
+
return model_object
|
| 1140 |
+
|
| 1141 |
+
if not model_object.buffers:
|
| 1142 |
+
model_object.buffers = []
|
| 1143 |
+
model_object.buffers.append(buffer_field)
|
| 1144 |
+
# Creates a new metadata field.
|
| 1145 |
+
metadata_field = schema_fb.MetadataT()
|
| 1146 |
+
metadata_field.name = CONVERSION_METADATA_FIELD_NAME
|
| 1147 |
+
metadata_field.buffer = len(model_object.buffers) - 1
|
| 1148 |
+
model_object.metadata.append(metadata_field)
|
| 1149 |
+
|
| 1150 |
+
return model_object
|
| 1151 |
+
except Exception: # pylint: disable=broad-except
|
| 1152 |
+
return model_object
|
| 1153 |
+
|
| 1154 |
+
|
| 1155 |
+
def get_conversion_metadata(model_buffer):
|
| 1156 |
+
"""Read conversion metadata from a tflite model.
|
| 1157 |
+
|
| 1158 |
+
Args:
|
| 1159 |
+
model_buffer: A tflite model.
|
| 1160 |
+
|
| 1161 |
+
Returns:
|
| 1162 |
+
The conversion metadata or None if it is not populated.
|
| 1163 |
+
"""
|
| 1164 |
+
model_object = flatbuffer_utils.convert_bytearray_to_object(model_buffer)
|
| 1165 |
+
if not model_object or not model_object.metadata:
|
| 1166 |
+
return None
|
| 1167 |
+
|
| 1168 |
+
for meta in model_object.metadata:
|
| 1169 |
+
if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
|
| 1170 |
+
metadata_buf = model_object.buffers[meta.buffer].data.tobytes()
|
| 1171 |
+
return conversion_metadata_fb.ConversionMetadataT.InitFromObj(
|
| 1172 |
+
conversion_metadata_fb.ConversionMetadata.GetRootAsConversionMetadata(
|
| 1173 |
+
metadata_buf, 0
|
| 1174 |
+
)
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
return None
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/flatbuffer_utils.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/visualize.cpython-310.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/flatbuffer_utils.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utility functions for FlatBuffers.
|
| 16 |
+
|
| 17 |
+
All functions that are commonly used to work with FlatBuffers.
|
| 18 |
+
|
| 19 |
+
Refer to the tensorflow lite flatbuffer schema here:
|
| 20 |
+
tensorflow/lite/schema/schema.fbs
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import copy
|
| 24 |
+
import random
|
| 25 |
+
import re
|
| 26 |
+
import struct
|
| 27 |
+
import sys
|
| 28 |
+
|
| 29 |
+
import flatbuffers
|
| 30 |
+
|
| 31 |
+
from tensorflow.lite.python import schema_py_generated as schema_fb
|
| 32 |
+
from tensorflow.lite.python import schema_util
|
| 33 |
+
from tensorflow.python.platform import gfile
|
| 34 |
+
|
| 35 |
+
_TFLITE_FILE_IDENTIFIER = b'TFL3'
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def convert_bytearray_to_object(model_bytearray):
|
| 39 |
+
"""Converts a tflite model from a bytearray to an object for parsing."""
|
| 40 |
+
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
|
| 41 |
+
return schema_fb.ModelT.InitFromObj(model_object)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def read_model(input_tflite_file):
|
| 45 |
+
"""Reads a tflite model as a python object.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
input_tflite_file: Full path name to the input tflite file
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
RuntimeError: If input_tflite_file path is invalid.
|
| 52 |
+
IOError: If input_tflite_file cannot be opened.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
A python object corresponding to the input tflite file.
|
| 56 |
+
"""
|
| 57 |
+
if not gfile.Exists(input_tflite_file):
|
| 58 |
+
raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
|
| 59 |
+
with gfile.GFile(input_tflite_file, 'rb') as input_file_handle:
|
| 60 |
+
model_bytearray = bytearray(input_file_handle.read())
|
| 61 |
+
return read_model_from_bytearray(model_bytearray)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def read_model_from_bytearray(model_bytearray):
|
| 65 |
+
"""Reads a tflite model as a python object.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
model_bytearray: TFLite model in bytearray format.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
A python object corresponding to the input tflite file.
|
| 72 |
+
"""
|
| 73 |
+
model = convert_bytearray_to_object(model_bytearray)
|
| 74 |
+
if sys.byteorder == 'big':
|
| 75 |
+
byte_swap_tflite_model_obj(model, 'little', 'big')
|
| 76 |
+
|
| 77 |
+
# Offset handling for models > 2GB
|
| 78 |
+
for buffer in model.buffers:
|
| 79 |
+
if buffer.offset:
|
| 80 |
+
buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size]
|
| 81 |
+
buffer.offset = 0
|
| 82 |
+
buffer.size = 0
|
| 83 |
+
for subgraph in model.subgraphs:
|
| 84 |
+
for op in subgraph.operators:
|
| 85 |
+
if op.largeCustomOptionsOffset:
|
| 86 |
+
op.customOptions = model_bytearray[
|
| 87 |
+
op.largeCustomOptionsOffset : op.largeCustomOptionsOffset
|
| 88 |
+
+ op.largeCustomOptionsSize
|
| 89 |
+
]
|
| 90 |
+
op.largeCustomOptionsOffset = 0
|
| 91 |
+
op.largeCustomOptionsSize = 0
|
| 92 |
+
|
| 93 |
+
return model
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def read_model_with_mutable_tensors(input_tflite_file):
|
| 97 |
+
"""Reads a tflite model as a python object with mutable tensors.
|
| 98 |
+
|
| 99 |
+
Similar to read_model() with the addition that the returned object has
|
| 100 |
+
mutable tensors (read_model() returns an object with immutable tensors).
|
| 101 |
+
|
| 102 |
+
NOTE: This API only works for TFLite generated with
|
| 103 |
+
_experimental_use_buffer_offset=false
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
input_tflite_file: Full path name to the input tflite file
|
| 107 |
+
|
| 108 |
+
Raises:
|
| 109 |
+
RuntimeError: If input_tflite_file path is invalid.
|
| 110 |
+
IOError: If input_tflite_file cannot be opened.
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
A mutable python object corresponding to the input tflite file.
|
| 114 |
+
"""
|
| 115 |
+
return copy.deepcopy(read_model(input_tflite_file))
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def convert_object_to_bytearray(model_object, extra_buffer=b''):
|
| 119 |
+
"""Converts a tflite model from an object to a immutable bytearray."""
|
| 120 |
+
# Initial size of the buffer, which will grow automatically if needed
|
| 121 |
+
builder = flatbuffers.Builder(1024)
|
| 122 |
+
model_offset = model_object.Pack(builder)
|
| 123 |
+
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
|
| 124 |
+
model_bytearray = bytes(builder.Output())
|
| 125 |
+
model_bytearray = model_bytearray + extra_buffer
|
| 126 |
+
return model_bytearray
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def write_model(model_object, output_tflite_file):
|
| 130 |
+
"""Writes the tflite model, a python object, into the output file.
|
| 131 |
+
|
| 132 |
+
NOTE: This API only works for TFLite generated with
|
| 133 |
+
_experimental_use_buffer_offset=false
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
model_object: A tflite model as a python object
|
| 137 |
+
output_tflite_file: Full path name to the output tflite file.
|
| 138 |
+
|
| 139 |
+
Raises:
|
| 140 |
+
IOError: If output_tflite_file path is invalid or cannot be opened.
|
| 141 |
+
"""
|
| 142 |
+
if sys.byteorder == 'big':
|
| 143 |
+
model_object = copy.deepcopy(model_object)
|
| 144 |
+
byte_swap_tflite_model_obj(model_object, 'big', 'little')
|
| 145 |
+
model_bytearray = convert_object_to_bytearray(model_object)
|
| 146 |
+
with gfile.GFile(output_tflite_file, 'wb') as output_file_handle:
|
| 147 |
+
output_file_handle.write(model_bytearray)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def strip_strings(model):
|
| 151 |
+
"""Strips all nonessential strings from the model to reduce model size.
|
| 152 |
+
|
| 153 |
+
We remove the following strings:
|
| 154 |
+
(find strings by searching ":string" in the tensorflow lite flatbuffer schema)
|
| 155 |
+
1. Model description
|
| 156 |
+
2. SubGraph name
|
| 157 |
+
3. Tensor names
|
| 158 |
+
We retain OperatorCode custom_code and Metadata name.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
model: The model from which to remove nonessential strings.
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
model.description = None
|
| 165 |
+
for subgraph in model.subgraphs:
|
| 166 |
+
subgraph.name = None
|
| 167 |
+
for tensor in subgraph.tensors:
|
| 168 |
+
tensor.name = None
|
| 169 |
+
# We clear all signature_def structure, since without names it is useless.
|
| 170 |
+
model.signatureDefs = None
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def type_to_name(tensor_type):
|
| 174 |
+
"""Converts a numerical enum to a readable tensor type."""
|
| 175 |
+
for name, value in schema_fb.TensorType.__dict__.items():
|
| 176 |
+
if value == tensor_type:
|
| 177 |
+
return name
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def randomize_weights(model, random_seed=0, buffers_to_skip=None):
|
| 182 |
+
"""Randomize weights in a model.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
model: The model in which to randomize weights.
|
| 186 |
+
random_seed: The input to the random number generator (default value is 0).
|
| 187 |
+
buffers_to_skip: The list of buffer indices to skip. The weights in these
|
| 188 |
+
buffers are left unmodified.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
# The input to the random seed generator. The default value is 0.
|
| 192 |
+
random.seed(random_seed)
|
| 193 |
+
|
| 194 |
+
# Parse model buffers which store the model weights
|
| 195 |
+
buffers = model.buffers
|
| 196 |
+
buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None
|
| 197 |
+
if buffers_to_skip is not None:
|
| 198 |
+
buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
|
| 199 |
+
|
| 200 |
+
buffer_types = {}
|
| 201 |
+
for graph in model.subgraphs:
|
| 202 |
+
for op in graph.operators:
|
| 203 |
+
if op.inputs is None:
|
| 204 |
+
break
|
| 205 |
+
for input_idx in op.inputs:
|
| 206 |
+
tensor = graph.tensors[input_idx]
|
| 207 |
+
buffer_types[tensor.buffer] = type_to_name(tensor.type)
|
| 208 |
+
|
| 209 |
+
for i in buffer_ids:
|
| 210 |
+
buffer_i_data = buffers[i].data
|
| 211 |
+
buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
|
| 212 |
+
if buffer_i_size == 0:
|
| 213 |
+
continue
|
| 214 |
+
|
| 215 |
+
# Raw data buffers are of type ubyte (or uint8) whose values lie in the
|
| 216 |
+
# range [0, 255]. Those ubytes (or unint8s) are the underlying
|
| 217 |
+
# representation of each datatype. For example, a bias tensor of type
|
| 218 |
+
# int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
|
| 219 |
+
# For floats, we need to generate a valid float and then pack it into
|
| 220 |
+
# the raw bytes in place.
|
| 221 |
+
buffer_type = buffer_types.get(i, 'INT8')
|
| 222 |
+
if buffer_type.startswith('FLOAT'):
|
| 223 |
+
format_code = 'e' if buffer_type == 'FLOAT16' else 'f'
|
| 224 |
+
for offset in range(0, buffer_i_size, struct.calcsize(format_code)):
|
| 225 |
+
value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2
|
| 226 |
+
struct.pack_into(format_code, buffer_i_data, offset, value)
|
| 227 |
+
else:
|
| 228 |
+
for j in range(buffer_i_size):
|
| 229 |
+
buffer_i_data[j] = random.randint(0, 255)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def rename_custom_ops(model, map_custom_op_renames):
|
| 233 |
+
"""Rename custom ops so they use the same naming style as builtin ops.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
model: The input tflite model.
|
| 237 |
+
map_custom_op_renames: A mapping from old to new custom op names.
|
| 238 |
+
"""
|
| 239 |
+
for op_code in model.operatorCodes:
|
| 240 |
+
if op_code.customCode:
|
| 241 |
+
op_code_str = op_code.customCode.decode('ascii')
|
| 242 |
+
if op_code_str in map_custom_op_renames:
|
| 243 |
+
op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def opcode_to_name(model, op_code):
|
| 247 |
+
"""Converts a TFLite op_code to the human readable name.
|
| 248 |
+
|
| 249 |
+
Args:
|
| 250 |
+
model: The input tflite model.
|
| 251 |
+
op_code: The op_code to resolve to a readable name.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
A string containing the human readable op name, or None if not resolvable.
|
| 255 |
+
"""
|
| 256 |
+
op = model.operatorCodes[op_code]
|
| 257 |
+
code = max(op.builtinCode, op.deprecatedBuiltinCode)
|
| 258 |
+
for name, value in vars(schema_fb.BuiltinOperator).items():
|
| 259 |
+
if value == code:
|
| 260 |
+
return name
|
| 261 |
+
return None
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def xxd_output_to_bytes(input_cc_file):
|
| 265 |
+
"""Converts xxd output C++ source file to bytes (immutable).
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
input_cc_file: Full path name to th C++ source file dumped by xxd
|
| 269 |
+
|
| 270 |
+
Raises:
|
| 271 |
+
RuntimeError: If input_cc_file path is invalid.
|
| 272 |
+
IOError: If input_cc_file cannot be opened.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
A bytearray corresponding to the input cc file array.
|
| 276 |
+
"""
|
| 277 |
+
# Match hex values in the string with comma as separator
|
| 278 |
+
pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
|
| 279 |
+
|
| 280 |
+
model_bytearray = bytearray()
|
| 281 |
+
|
| 282 |
+
with open(input_cc_file) as file_handle:
|
| 283 |
+
for line in file_handle:
|
| 284 |
+
values_match = pattern.match(line)
|
| 285 |
+
|
| 286 |
+
if values_match is None:
|
| 287 |
+
continue
|
| 288 |
+
|
| 289 |
+
# Match in the parentheses (hex array only)
|
| 290 |
+
list_text = values_match.group(1)
|
| 291 |
+
|
| 292 |
+
# Extract hex values (text) from the line
|
| 293 |
+
# e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
|
| 294 |
+
values_text = filter(None, list_text.split(','))
|
| 295 |
+
|
| 296 |
+
# Convert to hex
|
| 297 |
+
values = [int(x, base=16) for x in values_text]
|
| 298 |
+
model_bytearray.extend(values)
|
| 299 |
+
|
| 300 |
+
return bytes(model_bytearray)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def xxd_output_to_object(input_cc_file):
|
| 304 |
+
"""Converts xxd output C++ source file to object.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
input_cc_file: Full path name to th C++ source file dumped by xxd
|
| 308 |
+
|
| 309 |
+
Raises:
|
| 310 |
+
RuntimeError: If input_cc_file path is invalid.
|
| 311 |
+
IOError: If input_cc_file cannot be opened.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
A python object corresponding to the input tflite file.
|
| 315 |
+
"""
|
| 316 |
+
model_bytes = xxd_output_to_bytes(input_cc_file)
|
| 317 |
+
return convert_bytearray_to_object(model_bytes)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness):
|
| 321 |
+
"""Helper function for byte-swapping the buffers field."""
|
| 322 |
+
to_swap = [
|
| 323 |
+
buffer.data[i : i + chunksize]
|
| 324 |
+
for i in range(0, len(buffer.data), chunksize)
|
| 325 |
+
]
|
| 326 |
+
buffer.data = b''.join([
|
| 327 |
+
int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness)
|
| 328 |
+
for byteswap in to_swap
|
| 329 |
+
])
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def byte_swap_string_content(buffer, from_endiness, to_endiness):
|
| 333 |
+
"""Helper function for byte-swapping the string buffer.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
buffer: TFLite string buffer of from_endiness format.
|
| 337 |
+
from_endiness: The original endianness format of the string buffer.
|
| 338 |
+
to_endiness: The destined endianness format of the string buffer.
|
| 339 |
+
"""
|
| 340 |
+
num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness)
|
| 341 |
+
string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :])
|
| 342 |
+
prefix_data = b''.join([
|
| 343 |
+
int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes(
|
| 344 |
+
4, to_endiness
|
| 345 |
+
)
|
| 346 |
+
for i in range(0, (num_of_strings + 1) * 4 + 1, 4)
|
| 347 |
+
])
|
| 348 |
+
buffer.data = prefix_data + string_content
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def byte_swap_tflite_model_obj(model, from_endiness, to_endiness):
|
| 352 |
+
"""Byte swaps the buffers field in a TFLite model.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
model: TFLite model object of from_endiness format.
|
| 356 |
+
from_endiness: The original endianness format of the buffers in model.
|
| 357 |
+
to_endiness: The destined endianness format of the buffers in model.
|
| 358 |
+
"""
|
| 359 |
+
if model is None:
|
| 360 |
+
return
|
| 361 |
+
# Get all the constant buffers, byte swapping them as per their data types
|
| 362 |
+
buffer_swapped = []
|
| 363 |
+
types_of_16_bits = [
|
| 364 |
+
schema_fb.TensorType.FLOAT16,
|
| 365 |
+
schema_fb.TensorType.INT16,
|
| 366 |
+
schema_fb.TensorType.UINT16,
|
| 367 |
+
]
|
| 368 |
+
types_of_32_bits = [
|
| 369 |
+
schema_fb.TensorType.FLOAT32,
|
| 370 |
+
schema_fb.TensorType.INT32,
|
| 371 |
+
schema_fb.TensorType.COMPLEX64,
|
| 372 |
+
schema_fb.TensorType.UINT32,
|
| 373 |
+
]
|
| 374 |
+
types_of_64_bits = [
|
| 375 |
+
schema_fb.TensorType.INT64,
|
| 376 |
+
schema_fb.TensorType.FLOAT64,
|
| 377 |
+
schema_fb.TensorType.COMPLEX128,
|
| 378 |
+
schema_fb.TensorType.UINT64,
|
| 379 |
+
]
|
| 380 |
+
for subgraph in model.subgraphs:
|
| 381 |
+
for tensor in subgraph.tensors:
|
| 382 |
+
if (
|
| 383 |
+
tensor.buffer > 0
|
| 384 |
+
and tensor.buffer < len(model.buffers)
|
| 385 |
+
and tensor.buffer not in buffer_swapped
|
| 386 |
+
and model.buffers[tensor.buffer].data is not None
|
| 387 |
+
):
|
| 388 |
+
if tensor.type == schema_fb.TensorType.STRING:
|
| 389 |
+
byte_swap_string_content(
|
| 390 |
+
model.buffers[tensor.buffer], from_endiness, to_endiness
|
| 391 |
+
)
|
| 392 |
+
elif tensor.type in types_of_16_bits:
|
| 393 |
+
byte_swap_buffer_content(
|
| 394 |
+
model.buffers[tensor.buffer], 2, from_endiness, to_endiness
|
| 395 |
+
)
|
| 396 |
+
elif tensor.type in types_of_32_bits:
|
| 397 |
+
byte_swap_buffer_content(
|
| 398 |
+
model.buffers[tensor.buffer], 4, from_endiness, to_endiness
|
| 399 |
+
)
|
| 400 |
+
elif tensor.type in types_of_64_bits:
|
| 401 |
+
byte_swap_buffer_content(
|
| 402 |
+
model.buffers[tensor.buffer], 8, from_endiness, to_endiness
|
| 403 |
+
)
|
| 404 |
+
else:
|
| 405 |
+
continue
|
| 406 |
+
buffer_swapped.append(tensor.buffer)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness):
|
| 410 |
+
"""Generates a new model byte array after byte swapping its buffers field.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
tflite_model: TFLite flatbuffer in a byte array.
|
| 414 |
+
from_endiness: The original endianness format of the buffers in
|
| 415 |
+
tflite_model.
|
| 416 |
+
to_endiness: The destined endianness format of the buffers in tflite_model.
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
TFLite flatbuffer in a byte array, after being byte swapped to to_endiness
|
| 420 |
+
format.
|
| 421 |
+
"""
|
| 422 |
+
if tflite_model is None:
|
| 423 |
+
return None
|
| 424 |
+
# Load TFLite Flatbuffer byte array into an object.
|
| 425 |
+
model = convert_bytearray_to_object(tflite_model)
|
| 426 |
+
|
| 427 |
+
# Byte swapping the constant buffers as per their data types
|
| 428 |
+
byte_swap_tflite_model_obj(model, from_endiness, to_endiness)
|
| 429 |
+
|
| 430 |
+
# Return a TFLite flatbuffer as a byte array.
|
| 431 |
+
return convert_object_to_bytearray(model)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def count_resource_variables(model):
|
| 435 |
+
"""Calculates the number of unique resource variables in a model.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
model: the input tflite model, either as bytearray or object.
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
An integer number representing the number of unique resource variables.
|
| 442 |
+
"""
|
| 443 |
+
if not isinstance(model, schema_fb.ModelT):
|
| 444 |
+
model = convert_bytearray_to_object(model)
|
| 445 |
+
unique_shared_names = set()
|
| 446 |
+
for subgraph in model.subgraphs:
|
| 447 |
+
if subgraph.operators is None:
|
| 448 |
+
continue
|
| 449 |
+
for op in subgraph.operators:
|
| 450 |
+
builtin_code = schema_util.get_builtin_code_from_operator_code(
|
| 451 |
+
model.operatorCodes[op.opcodeIndex]
|
| 452 |
+
)
|
| 453 |
+
if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
|
| 454 |
+
unique_shared_names.add(op.builtinOptions.sharedName)
|
| 455 |
+
return len(unique_shared_names)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (206 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (216 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (223 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/debugger.cpython-310.pyc
ADDED
|
Binary file (18.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/debugger.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Python TF-Lite QuantizationDebugger."""
|
| 16 |
+
import collections
|
| 17 |
+
import csv
|
| 18 |
+
import re
|
| 19 |
+
from typing import (Any, Callable, Dict, IO, Iterable, List, Mapping, Optional,
|
| 20 |
+
Sequence, Tuple)
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
from tensorflow.lite.python import convert
|
| 25 |
+
from tensorflow.lite.python import interpreter as _interpreter
|
| 26 |
+
from tensorflow.lite.python.metrics import metrics as metrics_stub # type: ignore
|
| 27 |
+
from tensorflow.python.util import tf_export
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# TODO(b/198099651): move converter implementation out of lite.py
|
| 31 |
+
TFLiteConverter = Any # importing tf.lite creates circular dependency
|
| 32 |
+
|
| 33 |
+
# Returns metrics based on difference of values for quantized/float ops.
|
| 34 |
+
_DEFAULT_LAYER_DEBUG_METRICS = {
|
| 35 |
+
'num_elements': lambda diffs: diffs.size,
|
| 36 |
+
'stddev': np.std,
|
| 37 |
+
'mean_error': np.average,
|
| 38 |
+
'max_abs_error': lambda diffs: np.max(np.abs(diffs)),
|
| 39 |
+
'mean_squared_error': lambda diffs: np.average(diffs**2),
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
_NUMERIC_VERIFY_OP_NAME = 'NumericVerify'
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _get_quant_params(
|
| 46 |
+
tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]:
|
| 47 |
+
"""Returns first scale and zero point from tensor detail, if present."""
|
| 48 |
+
quant_params = tensor_detail['quantization_parameters']
|
| 49 |
+
if not quant_params:
|
| 50 |
+
return None
|
| 51 |
+
if quant_params['scales'] and quant_params['zero_points']:
|
| 52 |
+
return (quant_params['scales'][0], quant_params['zero_points'][0])
|
| 53 |
+
return None
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@tf_export.tf_export('lite.experimental.QuantizationDebugOptions')
|
| 57 |
+
class QuantizationDebugOptions:
|
| 58 |
+
"""Debug options to set up a given QuantizationDebugger."""
|
| 59 |
+
|
| 60 |
+
def __init__(self,
|
| 61 |
+
layer_debug_metrics: Optional[Mapping[str,
|
| 62 |
+
Callable[[np.ndarray],
|
| 63 |
+
float]]] = None,
|
| 64 |
+
model_debug_metrics: Optional[Mapping[
|
| 65 |
+
str, Callable[[Sequence[np.ndarray], Sequence[np.ndarray]],
|
| 66 |
+
float]]] = None,
|
| 67 |
+
layer_direct_compare_metrics: Optional[Mapping[str, Callable[
|
| 68 |
+
[Sequence[np.ndarray], Sequence[np.ndarray], float, int],
|
| 69 |
+
float]]] = None,
|
| 70 |
+
denylisted_ops: Optional[List[str]] = None,
|
| 71 |
+
denylisted_nodes: Optional[List[str]] = None,
|
| 72 |
+
fully_quantize: bool = False) -> None:
|
| 73 |
+
"""Initializes debugger options.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
layer_debug_metrics: a dict to specify layer debug functions
|
| 77 |
+
{function_name_str: function} where the function accepts result of
|
| 78 |
+
NumericVerify Op, which is value difference between float and
|
| 79 |
+
dequantized op results. The function returns single scalar value.
|
| 80 |
+
model_debug_metrics: a dict to specify model debug functions
|
| 81 |
+
{function_name_str: function} where the function accepts outputs from
|
| 82 |
+
two models, and returns single scalar value for a metric. (e.g.
|
| 83 |
+
accuracy, IoU)
|
| 84 |
+
layer_direct_compare_metrics: a dict to specify layer debug functions
|
| 85 |
+
{function_name_str: function}. The signature is different from that of
|
| 86 |
+
`layer_debug_metrics`, and this one gets passed (original float value,
|
| 87 |
+
original quantized value, scale, zero point). The function's
|
| 88 |
+
implementation is responsible for correctly dequantize the quantized
|
| 89 |
+
value to compare. Use this one when comparing diff is not enough.
|
| 90 |
+
(Note) quantized value is passed as int8, so cast to int32 is needed.
|
| 91 |
+
denylisted_ops: a list of op names which is expected to be removed from
|
| 92 |
+
quantization.
|
| 93 |
+
denylisted_nodes: a list of op's output tensor names to be removed from
|
| 94 |
+
quantization.
|
| 95 |
+
fully_quantize: Bool indicating whether to fully quantize the model.
|
| 96 |
+
Besides model body, the input/output will be quantized as well.
|
| 97 |
+
Corresponding to mlir_quantize's fully_quantize parameter.
|
| 98 |
+
|
| 99 |
+
Raises:
|
| 100 |
+
ValueError: when there are duplicate keys
|
| 101 |
+
"""
|
| 102 |
+
self.layer_debug_metrics = layer_debug_metrics
|
| 103 |
+
self.model_debug_metrics = model_debug_metrics
|
| 104 |
+
self.layer_direct_compare_metrics = layer_direct_compare_metrics
|
| 105 |
+
|
| 106 |
+
keys = []
|
| 107 |
+
for metrics in [
|
| 108 |
+
layer_debug_metrics, model_debug_metrics, layer_direct_compare_metrics
|
| 109 |
+
]:
|
| 110 |
+
if metrics is not None:
|
| 111 |
+
keys.extend(metrics.keys())
|
| 112 |
+
if len(keys) != len(set(keys)):
|
| 113 |
+
raise ValueError('Provided metrics have duplicate keys.')
|
| 114 |
+
|
| 115 |
+
self.denylisted_ops = denylisted_ops
|
| 116 |
+
self.denylisted_nodes = denylisted_nodes
|
| 117 |
+
self.fully_quantize = fully_quantize
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@tf_export.tf_export('lite.experimental.QuantizationDebugger')
|
| 121 |
+
class QuantizationDebugger:
|
| 122 |
+
"""Debugger for Quantized TensorFlow Lite debug mode models.
|
| 123 |
+
|
| 124 |
+
This can run the TensorFlow Lite converted models equipped with debug ops and
|
| 125 |
+
collect debug information. This debugger calculates statistics from
|
| 126 |
+
user-defined post-processing functions as well as default ones.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self,
|
| 130 |
+
quant_debug_model_path: Optional[str] = None,
|
| 131 |
+
quant_debug_model_content: Optional[bytes] = None,
|
| 132 |
+
float_model_path: Optional[str] = None,
|
| 133 |
+
float_model_content: Optional[bytes] = None,
|
| 134 |
+
debug_dataset: Optional[Callable[
|
| 135 |
+
[], Iterable[Sequence[np.ndarray]]]] = None,
|
| 136 |
+
debug_options: Optional[QuantizationDebugOptions] = None,
|
| 137 |
+
converter: Optional[TFLiteConverter] = None) -> None:
|
| 138 |
+
"""Runs the TFLite debugging model with given debug options.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
quant_debug_model_path: Path to the quantized debug TFLite model file.
|
| 142 |
+
quant_debug_model_content: Content of the quantized debug TFLite model.
|
| 143 |
+
float_model_path: Path to float TFLite model file.
|
| 144 |
+
float_model_content: Content of the float TFLite model.
|
| 145 |
+
debug_dataset: a factory function that returns dataset generator which is
|
| 146 |
+
used to generate input samples (list of np.ndarray) for the model. The
|
| 147 |
+
generated elements must have same types and shape as inputs to the
|
| 148 |
+
model.
|
| 149 |
+
debug_options: Debug options to debug the given model.
|
| 150 |
+
converter: Optional, use converter instead of quantized model.
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
ValueError: If the debugger was unable to be created.
|
| 154 |
+
|
| 155 |
+
Attributes:
|
| 156 |
+
layer_statistics: results of error metrics for each NumericVerify op
|
| 157 |
+
results. in {layer_name: {metric_name: metric}} format.
|
| 158 |
+
model_statistics: results of error metrics for difference between float
|
| 159 |
+
and quantized models. in {metric_name: metric} format.
|
| 160 |
+
"""
|
| 161 |
+
self._data_gen = debug_dataset
|
| 162 |
+
self._debug_options = debug_options or QuantizationDebugOptions()
|
| 163 |
+
self.converter = None
|
| 164 |
+
self.calibrated_model = None
|
| 165 |
+
self.float_model = None
|
| 166 |
+
self._float_interpreter = None
|
| 167 |
+
if converter is not None:
|
| 168 |
+
if self._debug_options.model_debug_metrics:
|
| 169 |
+
old_optimizations = converter.optimizations
|
| 170 |
+
self.converter = self._set_converter_options_for_float(converter)
|
| 171 |
+
self.float_model = self.converter.convert()
|
| 172 |
+
converter.optimizations = old_optimizations
|
| 173 |
+
|
| 174 |
+
self.converter = self._set_converter_options_for_calibration(converter)
|
| 175 |
+
self.calibrated_model = self.converter.convert()
|
| 176 |
+
# Converter should be already set up with all options
|
| 177 |
+
self._init_from_converter(
|
| 178 |
+
self._debug_options,
|
| 179 |
+
self.converter,
|
| 180 |
+
self.calibrated_model,
|
| 181 |
+
float_model=self.float_model)
|
| 182 |
+
else:
|
| 183 |
+
self._quant_interpreter = _interpreter.Interpreter(
|
| 184 |
+
quant_debug_model_path,
|
| 185 |
+
quant_debug_model_content,
|
| 186 |
+
experimental_preserve_all_tensors=(
|
| 187 |
+
self._debug_options.layer_direct_compare_metrics is not None))
|
| 188 |
+
if self._debug_options.model_debug_metrics:
|
| 189 |
+
self._float_interpreter = _interpreter.Interpreter(
|
| 190 |
+
float_model_path, float_model_content)
|
| 191 |
+
self._initialize_stats()
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def options(self) -> QuantizationDebugOptions:
|
| 195 |
+
return self._debug_options
|
| 196 |
+
|
| 197 |
+
@options.setter
|
| 198 |
+
def options(self, options: QuantizationDebugOptions) -> None:
|
| 199 |
+
self._debug_options = options
|
| 200 |
+
if not self.converter or not self.calibrated_model:
|
| 201 |
+
return
|
| 202 |
+
self._init_from_converter(
|
| 203 |
+
self._debug_options,
|
| 204 |
+
self.converter,
|
| 205 |
+
self.calibrated_model,
|
| 206 |
+
float_model=self.float_model)
|
| 207 |
+
self._initialize_stats()
|
| 208 |
+
|
| 209 |
+
def _initialize_stats(self):
|
| 210 |
+
"""Helper function initializes stats."""
|
| 211 |
+
# TODO(b/177749613) : Fix the dependency on tf.lite._get_ops_details()
|
| 212 |
+
# Following code is needed to get op's name from the output tensor index,
|
| 213 |
+
# since NumericVerify op only provides its quantized input tensor index.
|
| 214 |
+
self._defining_op = dict()
|
| 215 |
+
for op_info in self._quant_interpreter._get_ops_details(): # pylint: disable=protected-access
|
| 216 |
+
self._defining_op.update(
|
| 217 |
+
{tensor_idx: op_info['index'] for tensor_idx in op_info['outputs']})
|
| 218 |
+
|
| 219 |
+
self._numeric_verify_tensor_details = None
|
| 220 |
+
self._numeric_verify_op_details = None
|
| 221 |
+
if not self._get_numeric_verify_tensor_details():
|
| 222 |
+
raise ValueError('Please check if the quantized model is in debug mode')
|
| 223 |
+
|
| 224 |
+
self._layer_debug_metrics = _DEFAULT_LAYER_DEBUG_METRICS.copy()
|
| 225 |
+
if self._debug_options.layer_debug_metrics:
|
| 226 |
+
self._layer_debug_metrics.update(self._debug_options.layer_debug_metrics)
|
| 227 |
+
|
| 228 |
+
self.layer_statistics = None
|
| 229 |
+
self.model_statistics = None
|
| 230 |
+
|
| 231 |
+
self._metrics = metrics_stub.TFLiteMetrics()
|
| 232 |
+
self._metrics.increase_counter_debugger_creation()
|
| 233 |
+
|
| 234 |
+
def _get_quantized_model(self, is_debug: bool) -> bytes:
|
| 235 |
+
if not self.converter:
|
| 236 |
+
raise ValueError('No converter found, use this function with the '
|
| 237 |
+
'converter option in the constructor.')
|
| 238 |
+
|
| 239 |
+
return convert.mlir_quantize(
|
| 240 |
+
self.calibrated_model,
|
| 241 |
+
disable_per_channel=self.converter._experimental_disable_per_channel, # pylint: disable=protected-access
|
| 242 |
+
fully_quantize=self._debug_options.fully_quantize,
|
| 243 |
+
enable_numeric_verify=is_debug,
|
| 244 |
+
denylisted_ops=self._debug_options.denylisted_ops,
|
| 245 |
+
denylisted_nodes=self._debug_options.denylisted_nodes)
|
| 246 |
+
|
| 247 |
+
def get_nondebug_quantized_model(self) -> bytes:
|
| 248 |
+
"""Returns a non-instrumented quantized model.
|
| 249 |
+
|
| 250 |
+
Convert the quantized model with the initialized converter and
|
| 251 |
+
return bytes for nondebug model. The model will not be instrumented with
|
| 252 |
+
numeric verification operations.
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Model bytes corresponding to the model.
|
| 256 |
+
Raises:
|
| 257 |
+
ValueError: if converter is not passed to the debugger.
|
| 258 |
+
"""
|
| 259 |
+
return self._get_quantized_model(is_debug=False)
|
| 260 |
+
|
| 261 |
+
def get_debug_quantized_model(self) -> bytes:
|
| 262 |
+
"""Returns an instrumented quantized model.
|
| 263 |
+
|
| 264 |
+
Convert the quantized model with the initialized converter and
|
| 265 |
+
return bytes for model. The model will be instrumented with numeric
|
| 266 |
+
verification operations and should only be used for debugging.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Model bytes corresponding to the model.
|
| 270 |
+
Raises:
|
| 271 |
+
ValueError: if converter is not passed to the debugger.
|
| 272 |
+
"""
|
| 273 |
+
return self._get_quantized_model(is_debug=True)
|
| 274 |
+
|
| 275 |
+
def _init_from_converter(self,
|
| 276 |
+
options: QuantizationDebugOptions,
|
| 277 |
+
converter: TFLiteConverter,
|
| 278 |
+
calibrated_model: Optional[bytes] = None,
|
| 279 |
+
float_model: Optional[bytes] = None) -> None:
|
| 280 |
+
"""Convert the model and apply options.
|
| 281 |
+
|
| 282 |
+
Converts the quantized model and initializes a quantized model interpreter
|
| 283 |
+
with the quantized model. Returns a float model interpreter if float model
|
| 284 |
+
is provided.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
options: a QuantizationDebugOptions object.
|
| 288 |
+
converter: an initialized tf.lite.TFLiteConverter.
|
| 289 |
+
calibrated_model: Calibrated model bytes.
|
| 290 |
+
float_model: Float model bytes.
|
| 291 |
+
"""
|
| 292 |
+
self.quant_model = convert.mlir_quantize(
|
| 293 |
+
calibrated_model,
|
| 294 |
+
disable_per_channel=converter._experimental_disable_per_channel, # pylint: disable=protected-access
|
| 295 |
+
fully_quantize=options.fully_quantize,
|
| 296 |
+
enable_numeric_verify=True,
|
| 297 |
+
denylisted_ops=options.denylisted_ops,
|
| 298 |
+
denylisted_nodes=options.denylisted_nodes)
|
| 299 |
+
self._quant_interpreter = _interpreter.Interpreter(
|
| 300 |
+
model_content=self.quant_model)
|
| 301 |
+
self._float_interpreter = None
|
| 302 |
+
if float_model is not None:
|
| 303 |
+
self._float_interpreter = _interpreter.Interpreter(
|
| 304 |
+
model_content=float_model)
|
| 305 |
+
|
| 306 |
+
def _set_converter_options_for_float(
|
| 307 |
+
self, converter: TFLiteConverter) -> TFLiteConverter:
|
| 308 |
+
"""Verify converter options and set required experimental options."""
|
| 309 |
+
if converter.optimizations:
|
| 310 |
+
converter.optimizations = []
|
| 311 |
+
return converter
|
| 312 |
+
|
| 313 |
+
def _set_converter_options_for_calibration(
|
| 314 |
+
self, converter: TFLiteConverter) -> TFLiteConverter:
|
| 315 |
+
"""Verify converter options and set required experimental options."""
|
| 316 |
+
if not converter.optimizations:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
'converter object must set optimizations to lite.Optimize.DEFAULT')
|
| 319 |
+
if not converter.representative_dataset:
|
| 320 |
+
raise ValueError('converter object must set representative_dataset')
|
| 321 |
+
|
| 322 |
+
converter.experimental_mlir_quantizer = True
|
| 323 |
+
converter._experimental_calibrate_only = True # pylint: disable=protected-access
|
| 324 |
+
return converter
|
| 325 |
+
|
| 326 |
+
def run(self) -> None:
|
| 327 |
+
"""Runs models and gets metrics."""
|
| 328 |
+
self.layer_statistics = self._collect_layer_statistics()
|
| 329 |
+
if self._debug_options.model_debug_metrics:
|
| 330 |
+
self.model_statistics = self._collect_model_statistics()
|
| 331 |
+
|
| 332 |
+
def _collect_layer_statistics(self) -> Dict[str, Dict[str, float]]:
|
| 333 |
+
"""Collects layer statistics by applying layer debug metrics.
|
| 334 |
+
|
| 335 |
+
For all data from the given RepresentativeDataset, collect statistics per
|
| 336 |
+
example by getting the NumericVerify op results in _quant_interpreter
|
| 337 |
+
and calculating layer debug metrics on the results.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
aggregated per-layer statistics of NumericVerify results.
|
| 341 |
+
{layer_name: {metric_name: metric}}
|
| 342 |
+
"""
|
| 343 |
+
layer_statistics = collections.defaultdict(
|
| 344 |
+
lambda: collections.defaultdict(list))
|
| 345 |
+
|
| 346 |
+
initialize = True
|
| 347 |
+
for tensor_data in self._data_gen():
|
| 348 |
+
self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
|
| 349 |
+
initialize = False
|
| 350 |
+
|
| 351 |
+
# Run the model.
|
| 352 |
+
self._quant_interpreter.invoke()
|
| 353 |
+
|
| 354 |
+
# Collect the statistics of this invoke result.
|
| 355 |
+
for tensor_detail in self._get_numeric_verify_tensor_details():
|
| 356 |
+
tensor_name = tensor_detail['name'] # pytype: disable=unsupported-operands # dynamic-method-lookup
|
| 357 |
+
diffs = self._quant_interpreter.get_tensor(tensor_detail['index']) # pytype: disable=unsupported-operands # dynamic-method-lookup
|
| 358 |
+
for metric_name, metric_fn in self._layer_debug_metrics.items():
|
| 359 |
+
layer_statistics[tensor_name][metric_name].append(metric_fn(diffs))
|
| 360 |
+
|
| 361 |
+
if self._debug_options.layer_direct_compare_metrics is not None:
|
| 362 |
+
for tensor_detail in self._get_numeric_verify_tensor_details():
|
| 363 |
+
tensor_name = tensor_detail['name'] # pytype: disable=unsupported-operands # dynamic-method-lookup
|
| 364 |
+
op_idx = self._defining_op[tensor_detail['index']] # pytype: disable=unsupported-operands # dynamic-method-lookup
|
| 365 |
+
op_detail = self._quant_interpreter._get_op_details(op_idx) # pylint: disable=protected-access
|
| 366 |
+
q_idx, f_idx = op_detail['inputs']
|
| 367 |
+
quant_input_detail = self._quant_interpreter._get_tensor_details( # pylint: disable=protected-access
|
| 368 |
+
q_idx, subgraph_index=0)
|
| 369 |
+
for (metric_name, metric_fn
|
| 370 |
+
) in self._debug_options.layer_direct_compare_metrics.items():
|
| 371 |
+
layer_statistics[tensor_name][metric_name].append(
|
| 372 |
+
metric_fn(
|
| 373 |
+
self._quant_interpreter.get_tensor(f_idx),
|
| 374 |
+
self._quant_interpreter.get_tensor(q_idx),
|
| 375 |
+
quant_input_detail['quantization_parameters']['scales'][0],
|
| 376 |
+
quant_input_detail['quantization_parameters']['zero_points']
|
| 377 |
+
[0]))
|
| 378 |
+
|
| 379 |
+
# Calculate final aggregated metrics for each layer.
|
| 380 |
+
for metrics in layer_statistics.values():
|
| 381 |
+
for metric_name in metrics:
|
| 382 |
+
metrics[metric_name] = np.nanmean(metrics[metric_name])
|
| 383 |
+
|
| 384 |
+
return layer_statistics
|
| 385 |
+
|
| 386 |
+
def _collect_model_statistics(self) -> Dict[str, float]:
|
| 387 |
+
"""Collects model output metrics.
|
| 388 |
+
|
| 389 |
+
For all data from the given RepresentativeDataset, collect all model output
|
| 390 |
+
results from float model & quantized debug model, and calculate metrics
|
| 391 |
+
by using model output functions. As a result, self.model_results is filled,
|
| 392 |
+
|
| 393 |
+
where self.model_results[model_output_function_name] = `aggregated model
|
| 394 |
+
output function value` (a scalar).
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
aggregated per-model output discrepancy metrics.
|
| 398 |
+
{metric_name: aggregated_metric}
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
model_statistics = collections.defaultdict(list)
|
| 402 |
+
|
| 403 |
+
initialize = True
|
| 404 |
+
for tensor_data in self._data_gen():
|
| 405 |
+
# Run quantized debug model and collect output results.
|
| 406 |
+
self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
|
| 407 |
+
self._quant_interpreter.invoke()
|
| 408 |
+
quant_tensor_data = self._get_output_tensors(self._quant_interpreter)
|
| 409 |
+
|
| 410 |
+
# Run float model if it's initialized.
|
| 411 |
+
float_tensor_data = []
|
| 412 |
+
if self._float_interpreter:
|
| 413 |
+
self._set_input_tensors(
|
| 414 |
+
self._float_interpreter, tensor_data, initialize)
|
| 415 |
+
self._float_interpreter.invoke()
|
| 416 |
+
float_tensor_data = self._get_output_tensors(self._float_interpreter)
|
| 417 |
+
|
| 418 |
+
initialize = False
|
| 419 |
+
|
| 420 |
+
# Calculate the metrics.
|
| 421 |
+
for (metric_name,
|
| 422 |
+
metric_fn) in self._debug_options.model_debug_metrics.items():
|
| 423 |
+
model_statistics[metric_name].append(
|
| 424 |
+
metric_fn(float_tensor_data, quant_tensor_data))
|
| 425 |
+
|
| 426 |
+
# Calculate final aggregated metrics for each outputs.
|
| 427 |
+
return {
|
| 428 |
+
metric_name: np.mean(metric)
|
| 429 |
+
for metric_name, metric in model_statistics.items()
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
def _set_input_tensors(self, interpreter: _interpreter.Interpreter,
|
| 433 |
+
tensor_data: Sequence[np.ndarray],
|
| 434 |
+
initialize: bool) -> None:
|
| 435 |
+
"""Sets input tensors into TFLite model Interpreter.
|
| 436 |
+
|
| 437 |
+
Args:
|
| 438 |
+
interpreter: a tf.lite.Interpreter object with allocated tensors.
|
| 439 |
+
tensor_data: a list of Numpy array data.
|
| 440 |
+
initialize: set to true when input is first set for the interpreter, to
|
| 441 |
+
set input shapes and allocate tensors.
|
| 442 |
+
|
| 443 |
+
Raises:
|
| 444 |
+
ValueError: when inputs can't be set, or size of provided inputs does not
|
| 445 |
+
match size of model inputs.
|
| 446 |
+
"""
|
| 447 |
+
input_details = interpreter.get_input_details()
|
| 448 |
+
if len(input_details) != len(tensor_data):
|
| 449 |
+
raise ValueError(
|
| 450 |
+
'Number of inputs provided ({}) does not match number of inputs to '
|
| 451 |
+
'the model ({})'.format(len(tensor_data), len(input_details)))
|
| 452 |
+
|
| 453 |
+
if initialize:
|
| 454 |
+
for input_detail, tensor in zip(input_details, tensor_data):
|
| 455 |
+
interpreter.resize_tensor_input(input_detail['index'], tensor.shape)
|
| 456 |
+
interpreter.allocate_tensors()
|
| 457 |
+
|
| 458 |
+
for input_detail, tensor in zip(input_details, tensor_data):
|
| 459 |
+
if tensor.dtype == np.float32 and input_detail['dtype'] == np.int8:
|
| 460 |
+
quant_params = _get_quant_params(input_detail)
|
| 461 |
+
if quant_params:
|
| 462 |
+
scale, zero_point = quant_params
|
| 463 |
+
tensor = np.round((tensor / scale) + zero_point).astype(np.int8)
|
| 464 |
+
interpreter.set_tensor(input_detail['index'], tensor)
|
| 465 |
+
|
| 466 |
+
def _get_output_tensors(
|
| 467 |
+
self, interpreter: _interpreter.Interpreter) -> List[np.ndarray]:
|
| 468 |
+
"""Returns output tensors of given TFLite model Interpreter.
|
| 469 |
+
|
| 470 |
+
Args:
|
| 471 |
+
interpreter: a tf.lite.Interpreter object with allocated tensors.
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
a list of numpy arrays representing output tensor results.
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
outputs = []
|
| 478 |
+
for output_detail in interpreter.get_output_details():
|
| 479 |
+
tensor = interpreter.get_tensor(output_detail['index'])
|
| 480 |
+
if output_detail['dtype'] == np.int8:
|
| 481 |
+
quant_params = _get_quant_params(output_detail)
|
| 482 |
+
if quant_params:
|
| 483 |
+
scale, zero_point = quant_params
|
| 484 |
+
tensor = ((tensor.astype(np.float32) - zero_point) * scale).astype(
|
| 485 |
+
np.float32)
|
| 486 |
+
outputs.append(tensor)
|
| 487 |
+
|
| 488 |
+
return outputs
|
| 489 |
+
|
| 490 |
+
def _get_numeric_verify_tensor_details(self) -> List[str]:
|
| 491 |
+
"""Returns all names of all tensors from NumericVerify op."""
|
| 492 |
+
# pylint: disable=protected-access
|
| 493 |
+
if not self._numeric_verify_tensor_details:
|
| 494 |
+
self._numeric_verify_tensor_details = []
|
| 495 |
+
self._numeric_verify_op_details = {}
|
| 496 |
+
for op_info in self._quant_interpreter._get_ops_details():
|
| 497 |
+
if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME:
|
| 498 |
+
self._numeric_verify_tensor_details.append(
|
| 499 |
+
self._quant_interpreter._get_tensor_details(
|
| 500 |
+
op_info['outputs'][0], subgraph_index=0))
|
| 501 |
+
tensor_name = self._numeric_verify_tensor_details[-1]['name']
|
| 502 |
+
self._numeric_verify_op_details[tensor_name] = op_info
|
| 503 |
+
# pylint: enable=protected-access
|
| 504 |
+
return self._numeric_verify_tensor_details
|
| 505 |
+
|
| 506 |
+
def _get_operand_name_and_index(self,
|
| 507 |
+
numeric_verify_name: str) -> Tuple[str, int]:
|
| 508 |
+
"""Gets the index and name of NumericVerify Op's quantized input tensor.
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
numeric_verify_name: name of the NumericVerify op's output tensor. It has
|
| 512 |
+
format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}`
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
Tuple of (tensor_name, tensor_idx) for quantized op's output tensor.
|
| 516 |
+
"""
|
| 517 |
+
tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1)
|
| 518 |
+
float_tensor_name = tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:]
|
| 519 |
+
if re.match(r'\d', float_tensor_name[-1]):
|
| 520 |
+
float_tensor_name = float_tensor_name[:-1]
|
| 521 |
+
|
| 522 |
+
return (float_tensor_name, int(tensor_idx))
|
| 523 |
+
|
| 524 |
+
def layer_statistics_dump(self, file: IO[str]) -> None:
|
| 525 |
+
"""Dumps layer statistics into file, in csv format.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
file: file, or file-like object to write.
|
| 529 |
+
"""
|
| 530 |
+
# order of `fields` is the order of fields in csv.
|
| 531 |
+
fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys())
|
| 532 |
+
if self._debug_options.layer_direct_compare_metrics is not None:
|
| 533 |
+
fields += list(self._debug_options.layer_direct_compare_metrics.keys())
|
| 534 |
+
fields += ['scale', 'zero_point', 'tensor_name']
|
| 535 |
+
writer = csv.DictWriter(file, fields)
|
| 536 |
+
writer.writeheader()
|
| 537 |
+
if self.layer_statistics:
|
| 538 |
+
for name, metrics in self.layer_statistics.items():
|
| 539 |
+
data = metrics.copy()
|
| 540 |
+
(data['tensor_name'], _) = self._get_operand_name_and_index(name)
|
| 541 |
+
data['tensor_idx'] = self._numeric_verify_op_details[name]['inputs'][0]
|
| 542 |
+
data['op_name'] = self._quant_interpreter._get_op_details( # pylint: disable=protected-access
|
| 543 |
+
self._defining_op[data['tensor_idx']])['op_name']
|
| 544 |
+
details = self._quant_interpreter._get_tensor_details( # pylint: disable=protected-access
|
| 545 |
+
data['tensor_idx'], subgraph_index=0)
|
| 546 |
+
data['scale'], data['zero_point'] = (
|
| 547 |
+
details['quantization_parameters']['scales'][0],
|
| 548 |
+
details['quantization_parameters']['zero_points'][0])
|
| 549 |
+
writer.writerow(data)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/visualize.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ==============================================================================
|
| 16 |
+
"""This tool creates an html visualization of a TensorFlow Lite graph.
|
| 17 |
+
|
| 18 |
+
Example usage:
|
| 19 |
+
|
| 20 |
+
python visualize.py foo.tflite foo.html
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
import sys
|
| 27 |
+
import numpy as np
|
| 28 |
+
|
| 29 |
+
# pylint: disable=g-import-not-at-top
|
| 30 |
+
if not os.path.splitext(__file__)[0].endswith(
|
| 31 |
+
os.path.join("tflite_runtime", "visualize")):
|
| 32 |
+
# This file is part of tensorflow package.
|
| 33 |
+
from tensorflow.lite.python import schema_py_generated as schema_fb
|
| 34 |
+
else:
|
| 35 |
+
# This file is part of tflite_runtime package.
|
| 36 |
+
from tflite_runtime import schema_py_generated as schema_fb
|
| 37 |
+
|
| 38 |
+
# A CSS description for making the visualizer
|
| 39 |
+
_CSS = """
|
| 40 |
+
<html>
|
| 41 |
+
<head>
|
| 42 |
+
<style>
|
| 43 |
+
body {font-family: sans-serif; background-color: #fa0;}
|
| 44 |
+
table {background-color: #eca;}
|
| 45 |
+
th {background-color: black; color: white;}
|
| 46 |
+
h1 {
|
| 47 |
+
background-color: ffaa00;
|
| 48 |
+
padding:5px;
|
| 49 |
+
color: black;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
svg {
|
| 53 |
+
margin: 10px;
|
| 54 |
+
border: 2px;
|
| 55 |
+
border-style: solid;
|
| 56 |
+
border-color: black;
|
| 57 |
+
background: white;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
div {
|
| 61 |
+
border-radius: 5px;
|
| 62 |
+
background-color: #fec;
|
| 63 |
+
padding:5px;
|
| 64 |
+
margin:5px;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
.tooltip {color: blue;}
|
| 68 |
+
.tooltip .tooltipcontent {
|
| 69 |
+
visibility: hidden;
|
| 70 |
+
color: black;
|
| 71 |
+
background-color: yellow;
|
| 72 |
+
padding: 5px;
|
| 73 |
+
border-radius: 4px;
|
| 74 |
+
position: absolute;
|
| 75 |
+
z-index: 1;
|
| 76 |
+
}
|
| 77 |
+
.tooltip:hover .tooltipcontent {
|
| 78 |
+
visibility: visible;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
.edges line {
|
| 82 |
+
stroke: #333;
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
text {
|
| 86 |
+
font-weight: bold;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.nodes text {
|
| 90 |
+
color: black;
|
| 91 |
+
pointer-events: none;
|
| 92 |
+
font-family: sans-serif;
|
| 93 |
+
font-size: 11px;
|
| 94 |
+
}
|
| 95 |
+
</style>
|
| 96 |
+
|
| 97 |
+
<script src="https://d3js.org/d3.v4.min.js"></script>
|
| 98 |
+
|
| 99 |
+
</head>
|
| 100 |
+
<body>
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
_D3_HTML_TEMPLATE = """
|
| 104 |
+
<script>
|
| 105 |
+
function buildGraph() {
|
| 106 |
+
// Build graph data
|
| 107 |
+
var graph = %s;
|
| 108 |
+
|
| 109 |
+
var svg = d3.select("#subgraph%d")
|
| 110 |
+
var width = svg.attr("width");
|
| 111 |
+
var height = svg.attr("height");
|
| 112 |
+
// Make the graph scrollable.
|
| 113 |
+
svg = svg.call(d3.zoom().on("zoom", function() {
|
| 114 |
+
svg.attr("transform", d3.event.transform);
|
| 115 |
+
})).append("g");
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
var color = d3.scaleOrdinal(d3.schemeDark2);
|
| 119 |
+
|
| 120 |
+
var simulation = d3.forceSimulation()
|
| 121 |
+
.force("link", d3.forceLink().id(function(d) {return d.id;}))
|
| 122 |
+
.force("charge", d3.forceManyBody())
|
| 123 |
+
.force("center", d3.forceCenter(0.5 * width, 0.5 * height));
|
| 124 |
+
|
| 125 |
+
var edge = svg.append("g").attr("class", "edges").selectAll("line")
|
| 126 |
+
.data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none")
|
| 127 |
+
|
| 128 |
+
// Make the node group
|
| 129 |
+
var node = svg.selectAll(".nodes")
|
| 130 |
+
.data(graph.nodes)
|
| 131 |
+
.enter().append("g")
|
| 132 |
+
.attr("x", function(d){return d.x})
|
| 133 |
+
.attr("y", function(d){return d.y})
|
| 134 |
+
.attr("transform", function(d) {
|
| 135 |
+
return "translate( " + d.x + ", " + d.y + ")"
|
| 136 |
+
})
|
| 137 |
+
.attr("class", "nodes")
|
| 138 |
+
.call(d3.drag()
|
| 139 |
+
.on("start", function(d) {
|
| 140 |
+
if(!d3.event.active) simulation.alphaTarget(1.0).restart();
|
| 141 |
+
d.fx = d.x;d.fy = d.y;
|
| 142 |
+
})
|
| 143 |
+
.on("drag", function(d) {
|
| 144 |
+
d.fx = d3.event.x; d.fy = d3.event.y;
|
| 145 |
+
})
|
| 146 |
+
.on("end", function(d) {
|
| 147 |
+
if (!d3.event.active) simulation.alphaTarget(0);
|
| 148 |
+
d.fx = d.fy = null;
|
| 149 |
+
}));
|
| 150 |
+
// Within the group, draw a box for the node position and text
|
| 151 |
+
// on the side.
|
| 152 |
+
|
| 153 |
+
var node_width = 150;
|
| 154 |
+
var node_height = 30;
|
| 155 |
+
|
| 156 |
+
node.append("rect")
|
| 157 |
+
.attr("r", "5px")
|
| 158 |
+
.attr("width", node_width)
|
| 159 |
+
.attr("height", node_height)
|
| 160 |
+
.attr("rx", function(d) { return d.group == 1 ? 1 : 10; })
|
| 161 |
+
.attr("stroke", "#000000")
|
| 162 |
+
.attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; })
|
| 163 |
+
node.append("text")
|
| 164 |
+
.text(function(d) { return d.name; })
|
| 165 |
+
.attr("x", 5)
|
| 166 |
+
.attr("y", 20)
|
| 167 |
+
.attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; })
|
| 168 |
+
// Setup force parameters and update position callback
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
var node = svg.selectAll(".nodes")
|
| 172 |
+
.data(graph.nodes);
|
| 173 |
+
|
| 174 |
+
// Bind the links
|
| 175 |
+
var name_to_g = {}
|
| 176 |
+
node.each(function(data, index, nodes) {
|
| 177 |
+
console.log(data.id)
|
| 178 |
+
name_to_g[data.id] = this;
|
| 179 |
+
});
|
| 180 |
+
|
| 181 |
+
function proc(w, t) {
|
| 182 |
+
return parseInt(w.getAttribute(t));
|
| 183 |
+
}
|
| 184 |
+
edge.attr("d", function(d) {
|
| 185 |
+
function lerp(t, a, b) {
|
| 186 |
+
return (1.0-t) * a + t * b;
|
| 187 |
+
}
|
| 188 |
+
var x1 = proc(name_to_g[d.source],"x") + node_width /2;
|
| 189 |
+
var y1 = proc(name_to_g[d.source],"y") + node_height;
|
| 190 |
+
var x2 = proc(name_to_g[d.target],"x") + node_width /2;
|
| 191 |
+
var y2 = proc(name_to_g[d.target],"y");
|
| 192 |
+
var s = "M " + x1 + " " + y1
|
| 193 |
+
+ " C " + x1 + " " + lerp(.5, y1, y2)
|
| 194 |
+
+ " " + x2 + " " + lerp(.5, y1, y2)
|
| 195 |
+
+ " " + x2 + " " + y2
|
| 196 |
+
return s;
|
| 197 |
+
});
|
| 198 |
+
|
| 199 |
+
}
|
| 200 |
+
buildGraph()
|
| 201 |
+
</script>
|
| 202 |
+
"""
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def TensorTypeToName(tensor_type):
|
| 206 |
+
"""Converts a numerical enum to a readable tensor type."""
|
| 207 |
+
for name, value in schema_fb.TensorType.__dict__.items():
|
| 208 |
+
if value == tensor_type:
|
| 209 |
+
return name
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def BuiltinCodeToName(code):
|
| 214 |
+
"""Converts a builtin op code enum to a readable name."""
|
| 215 |
+
for name, value in schema_fb.BuiltinOperator.__dict__.items():
|
| 216 |
+
if value == code:
|
| 217 |
+
return name
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def NameListToString(name_list):
|
| 222 |
+
"""Converts a list of integers to the equivalent ASCII string."""
|
| 223 |
+
if isinstance(name_list, str):
|
| 224 |
+
return name_list
|
| 225 |
+
else:
|
| 226 |
+
result = ""
|
| 227 |
+
if name_list is not None:
|
| 228 |
+
for val in name_list:
|
| 229 |
+
result = result + chr(int(val))
|
| 230 |
+
return result
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class OpCodeMapper:
|
| 234 |
+
"""Maps an opcode index to an op name."""
|
| 235 |
+
|
| 236 |
+
def __init__(self, data):
|
| 237 |
+
self.code_to_name = {}
|
| 238 |
+
for idx, d in enumerate(data["operator_codes"]):
|
| 239 |
+
self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"])
|
| 240 |
+
if self.code_to_name[idx] == "CUSTOM":
|
| 241 |
+
self.code_to_name[idx] = NameListToString(d["custom_code"])
|
| 242 |
+
|
| 243 |
+
def __call__(self, x):
|
| 244 |
+
if x not in self.code_to_name:
|
| 245 |
+
s = "<UNKNOWN>"
|
| 246 |
+
else:
|
| 247 |
+
s = self.code_to_name[x]
|
| 248 |
+
return "%s (%d)" % (s, x)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class DataSizeMapper:
|
| 252 |
+
"""For buffers, report the number of bytes."""
|
| 253 |
+
|
| 254 |
+
def __call__(self, x):
|
| 255 |
+
if x is not None:
|
| 256 |
+
return "%d bytes" % len(x)
|
| 257 |
+
else:
|
| 258 |
+
return "--"
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class TensorMapper:
|
| 262 |
+
"""Maps a list of tensor indices to a tooltip hoverable indicator of more."""
|
| 263 |
+
|
| 264 |
+
def __init__(self, subgraph_data):
|
| 265 |
+
self.data = subgraph_data
|
| 266 |
+
|
| 267 |
+
def __call__(self, x):
|
| 268 |
+
html = ""
|
| 269 |
+
if x is None:
|
| 270 |
+
return html
|
| 271 |
+
|
| 272 |
+
html += "<span class='tooltip'><span class='tooltipcontent'>"
|
| 273 |
+
for i in x:
|
| 274 |
+
tensor = self.data["tensors"][i]
|
| 275 |
+
html += str(i) + " "
|
| 276 |
+
html += NameListToString(tensor["name"]) + " "
|
| 277 |
+
html += TensorTypeToName(tensor["type"]) + " "
|
| 278 |
+
html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
|
| 279 |
+
html += (repr(tensor["shape_signature"])
|
| 280 |
+
if "shape_signature" in tensor else "[]") + "<br>"
|
| 281 |
+
html += "</span>"
|
| 282 |
+
html += repr(x)
|
| 283 |
+
html += "</span>"
|
| 284 |
+
return html
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def GenerateGraph(subgraph_idx, g, opcode_mapper):
|
| 288 |
+
"""Produces the HTML required to have a d3 visualization of the dag."""
|
| 289 |
+
|
| 290 |
+
def TensorName(idx):
|
| 291 |
+
return "t%d" % idx
|
| 292 |
+
|
| 293 |
+
def OpName(idx):
|
| 294 |
+
return "o%d" % idx
|
| 295 |
+
|
| 296 |
+
edges = []
|
| 297 |
+
nodes = []
|
| 298 |
+
first = {}
|
| 299 |
+
second = {}
|
| 300 |
+
pixel_mult = 200 # TODO(aselle): multiplier for initial placement
|
| 301 |
+
width_mult = 170 # TODO(aselle): multiplier for initial placement
|
| 302 |
+
for op_index, op in enumerate(g["operators"] or []):
|
| 303 |
+
if op["inputs"] is not None:
|
| 304 |
+
for tensor_input_position, tensor_index in enumerate(op["inputs"]):
|
| 305 |
+
if tensor_index not in first:
|
| 306 |
+
first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult,
|
| 307 |
+
(tensor_input_position + 1) * width_mult)
|
| 308 |
+
edges.append({
|
| 309 |
+
"source": TensorName(tensor_index),
|
| 310 |
+
"target": OpName(op_index)
|
| 311 |
+
})
|
| 312 |
+
if op["outputs"] is not None:
|
| 313 |
+
for tensor_output_position, tensor_index in enumerate(op["outputs"]):
|
| 314 |
+
if tensor_index not in second:
|
| 315 |
+
second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult,
|
| 316 |
+
(tensor_output_position + 1) * width_mult)
|
| 317 |
+
edges.append({
|
| 318 |
+
"target": TensorName(tensor_index),
|
| 319 |
+
"source": OpName(op_index)
|
| 320 |
+
})
|
| 321 |
+
|
| 322 |
+
nodes.append({
|
| 323 |
+
"id": OpName(op_index),
|
| 324 |
+
"name": opcode_mapper(op["opcode_index"]),
|
| 325 |
+
"group": 2,
|
| 326 |
+
"x": pixel_mult,
|
| 327 |
+
"y": (op_index + 1) * pixel_mult
|
| 328 |
+
})
|
| 329 |
+
for tensor_index, tensor in enumerate(g["tensors"]):
|
| 330 |
+
initial_y = (
|
| 331 |
+
first[tensor_index] if tensor_index in first else
|
| 332 |
+
second[tensor_index] if tensor_index in second else (0, 0))
|
| 333 |
+
|
| 334 |
+
nodes.append({
|
| 335 |
+
"id": TensorName(tensor_index),
|
| 336 |
+
"name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
|
| 337 |
+
"group": 1,
|
| 338 |
+
"x": initial_y[1],
|
| 339 |
+
"y": initial_y[0]
|
| 340 |
+
})
|
| 341 |
+
graph_str = json.dumps({"nodes": nodes, "edges": edges})
|
| 342 |
+
|
| 343 |
+
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
|
| 344 |
+
return html
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def GenerateTableHtml(items, keys_to_print, display_index=True):
|
| 348 |
+
"""Given a list of object values and keys to print, make an HTML table.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
items: Items to print an array of dicts.
|
| 352 |
+
keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
|
| 353 |
+
items[0][key] should exist. display_fn is the mapping function on display.
|
| 354 |
+
i.e. the displayed html cell will have the string returned by
|
| 355 |
+
`mapping_fn(items[0][key])`.
|
| 356 |
+
display_index: add a column which is the index of each row in `items`.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
An html table.
|
| 360 |
+
"""
|
| 361 |
+
html = ""
|
| 362 |
+
# Print the list of items
|
| 363 |
+
html += "<table><tr>\n"
|
| 364 |
+
html += "<tr>\n"
|
| 365 |
+
if display_index:
|
| 366 |
+
html += "<th>index</th>"
|
| 367 |
+
for h, mapper in keys_to_print:
|
| 368 |
+
html += "<th>%s</th>" % h
|
| 369 |
+
html += "</tr>\n"
|
| 370 |
+
for idx, tensor in enumerate(items):
|
| 371 |
+
html += "<tr>\n"
|
| 372 |
+
if display_index:
|
| 373 |
+
html += "<td>%d</td>" % idx
|
| 374 |
+
# print tensor.keys()
|
| 375 |
+
for h, mapper in keys_to_print:
|
| 376 |
+
val = tensor[h] if h in tensor else None
|
| 377 |
+
val = val if mapper is None else mapper(val)
|
| 378 |
+
html += "<td>%s</td>\n" % val
|
| 379 |
+
|
| 380 |
+
html += "</tr>\n"
|
| 381 |
+
html += "</table>\n"
|
| 382 |
+
return html
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def CamelCaseToSnakeCase(camel_case_input):
|
| 386 |
+
"""Converts an identifier in CamelCase to snake_case."""
|
| 387 |
+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
|
| 388 |
+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def FlatbufferToDict(fb, preserve_as_numpy):
|
| 392 |
+
"""Converts a hierarchy of FB objects into a nested dict.
|
| 393 |
+
|
| 394 |
+
We avoid transforming big parts of the flat buffer into python arrays. This
|
| 395 |
+
speeds conversion from ten minutes to a few seconds on big graphs.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
fb: a flat buffer structure. (i.e. ModelT)
|
| 399 |
+
preserve_as_numpy: true if all downstream np.arrays should be preserved.
|
| 400 |
+
false if all downstream np.array should become python arrays
|
| 401 |
+
Returns:
|
| 402 |
+
A dictionary representing the flatbuffer rather than a flatbuffer object.
|
| 403 |
+
"""
|
| 404 |
+
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
|
| 405 |
+
return fb
|
| 406 |
+
elif hasattr(fb, "__dict__"):
|
| 407 |
+
result = {}
|
| 408 |
+
for attribute_name in dir(fb):
|
| 409 |
+
attribute = fb.__getattribute__(attribute_name)
|
| 410 |
+
if not callable(attribute) and attribute_name[0] != "_":
|
| 411 |
+
snake_name = CamelCaseToSnakeCase(attribute_name)
|
| 412 |
+
preserve = True if attribute_name == "buffers" else preserve_as_numpy
|
| 413 |
+
result[snake_name] = FlatbufferToDict(attribute, preserve)
|
| 414 |
+
return result
|
| 415 |
+
elif isinstance(fb, np.ndarray):
|
| 416 |
+
return fb if preserve_as_numpy else fb.tolist()
|
| 417 |
+
elif hasattr(fb, "__len__"):
|
| 418 |
+
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
|
| 419 |
+
else:
|
| 420 |
+
return fb
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def CreateDictFromFlatbuffer(buffer_data):
|
| 424 |
+
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
|
| 425 |
+
model = schema_fb.ModelT.InitFromObj(model_obj)
|
| 426 |
+
return FlatbufferToDict(model, preserve_as_numpy=False)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name
|
| 430 |
+
"""Returns html description with the given tflite model.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
tflite_input: TFLite flatbuffer model path or model object.
|
| 434 |
+
input_is_filepath: Tells if tflite_input is a model path or a model object.
|
| 435 |
+
|
| 436 |
+
Returns:
|
| 437 |
+
Dump of the given tflite model in HTML format.
|
| 438 |
+
|
| 439 |
+
Raises:
|
| 440 |
+
RuntimeError: If the input is not valid.
|
| 441 |
+
"""
|
| 442 |
+
|
| 443 |
+
# Convert the model into a JSON flatbuffer using flatc (build if doesn't
|
| 444 |
+
# exist.
|
| 445 |
+
if input_is_filepath:
|
| 446 |
+
if not os.path.exists(tflite_input):
|
| 447 |
+
raise RuntimeError("Invalid filename %r" % tflite_input)
|
| 448 |
+
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
|
| 449 |
+
with open(tflite_input, "rb") as file_handle:
|
| 450 |
+
file_data = bytearray(file_handle.read())
|
| 451 |
+
data = CreateDictFromFlatbuffer(file_data)
|
| 452 |
+
elif tflite_input.endswith(".json"):
|
| 453 |
+
data = json.load(open(tflite_input))
|
| 454 |
+
else:
|
| 455 |
+
raise RuntimeError("Input file was not .tflite or .json")
|
| 456 |
+
else:
|
| 457 |
+
data = CreateDictFromFlatbuffer(tflite_input)
|
| 458 |
+
html = ""
|
| 459 |
+
html += _CSS
|
| 460 |
+
html += "<h1>TensorFlow Lite Model</h2>"
|
| 461 |
+
|
| 462 |
+
data["filename"] = tflite_input if input_is_filepath else (
|
| 463 |
+
"Null (used model object)") # Avoid special case
|
| 464 |
+
|
| 465 |
+
toplevel_stuff = [("filename", None), ("version", None),
|
| 466 |
+
("description", None)]
|
| 467 |
+
|
| 468 |
+
html += "<table>\n"
|
| 469 |
+
for key, mapping in toplevel_stuff:
|
| 470 |
+
if not mapping:
|
| 471 |
+
mapping = lambda x: x
|
| 472 |
+
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key)))
|
| 473 |
+
html += "</table>\n"
|
| 474 |
+
|
| 475 |
+
# Spec on what keys to display
|
| 476 |
+
buffer_keys_to_display = [("data", DataSizeMapper())]
|
| 477 |
+
operator_keys_to_display = [("builtin_code", BuiltinCodeToName),
|
| 478 |
+
("custom_code", NameListToString),
|
| 479 |
+
("version", None)]
|
| 480 |
+
|
| 481 |
+
# Update builtin code fields.
|
| 482 |
+
for d in data["operator_codes"]:
|
| 483 |
+
d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"])
|
| 484 |
+
|
| 485 |
+
for subgraph_idx, g in enumerate(data["subgraphs"]):
|
| 486 |
+
# Subgraph local specs on what to display
|
| 487 |
+
html += "<div class='subgraph'>"
|
| 488 |
+
tensor_mapper = TensorMapper(g)
|
| 489 |
+
opcode_mapper = OpCodeMapper(data)
|
| 490 |
+
op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
|
| 491 |
+
("builtin_options", None),
|
| 492 |
+
("opcode_index", opcode_mapper)]
|
| 493 |
+
tensor_keys_to_display = [("name", NameListToString),
|
| 494 |
+
("type", TensorTypeToName), ("shape", None),
|
| 495 |
+
("shape_signature", None), ("buffer", None),
|
| 496 |
+
("quantization", None)]
|
| 497 |
+
|
| 498 |
+
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
|
| 499 |
+
|
| 500 |
+
# Inputs and outputs.
|
| 501 |
+
html += "<h3>Inputs/Outputs</h3>\n"
|
| 502 |
+
html += GenerateTableHtml([{
|
| 503 |
+
"inputs": g["inputs"],
|
| 504 |
+
"outputs": g["outputs"]
|
| 505 |
+
}], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
|
| 506 |
+
display_index=False)
|
| 507 |
+
|
| 508 |
+
# Print the tensors.
|
| 509 |
+
html += "<h3>Tensors</h3>\n"
|
| 510 |
+
html += GenerateTableHtml(g["tensors"], tensor_keys_to_display)
|
| 511 |
+
|
| 512 |
+
# Print the ops.
|
| 513 |
+
if g["operators"]:
|
| 514 |
+
html += "<h3>Ops</h3>\n"
|
| 515 |
+
html += GenerateTableHtml(g["operators"], op_keys_to_display)
|
| 516 |
+
|
| 517 |
+
# Visual graph.
|
| 518 |
+
html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % (
|
| 519 |
+
subgraph_idx,)
|
| 520 |
+
html += GenerateGraph(subgraph_idx, g, opcode_mapper)
|
| 521 |
+
html += "</div>"
|
| 522 |
+
|
| 523 |
+
# Buffers have no data, but maybe in the future they will
|
| 524 |
+
html += "<h2>Buffers</h2>\n"
|
| 525 |
+
html += GenerateTableHtml(data["buffers"], buffer_keys_to_display)
|
| 526 |
+
|
| 527 |
+
# Operator codes
|
| 528 |
+
html += "<h2>Operator Codes</h2>\n"
|
| 529 |
+
html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
|
| 530 |
+
|
| 531 |
+
html += "</body></html>\n"
|
| 532 |
+
|
| 533 |
+
return html
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def main(argv):
|
| 537 |
+
try:
|
| 538 |
+
tflite_input = argv[1]
|
| 539 |
+
html_output = argv[2]
|
| 540 |
+
except IndexError:
|
| 541 |
+
print("Usage: %s <input tflite> <output html>" % (argv[0]))
|
| 542 |
+
else:
|
| 543 |
+
html = create_html(tflite_input)
|
| 544 |
+
with open(html_output, "w") as output_file:
|
| 545 |
+
output_file.write(html)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
if __name__ == "__main__":
|
| 549 |
+
main(sys.argv)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (216 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.pyi
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
class TestClassDef:
|
| 17 |
+
def __init__(self) -> None: ...
|
| 18 |
+
def method(self) -> object: ...
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.so
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7dd76e74055bba4c02308da5f57791117799704b278e153aef7741edbae230b2
|
| 3 |
+
size 1072920
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__init__.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""This module implements operators that AutoGraph overloads.
|
| 16 |
+
|
| 17 |
+
Note that "operator" is used loosely here, and includes control structures like
|
| 18 |
+
conditionals and loops, implemented in functional form, using for example
|
| 19 |
+
closures for the body.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# Naming conventions:
|
| 23 |
+
# * operator names match the name usually used for the respective Python
|
| 24 |
+
# idiom; examples: for_stmt, list_append
|
| 25 |
+
# * operator arguments match either of:
|
| 26 |
+
# - the corresponding Python AST attribute (e.g. the condition of an if
|
| 27 |
+
# statement is called test) if the operator represents an AST construct
|
| 28 |
+
# - the names used in the Python docs, if the operator is a function (e.g.
|
| 29 |
+
# list_ and x for append, see
|
| 30 |
+
# https://docs.python.org/3.7/tutorial/datastructures.html)
|
| 31 |
+
#
|
| 32 |
+
# All operators may accept a final argument named "opts", of a type that
|
| 33 |
+
# subclasses namedtuple and contains any arguments that are only required
|
| 34 |
+
# for some specializations of the operator.
|
| 35 |
+
|
| 36 |
+
from tensorflow.python.autograph.operators.conditional_expressions import if_exp
|
| 37 |
+
from tensorflow.python.autograph.operators.control_flow import for_stmt
|
| 38 |
+
from tensorflow.python.autograph.operators.control_flow import if_stmt
|
| 39 |
+
from tensorflow.python.autograph.operators.control_flow import while_stmt
|
| 40 |
+
from tensorflow.python.autograph.operators.data_structures import list_append
|
| 41 |
+
from tensorflow.python.autograph.operators.data_structures import list_pop
|
| 42 |
+
from tensorflow.python.autograph.operators.data_structures import list_stack
|
| 43 |
+
from tensorflow.python.autograph.operators.data_structures import ListPopOpts
|
| 44 |
+
from tensorflow.python.autograph.operators.data_structures import ListStackOpts
|
| 45 |
+
from tensorflow.python.autograph.operators.data_structures import new_list
|
| 46 |
+
from tensorflow.python.autograph.operators.exceptions import assert_stmt
|
| 47 |
+
from tensorflow.python.autograph.operators.logical import and_
|
| 48 |
+
from tensorflow.python.autograph.operators.logical import eq
|
| 49 |
+
from tensorflow.python.autograph.operators.logical import not_
|
| 50 |
+
from tensorflow.python.autograph.operators.logical import not_eq
|
| 51 |
+
from tensorflow.python.autograph.operators.logical import or_
|
| 52 |
+
from tensorflow.python.autograph.operators.py_builtins import float_
|
| 53 |
+
from tensorflow.python.autograph.operators.py_builtins import int_
|
| 54 |
+
from tensorflow.python.autograph.operators.py_builtins import len_
|
| 55 |
+
from tensorflow.python.autograph.operators.py_builtins import print_
|
| 56 |
+
from tensorflow.python.autograph.operators.py_builtins import range_
|
| 57 |
+
from tensorflow.python.autograph.operators.slices import get_item
|
| 58 |
+
from tensorflow.python.autograph.operators.slices import GetItemOpts
|
| 59 |
+
from tensorflow.python.autograph.operators.slices import set_item
|
| 60 |
+
from tensorflow.python.autograph.operators.variables import ld
|
| 61 |
+
from tensorflow.python.autograph.operators.variables import ldu
|
| 62 |
+
from tensorflow.python.autograph.operators.variables import Undefined
|
| 63 |
+
from tensorflow.python.autograph.operators.variables import UndefinedReturnValue
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.73 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/conditional_expressions.cpython-310.pyc
ADDED
|
Binary file (1.5 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/control_flow.cpython-310.pyc
ADDED
|
Binary file (36.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/data_structures.cpython-310.pyc
ADDED
|
Binary file (9.65 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/exceptions.cpython-310.pyc
ADDED
|
Binary file (2.57 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/logical.cpython-310.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|