AIDUDE0541 commited on
Commit
7d39123
·
verified ·
1 Parent(s): b682ec9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__init__.py +0 -0
  3. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__pycache__/__init__.cpython-310.pyc +0 -0
  4. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__init__.py +0 -0
  5. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__pycache__/__init__.cpython-310.pyc +0 -0
  6. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__init__.py +0 -0
  7. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__pycache__/__init__.cpython-310.pyc +0 -0
  8. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__init__.py +0 -0
  9. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/__init__.cpython-310.pyc +0 -0
  10. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/gen_rpc_ops.cpython-310.pyc +0 -0
  11. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/gen_rpc_ops.py +763 -0
  12. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__init__.py +0 -0
  13. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/__init__.cpython-310.pyc +0 -0
  14. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2.cpython-310.pyc +0 -0
  15. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2_grpc.cpython-310.pyc +0 -0
  16. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2.py +37 -0
  17. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2_grpc.py +63 -0
  18. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__init__.py +0 -0
  19. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__pycache__/__init__.cpython-310.pyc +0 -0
  20. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/analyzer.py +107 -0
  21. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/conversion_metadata_schema_py_generated.py +568 -0
  22. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/convert_phase.py +219 -0
  23. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/schema_util.py +45 -0
  24. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/tflite_convert.py +696 -0
  25. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/util.py +1177 -0
  26. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__init__.py +0 -0
  27. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/__init__.cpython-310.pyc +0 -0
  28. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/flatbuffer_utils.cpython-310.pyc +0 -0
  29. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/visualize.cpython-310.pyc +0 -0
  30. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/flatbuffer_utils.py +455 -0
  31. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__init__.py +0 -0
  32. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__pycache__/__init__.cpython-310.pyc +0 -0
  33. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__init__.py +0 -0
  34. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__pycache__/__init__.cpython-310.pyc +0 -0
  35. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__init__.py +0 -0
  36. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/__init__.cpython-310.pyc +0 -0
  37. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/debugger.cpython-310.pyc +0 -0
  38. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/debugger.py +549 -0
  39. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/visualize.py +549 -0
  40. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__init__.py +0 -0
  41. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__pycache__/__init__.cpython-310.pyc +0 -0
  42. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.pyi +18 -0
  43. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.so +3 -0
  44. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__init__.py +63 -0
  45. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/__init__.cpython-310.pyc +0 -0
  46. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/conditional_expressions.cpython-310.pyc +0 -0
  47. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/control_flow.cpython-310.pyc +0 -0
  48. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/data_structures.cpython-310.pyc +0 -0
  49. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/exceptions.cpython-310.pyc +0 -0
  50. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/logical.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -197,3 +197,4 @@ SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/gr
197
  SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/grappler/_pywrap_tf_item.so filter=lfs diff=lfs merge=lfs -text
198
  SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/profiler/internal/_pywrap_profiler.so filter=lfs diff=lfs merge=lfs -text
199
  SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/client/_pywrap_tf_session.so filter=lfs diff=lfs merge=lfs -text
 
 
197
  SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/grappler/_pywrap_tf_item.so filter=lfs diff=lfs merge=lfs -text
198
  SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/profiler/internal/_pywrap_profiler.so filter=lfs diff=lfs merge=lfs -text
199
  SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/client/_pywrap_tf_session.so filter=lfs diff=lfs merge=lfs -text
200
+ SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.so filter=lfs diff=lfs merge=lfs -text
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (197 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (210 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (214 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (222 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/__pycache__/gen_rpc_ops.cpython-310.pyc ADDED
Binary file (16.6 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/kernels/gen_rpc_ops.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python wrappers around TensorFlow ops.
2
+
3
+ This file is MACHINE GENERATED! Do not edit.
4
+ """
5
+
6
+ import collections
7
+
8
+ from tensorflow.python import pywrap_tfe as pywrap_tfe
9
+ from tensorflow.python.eager import context as _context
10
+ from tensorflow.python.eager import core as _core
11
+ from tensorflow.python.eager import execute as _execute
12
+ from tensorflow.python.framework import dtypes as _dtypes
13
+ from tensorflow.security.fuzzing.py import annotation_types as _atypes
14
+
15
+ from tensorflow.python.framework import op_def_registry as _op_def_registry
16
+ from tensorflow.python.framework import ops as _ops
17
+ from tensorflow.python.framework import op_def_library as _op_def_library
18
+ from tensorflow.python.util.deprecation import deprecated_endpoints
19
+ from tensorflow.python.util import dispatch as _dispatch
20
+ from tensorflow.python.util.tf_export import tf_export
21
+
22
+ from typing import TypeVar, List, Any
23
+ from typing_extensions import Annotated
24
+
25
+ @_dispatch.add_fallback_dispatch_list
26
+ @_dispatch.add_type_based_api_dispatcher
27
+ @tf_export('delete_rpc_future_resource')
28
+ def delete_rpc_future_resource(handle: Annotated[Any, _atypes.Resource], deleter: Annotated[Any, _atypes.Variant], name=None):
29
+ r"""TODO: add doc.
30
+
31
+ Args:
32
+ handle: A `Tensor` of type `resource`.
33
+ deleter: A `Tensor` of type `variant`.
34
+ name: A name for the operation (optional).
35
+
36
+ Returns:
37
+ The created Operation.
38
+ """
39
+ _ctx = _context._context or _context.context()
40
+ tld = _ctx._thread_local_data
41
+ if tld.is_eager:
42
+ try:
43
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
44
+ _ctx, "DeleteRpcFutureResource", name, handle, deleter)
45
+ return _result
46
+ except _core._NotOkStatusException as e:
47
+ _ops.raise_from_not_ok_status(e, name)
48
+ except _core._FallbackException:
49
+ pass
50
+ try:
51
+ _result = _dispatcher_for_delete_rpc_future_resource(
52
+ (handle, deleter, name,), None)
53
+ if _result is not NotImplemented:
54
+ return _result
55
+ return delete_rpc_future_resource_eager_fallback(
56
+ handle, deleter, name=name, ctx=_ctx)
57
+ except _core._SymbolicException:
58
+ pass # Add nodes to the TensorFlow graph.
59
+ except (TypeError, ValueError):
60
+ _result = _dispatch.dispatch(
61
+ delete_rpc_future_resource, (), dict(handle=handle,
62
+ deleter=deleter, name=name)
63
+ )
64
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
65
+ return _result
66
+ raise
67
+ else:
68
+ _result = _dispatcher_for_delete_rpc_future_resource(
69
+ (handle, deleter, name,), None)
70
+ if _result is not NotImplemented:
71
+ return _result
72
+ # Add nodes to the TensorFlow graph.
73
+ try:
74
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
75
+ "DeleteRpcFutureResource", handle=handle, deleter=deleter, name=name)
76
+ except (TypeError, ValueError):
77
+ _result = _dispatch.dispatch(
78
+ delete_rpc_future_resource, (), dict(handle=handle, deleter=deleter,
79
+ name=name)
80
+ )
81
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
82
+ return _result
83
+ raise
84
+ return _op
85
+ DeleteRpcFutureResource = tf_export("raw_ops.DeleteRpcFutureResource")(_ops.to_raw_op(delete_rpc_future_resource))
86
+ _dispatcher_for_delete_rpc_future_resource = delete_rpc_future_resource._tf_type_based_dispatcher.Dispatch
87
+
88
+
89
+ def delete_rpc_future_resource_eager_fallback(handle: Annotated[Any, _atypes.Resource], deleter: Annotated[Any, _atypes.Variant], name, ctx):
90
+ handle = _ops.convert_to_tensor(handle, _dtypes.resource)
91
+ deleter = _ops.convert_to_tensor(deleter, _dtypes.variant)
92
+ _inputs_flat = [handle, deleter]
93
+ _attrs = None
94
+ _result = _execute.execute(b"DeleteRpcFutureResource", 0,
95
+ inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
96
+ name=name)
97
+ _result = None
98
+ return _result
99
+
100
+ _RpcCallOutput = collections.namedtuple(
101
+ "RpcCall",
102
+ ["future", "deleter"])
103
+
104
+
105
+ @_dispatch.add_fallback_dispatch_list
106
+ @_dispatch.add_type_based_api_dispatcher
107
+ @tf_export('rpc_call')
108
+ def rpc_call(client: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], args, timeout_in_ms: Annotated[Any, _atypes.Int64], name=None):
109
+ r"""TODO: add doc.
110
+
111
+ Args:
112
+ client: A `Tensor` of type `resource`.
113
+ method_name: A `Tensor` of type `string`.
114
+ args: A list of `Tensor` objects.
115
+ timeout_in_ms: A `Tensor` of type `int64`.
116
+ name: A name for the operation (optional).
117
+
118
+ Returns:
119
+ A tuple of `Tensor` objects (future, deleter).
120
+
121
+ future: A `Tensor` of type `resource`.
122
+ deleter: A `Tensor` of type `variant`.
123
+ """
124
+ _ctx = _context._context or _context.context()
125
+ tld = _ctx._thread_local_data
126
+ if tld.is_eager:
127
+ try:
128
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
129
+ _ctx, "RpcCall", name, client, method_name, args, timeout_in_ms)
130
+ _result = _RpcCallOutput._make(_result)
131
+ return _result
132
+ except _core._NotOkStatusException as e:
133
+ _ops.raise_from_not_ok_status(e, name)
134
+ except _core._FallbackException:
135
+ pass
136
+ try:
137
+ _result = _dispatcher_for_rpc_call(
138
+ (client, method_name, args, timeout_in_ms, name,), None)
139
+ if _result is not NotImplemented:
140
+ return _result
141
+ return rpc_call_eager_fallback(
142
+ client, method_name, args, timeout_in_ms, name=name, ctx=_ctx)
143
+ except _core._SymbolicException:
144
+ pass # Add nodes to the TensorFlow graph.
145
+ except (TypeError, ValueError):
146
+ _result = _dispatch.dispatch(
147
+ rpc_call, (), dict(client=client, method_name=method_name,
148
+ args=args, timeout_in_ms=timeout_in_ms,
149
+ name=name)
150
+ )
151
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
152
+ return _result
153
+ raise
154
+ else:
155
+ _result = _dispatcher_for_rpc_call(
156
+ (client, method_name, args, timeout_in_ms, name,), None)
157
+ if _result is not NotImplemented:
158
+ return _result
159
+ # Add nodes to the TensorFlow graph.
160
+ try:
161
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
162
+ "RpcCall", client=client, method_name=method_name, args=args,
163
+ timeout_in_ms=timeout_in_ms, name=name)
164
+ except (TypeError, ValueError):
165
+ _result = _dispatch.dispatch(
166
+ rpc_call, (), dict(client=client, method_name=method_name,
167
+ args=args, timeout_in_ms=timeout_in_ms,
168
+ name=name)
169
+ )
170
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
171
+ return _result
172
+ raise
173
+ _result = _outputs[:]
174
+ if _execute.must_record_gradient():
175
+ _attrs = ("Tin", _op.get_attr("Tin"))
176
+ _inputs_flat = _op.inputs
177
+ _execute.record_gradient(
178
+ "RpcCall", _inputs_flat, _attrs, _result)
179
+ _result = _RpcCallOutput._make(_result)
180
+ return _result
181
+
182
+ RpcCall = tf_export("raw_ops.RpcCall")(_ops.to_raw_op(rpc_call))
183
+ _dispatcher_for_rpc_call = rpc_call._tf_type_based_dispatcher.Dispatch
184
+
185
+
186
+ def rpc_call_eager_fallback(client: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], args, timeout_in_ms: Annotated[Any, _atypes.Int64], name, ctx):
187
+ _attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
188
+ client = _ops.convert_to_tensor(client, _dtypes.resource)
189
+ method_name = _ops.convert_to_tensor(method_name, _dtypes.string)
190
+ timeout_in_ms = _ops.convert_to_tensor(timeout_in_ms, _dtypes.int64)
191
+ _inputs_flat = [client, method_name] + list(args) + [timeout_in_ms]
192
+ _attrs = ("Tin", _attr_Tin)
193
+ _result = _execute.execute(b"RpcCall", 2, inputs=_inputs_flat, attrs=_attrs,
194
+ ctx=ctx, name=name)
195
+ if _execute.must_record_gradient():
196
+ _execute.record_gradient(
197
+ "RpcCall", _inputs_flat, _attrs, _result)
198
+ _result = _RpcCallOutput._make(_result)
199
+ return _result
200
+
201
+ _RpcCheckStatusOutput = collections.namedtuple(
202
+ "RpcCheckStatus",
203
+ ["error_code", "error"])
204
+
205
+
206
+ @_dispatch.add_fallback_dispatch_list
207
+ @_dispatch.add_type_based_api_dispatcher
208
+ @tf_export('rpc_check_status')
209
+ def rpc_check_status(status_or: Annotated[Any, _atypes.Resource], name=None):
210
+ r"""TODO: add doc.
211
+
212
+ Args:
213
+ status_or: A `Tensor` of type `resource`.
214
+ name: A name for the operation (optional).
215
+
216
+ Returns:
217
+ A tuple of `Tensor` objects (error_code, error).
218
+
219
+ error_code: A `Tensor` of type `int64`.
220
+ error: A `Tensor` of type `string`.
221
+ """
222
+ _ctx = _context._context or _context.context()
223
+ tld = _ctx._thread_local_data
224
+ if tld.is_eager:
225
+ try:
226
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
227
+ _ctx, "RpcCheckStatus", name, status_or)
228
+ _result = _RpcCheckStatusOutput._make(_result)
229
+ return _result
230
+ except _core._NotOkStatusException as e:
231
+ _ops.raise_from_not_ok_status(e, name)
232
+ except _core._FallbackException:
233
+ pass
234
+ try:
235
+ _result = _dispatcher_for_rpc_check_status(
236
+ (status_or, name,), None)
237
+ if _result is not NotImplemented:
238
+ return _result
239
+ return rpc_check_status_eager_fallback(
240
+ status_or, name=name, ctx=_ctx)
241
+ except _core._SymbolicException:
242
+ pass # Add nodes to the TensorFlow graph.
243
+ except (TypeError, ValueError):
244
+ _result = _dispatch.dispatch(
245
+ rpc_check_status, (), dict(status_or=status_or, name=name)
246
+ )
247
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
248
+ return _result
249
+ raise
250
+ else:
251
+ _result = _dispatcher_for_rpc_check_status(
252
+ (status_or, name,), None)
253
+ if _result is not NotImplemented:
254
+ return _result
255
+ # Add nodes to the TensorFlow graph.
256
+ try:
257
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
258
+ "RpcCheckStatus", status_or=status_or, name=name)
259
+ except (TypeError, ValueError):
260
+ _result = _dispatch.dispatch(
261
+ rpc_check_status, (), dict(status_or=status_or, name=name)
262
+ )
263
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
264
+ return _result
265
+ raise
266
+ _result = _outputs[:]
267
+ if _execute.must_record_gradient():
268
+ _attrs = ()
269
+ _inputs_flat = _op.inputs
270
+ _execute.record_gradient(
271
+ "RpcCheckStatus", _inputs_flat, _attrs, _result)
272
+ _result = _RpcCheckStatusOutput._make(_result)
273
+ return _result
274
+
275
+ RpcCheckStatus = tf_export("raw_ops.RpcCheckStatus")(_ops.to_raw_op(rpc_check_status))
276
+ _dispatcher_for_rpc_check_status = rpc_check_status._tf_type_based_dispatcher.Dispatch
277
+
278
+
279
+ def rpc_check_status_eager_fallback(status_or: Annotated[Any, _atypes.Resource], name, ctx):
280
+ status_or = _ops.convert_to_tensor(status_or, _dtypes.resource)
281
+ _inputs_flat = [status_or]
282
+ _attrs = None
283
+ _result = _execute.execute(b"RpcCheckStatus", 2, inputs=_inputs_flat,
284
+ attrs=_attrs, ctx=ctx, name=name)
285
+ if _execute.must_record_gradient():
286
+ _execute.record_gradient(
287
+ "RpcCheckStatus", _inputs_flat, _attrs, _result)
288
+ _result = _RpcCheckStatusOutput._make(_result)
289
+ return _result
290
+
291
+ _RpcClientOutput = collections.namedtuple(
292
+ "RpcClient",
293
+ ["client", "method_specs"])
294
+
295
+
296
+ @_dispatch.add_fallback_dispatch_list
297
+ @_dispatch.add_type_based_api_dispatcher
298
+ @tf_export('rpc_client')
299
+ def rpc_client(server_address: Annotated[Any, _atypes.String], timeout_in_ms: Annotated[Any, _atypes.Int64], shared_name:str="", list_registered_methods:bool=False, name=None):
300
+ r"""TODO: add doc.
301
+
302
+ Args:
303
+ server_address: A `Tensor` of type `string`.
304
+ timeout_in_ms: A `Tensor` of type `int64`.
305
+ shared_name: An optional `string`. Defaults to `""`.
306
+ list_registered_methods: An optional `bool`. Defaults to `False`.
307
+ name: A name for the operation (optional).
308
+
309
+ Returns:
310
+ A tuple of `Tensor` objects (client, method_specs).
311
+
312
+ client: A `Tensor` of type `resource`.
313
+ method_specs: A `Tensor` of type `string`.
314
+ """
315
+ _ctx = _context._context or _context.context()
316
+ tld = _ctx._thread_local_data
317
+ if tld.is_eager:
318
+ try:
319
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
320
+ _ctx, "RpcClient", name, server_address, timeout_in_ms, "shared_name",
321
+ shared_name, "list_registered_methods", list_registered_methods)
322
+ _result = _RpcClientOutput._make(_result)
323
+ return _result
324
+ except _core._NotOkStatusException as e:
325
+ _ops.raise_from_not_ok_status(e, name)
326
+ except _core._FallbackException:
327
+ pass
328
+ try:
329
+ _result = _dispatcher_for_rpc_client(
330
+ (server_address, timeout_in_ms, shared_name,
331
+ list_registered_methods, name,), None)
332
+ if _result is not NotImplemented:
333
+ return _result
334
+ return rpc_client_eager_fallback(
335
+ server_address, timeout_in_ms, shared_name=shared_name,
336
+ list_registered_methods=list_registered_methods, name=name,
337
+ ctx=_ctx)
338
+ except _core._SymbolicException:
339
+ pass # Add nodes to the TensorFlow graph.
340
+ except (TypeError, ValueError):
341
+ _result = _dispatch.dispatch(
342
+ rpc_client, (), dict(server_address=server_address,
343
+ timeout_in_ms=timeout_in_ms,
344
+ shared_name=shared_name,
345
+ list_registered_methods=list_registered_methods,
346
+ name=name)
347
+ )
348
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
349
+ return _result
350
+ raise
351
+ else:
352
+ _result = _dispatcher_for_rpc_client(
353
+ (server_address, timeout_in_ms, shared_name, list_registered_methods,
354
+ name,), None)
355
+ if _result is not NotImplemented:
356
+ return _result
357
+ # Add nodes to the TensorFlow graph.
358
+ if shared_name is None:
359
+ shared_name = ""
360
+ shared_name = _execute.make_str(shared_name, "shared_name")
361
+ if list_registered_methods is None:
362
+ list_registered_methods = False
363
+ list_registered_methods = _execute.make_bool(list_registered_methods, "list_registered_methods")
364
+ try:
365
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
366
+ "RpcClient", server_address=server_address,
367
+ timeout_in_ms=timeout_in_ms, shared_name=shared_name,
368
+ list_registered_methods=list_registered_methods,
369
+ name=name)
370
+ except (TypeError, ValueError):
371
+ _result = _dispatch.dispatch(
372
+ rpc_client, (), dict(server_address=server_address,
373
+ timeout_in_ms=timeout_in_ms,
374
+ shared_name=shared_name,
375
+ list_registered_methods=list_registered_methods,
376
+ name=name)
377
+ )
378
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
379
+ return _result
380
+ raise
381
+ _result = _outputs[:]
382
+ if _execute.must_record_gradient():
383
+ _attrs = ("shared_name", _op.get_attr("shared_name"),
384
+ "list_registered_methods",
385
+ _op._get_attr_bool("list_registered_methods"))
386
+ _inputs_flat = _op.inputs
387
+ _execute.record_gradient(
388
+ "RpcClient", _inputs_flat, _attrs, _result)
389
+ _result = _RpcClientOutput._make(_result)
390
+ return _result
391
+
392
+ RpcClient = tf_export("raw_ops.RpcClient")(_ops.to_raw_op(rpc_client))
393
+ _dispatcher_for_rpc_client = rpc_client._tf_type_based_dispatcher.Dispatch
394
+
395
+
396
+ def rpc_client_eager_fallback(server_address: Annotated[Any, _atypes.String], timeout_in_ms: Annotated[Any, _atypes.Int64], shared_name: str, list_registered_methods: bool, name, ctx):
397
+ if shared_name is None:
398
+ shared_name = ""
399
+ shared_name = _execute.make_str(shared_name, "shared_name")
400
+ if list_registered_methods is None:
401
+ list_registered_methods = False
402
+ list_registered_methods = _execute.make_bool(list_registered_methods, "list_registered_methods")
403
+ server_address = _ops.convert_to_tensor(server_address, _dtypes.string)
404
+ timeout_in_ms = _ops.convert_to_tensor(timeout_in_ms, _dtypes.int64)
405
+ _inputs_flat = [server_address, timeout_in_ms]
406
+ _attrs = ("shared_name", shared_name, "list_registered_methods",
407
+ list_registered_methods)
408
+ _result = _execute.execute(b"RpcClient", 2, inputs=_inputs_flat,
409
+ attrs=_attrs, ctx=ctx, name=name)
410
+ if _execute.must_record_gradient():
411
+ _execute.record_gradient(
412
+ "RpcClient", _inputs_flat, _attrs, _result)
413
+ _result = _RpcClientOutput._make(_result)
414
+ return _result
415
+
416
+
417
+ @_dispatch.add_fallback_dispatch_list
418
+ @_dispatch.add_type_based_api_dispatcher
419
+ @tf_export('rpc_get_value')
420
+ def rpc_get_value(status_or: Annotated[Any, _atypes.Resource], Tout, name=None):
421
+ r"""TODO: add doc.
422
+
423
+ Args:
424
+ status_or: A `Tensor` of type `resource`.
425
+ Tout: A list of `tf.DTypes`.
426
+ name: A name for the operation (optional).
427
+
428
+ Returns:
429
+ A list of `Tensor` objects of type `Tout`.
430
+ """
431
+ _ctx = _context._context or _context.context()
432
+ tld = _ctx._thread_local_data
433
+ if tld.is_eager:
434
+ try:
435
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
436
+ _ctx, "RpcGetValue", name, status_or, "Tout", Tout)
437
+ return _result
438
+ except _core._NotOkStatusException as e:
439
+ _ops.raise_from_not_ok_status(e, name)
440
+ except _core._FallbackException:
441
+ pass
442
+ try:
443
+ _result = _dispatcher_for_rpc_get_value(
444
+ (status_or, Tout, name,), None)
445
+ if _result is not NotImplemented:
446
+ return _result
447
+ return rpc_get_value_eager_fallback(
448
+ status_or, Tout=Tout, name=name, ctx=_ctx)
449
+ except _core._SymbolicException:
450
+ pass # Add nodes to the TensorFlow graph.
451
+ except (TypeError, ValueError):
452
+ _result = _dispatch.dispatch(
453
+ rpc_get_value, (), dict(status_or=status_or, Tout=Tout, name=name)
454
+ )
455
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
456
+ return _result
457
+ raise
458
+ else:
459
+ _result = _dispatcher_for_rpc_get_value(
460
+ (status_or, Tout, name,), None)
461
+ if _result is not NotImplemented:
462
+ return _result
463
+ # Add nodes to the TensorFlow graph.
464
+ if not isinstance(Tout, (list, tuple)):
465
+ raise TypeError(
466
+ "Expected list for 'Tout' argument to "
467
+ "'rpc_get_value' Op, not %r." % Tout)
468
+ Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
469
+ try:
470
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
471
+ "RpcGetValue", status_or=status_or, Tout=Tout, name=name)
472
+ except (TypeError, ValueError):
473
+ _result = _dispatch.dispatch(
474
+ rpc_get_value, (), dict(status_or=status_or, Tout=Tout, name=name)
475
+ )
476
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
477
+ return _result
478
+ raise
479
+ _result = _outputs[:]
480
+ if not _result:
481
+ return _op
482
+ if _execute.must_record_gradient():
483
+ _attrs = ("Tout", _op.get_attr("Tout"))
484
+ _inputs_flat = _op.inputs
485
+ _execute.record_gradient(
486
+ "RpcGetValue", _inputs_flat, _attrs, _result)
487
+ return _result
488
+
489
+ RpcGetValue = tf_export("raw_ops.RpcGetValue")(_ops.to_raw_op(rpc_get_value))
490
+ _dispatcher_for_rpc_get_value = rpc_get_value._tf_type_based_dispatcher.Dispatch
491
+
492
+
493
+ def rpc_get_value_eager_fallback(status_or: Annotated[Any, _atypes.Resource], Tout, name, ctx):
494
+ if not isinstance(Tout, (list, tuple)):
495
+ raise TypeError(
496
+ "Expected list for 'Tout' argument to "
497
+ "'rpc_get_value' Op, not %r." % Tout)
498
+ Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
499
+ status_or = _ops.convert_to_tensor(status_or, _dtypes.resource)
500
+ _inputs_flat = [status_or]
501
+ _attrs = ("Tout", Tout)
502
+ _result = _execute.execute(b"RpcGetValue", len(Tout), inputs=_inputs_flat,
503
+ attrs=_attrs, ctx=ctx, name=name)
504
+ if _execute.must_record_gradient():
505
+ _execute.record_gradient(
506
+ "RpcGetValue", _inputs_flat, _attrs, _result)
507
+ return _result
508
+
509
+
510
+ @_dispatch.add_fallback_dispatch_list
511
+ @_dispatch.add_type_based_api_dispatcher
512
+ @tf_export('rpc_server')
513
+ def rpc_server(server_address: Annotated[Any, _atypes.String], name=None) -> Annotated[Any, _atypes.Resource]:
514
+ r"""TODO: add doc.
515
+
516
+ Args:
517
+ server_address: A `Tensor` of type `string`.
518
+ name: A name for the operation (optional).
519
+
520
+ Returns:
521
+ A `Tensor` of type `resource`.
522
+ """
523
+ _ctx = _context._context or _context.context()
524
+ tld = _ctx._thread_local_data
525
+ if tld.is_eager:
526
+ try:
527
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
528
+ _ctx, "RpcServer", name, server_address)
529
+ return _result
530
+ except _core._NotOkStatusException as e:
531
+ _ops.raise_from_not_ok_status(e, name)
532
+ except _core._FallbackException:
533
+ pass
534
+ try:
535
+ _result = _dispatcher_for_rpc_server(
536
+ (server_address, name,), None)
537
+ if _result is not NotImplemented:
538
+ return _result
539
+ return rpc_server_eager_fallback(
540
+ server_address, name=name, ctx=_ctx)
541
+ except _core._SymbolicException:
542
+ pass # Add nodes to the TensorFlow graph.
543
+ except (TypeError, ValueError):
544
+ _result = _dispatch.dispatch(
545
+ rpc_server, (), dict(server_address=server_address, name=name)
546
+ )
547
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
548
+ return _result
549
+ raise
550
+ else:
551
+ _result = _dispatcher_for_rpc_server(
552
+ (server_address, name,), None)
553
+ if _result is not NotImplemented:
554
+ return _result
555
+ # Add nodes to the TensorFlow graph.
556
+ try:
557
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
558
+ "RpcServer", server_address=server_address, name=name)
559
+ except (TypeError, ValueError):
560
+ _result = _dispatch.dispatch(
561
+ rpc_server, (), dict(server_address=server_address, name=name)
562
+ )
563
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
564
+ return _result
565
+ raise
566
+ _result = _outputs[:]
567
+ if _execute.must_record_gradient():
568
+ _attrs = ()
569
+ _inputs_flat = _op.inputs
570
+ _execute.record_gradient(
571
+ "RpcServer", _inputs_flat, _attrs, _result)
572
+ _result, = _result
573
+ return _result
574
+
575
+ RpcServer = tf_export("raw_ops.RpcServer")(_ops.to_raw_op(rpc_server))
576
+ _dispatcher_for_rpc_server = rpc_server._tf_type_based_dispatcher.Dispatch
577
+
578
+
579
+ def rpc_server_eager_fallback(server_address: Annotated[Any, _atypes.String], name, ctx) -> Annotated[Any, _atypes.Resource]:
580
+ server_address = _ops.convert_to_tensor(server_address, _dtypes.string)
581
+ _inputs_flat = [server_address]
582
+ _attrs = None
583
+ _result = _execute.execute(b"RpcServer", 1, inputs=_inputs_flat,
584
+ attrs=_attrs, ctx=ctx, name=name)
585
+ if _execute.must_record_gradient():
586
+ _execute.record_gradient(
587
+ "RpcServer", _inputs_flat, _attrs, _result)
588
+ _result, = _result
589
+ return _result
590
+
591
+
592
+ @_dispatch.add_fallback_dispatch_list
593
+ @_dispatch.add_type_based_api_dispatcher
594
+ @tf_export('rpc_server_register')
595
+ def rpc_server_register(server: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], captured_inputs, f, output_specs: str, input_specs:str="", name=None):
596
+ r"""TODO: add doc.
597
+
598
+ Args:
599
+ server: A `Tensor` of type `resource`.
600
+ method_name: A `Tensor` of type `string`.
601
+ captured_inputs: A list of `Tensor` objects.
602
+ f: A function decorated with @Defun.
603
+ output_specs: A `string`.
604
+ input_specs: An optional `string`. Defaults to `""`.
605
+ name: A name for the operation (optional).
606
+
607
+ Returns:
608
+ The created Operation.
609
+ """
610
+ _ctx = _context._context or _context.context()
611
+ tld = _ctx._thread_local_data
612
+ if tld.is_eager:
613
+ try:
614
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
615
+ _ctx, "RpcServerRegister", name, server, method_name, captured_inputs,
616
+ "f", f, "input_specs", input_specs, "output_specs", output_specs)
617
+ return _result
618
+ except _core._NotOkStatusException as e:
619
+ _ops.raise_from_not_ok_status(e, name)
620
+ except _core._FallbackException:
621
+ pass
622
+ try:
623
+ _result = _dispatcher_for_rpc_server_register(
624
+ (server, method_name, captured_inputs, f, output_specs, input_specs,
625
+ name,), None)
626
+ if _result is not NotImplemented:
627
+ return _result
628
+ return rpc_server_register_eager_fallback(
629
+ server, method_name, captured_inputs, f=f, input_specs=input_specs,
630
+ output_specs=output_specs, name=name, ctx=_ctx)
631
+ except _core._SymbolicException:
632
+ pass # Add nodes to the TensorFlow graph.
633
+ except (TypeError, ValueError):
634
+ _result = _dispatch.dispatch(
635
+ rpc_server_register, (), dict(server=server,
636
+ method_name=method_name,
637
+ captured_inputs=captured_inputs,
638
+ f=f, output_specs=output_specs,
639
+ input_specs=input_specs, name=name)
640
+ )
641
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
642
+ return _result
643
+ raise
644
+ else:
645
+ _result = _dispatcher_for_rpc_server_register(
646
+ (server, method_name, captured_inputs, f, output_specs, input_specs,
647
+ name,), None)
648
+ if _result is not NotImplemented:
649
+ return _result
650
+ # Add nodes to the TensorFlow graph.
651
+ output_specs = _execute.make_str(output_specs, "output_specs")
652
+ if input_specs is None:
653
+ input_specs = ""
654
+ input_specs = _execute.make_str(input_specs, "input_specs")
655
+ try:
656
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
657
+ "RpcServerRegister", server=server, method_name=method_name,
658
+ captured_inputs=captured_inputs, f=f,
659
+ output_specs=output_specs,
660
+ input_specs=input_specs, name=name)
661
+ except (TypeError, ValueError):
662
+ _result = _dispatch.dispatch(
663
+ rpc_server_register, (), dict(server=server,
664
+ method_name=method_name,
665
+ captured_inputs=captured_inputs, f=f,
666
+ output_specs=output_specs,
667
+ input_specs=input_specs, name=name)
668
+ )
669
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
670
+ return _result
671
+ raise
672
+ return _op
673
+ RpcServerRegister = tf_export("raw_ops.RpcServerRegister")(_ops.to_raw_op(rpc_server_register))
674
+ _dispatcher_for_rpc_server_register = rpc_server_register._tf_type_based_dispatcher.Dispatch
675
+
676
+
677
+ def rpc_server_register_eager_fallback(server: Annotated[Any, _atypes.Resource], method_name: Annotated[Any, _atypes.String], captured_inputs, f, output_specs: str, input_specs: str, name, ctx):
678
+ output_specs = _execute.make_str(output_specs, "output_specs")
679
+ if input_specs is None:
680
+ input_specs = ""
681
+ input_specs = _execute.make_str(input_specs, "input_specs")
682
+ _attr_Tin, captured_inputs = _execute.convert_to_mixed_eager_tensors(captured_inputs, ctx)
683
+ server = _ops.convert_to_tensor(server, _dtypes.resource)
684
+ method_name = _ops.convert_to_tensor(method_name, _dtypes.string)
685
+ _inputs_flat = [server, method_name] + list(captured_inputs)
686
+ _attrs = ("Tin", _attr_Tin, "f", f, "input_specs", input_specs,
687
+ "output_specs", output_specs)
688
+ _result = _execute.execute(b"RpcServerRegister", 0, inputs=_inputs_flat,
689
+ attrs=_attrs, ctx=ctx, name=name)
690
+ _result = None
691
+ return _result
692
+
693
+
694
+ @_dispatch.add_fallback_dispatch_list
695
+ @_dispatch.add_type_based_api_dispatcher
696
+ @tf_export('rpc_server_start')
697
+ def rpc_server_start(server: Annotated[Any, _atypes.Resource], name=None):
698
+ r"""TODO: add doc.
699
+
700
+ Args:
701
+ server: A `Tensor` of type `resource`.
702
+ name: A name for the operation (optional).
703
+
704
+ Returns:
705
+ The created Operation.
706
+ """
707
+ _ctx = _context._context or _context.context()
708
+ tld = _ctx._thread_local_data
709
+ if tld.is_eager:
710
+ try:
711
+ _result = pywrap_tfe.TFE_Py_FastPathExecute(
712
+ _ctx, "RpcServerStart", name, server)
713
+ return _result
714
+ except _core._NotOkStatusException as e:
715
+ _ops.raise_from_not_ok_status(e, name)
716
+ except _core._FallbackException:
717
+ pass
718
+ try:
719
+ _result = _dispatcher_for_rpc_server_start(
720
+ (server, name,), None)
721
+ if _result is not NotImplemented:
722
+ return _result
723
+ return rpc_server_start_eager_fallback(
724
+ server, name=name, ctx=_ctx)
725
+ except _core._SymbolicException:
726
+ pass # Add nodes to the TensorFlow graph.
727
+ except (TypeError, ValueError):
728
+ _result = _dispatch.dispatch(
729
+ rpc_server_start, (), dict(server=server, name=name)
730
+ )
731
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
732
+ return _result
733
+ raise
734
+ else:
735
+ _result = _dispatcher_for_rpc_server_start(
736
+ (server, name,), None)
737
+ if _result is not NotImplemented:
738
+ return _result
739
+ # Add nodes to the TensorFlow graph.
740
+ try:
741
+ _, _, _op, _outputs = _op_def_library._apply_op_helper(
742
+ "RpcServerStart", server=server, name=name)
743
+ except (TypeError, ValueError):
744
+ _result = _dispatch.dispatch(
745
+ rpc_server_start, (), dict(server=server, name=name)
746
+ )
747
+ if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
748
+ return _result
749
+ raise
750
+ return _op
751
+ RpcServerStart = tf_export("raw_ops.RpcServerStart")(_ops.to_raw_op(rpc_server_start))
752
+ _dispatcher_for_rpc_server_start = rpc_server_start._tf_type_based_dispatcher.Dispatch
753
+
754
+
755
+ def rpc_server_start_eager_fallback(server: Annotated[Any, _atypes.Resource], name, ctx):
756
+ server = _ops.convert_to_tensor(server, _dtypes.resource)
757
+ _inputs_flat = [server]
758
+ _attrs = None
759
+ _result = _execute.execute(b"RpcServerStart", 0, inputs=_inputs_flat,
760
+ attrs=_attrs, ctx=ctx, name=name)
761
+ _result = None
762
+ return _result
763
+
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (220 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2.cpython-310.pyc ADDED
Binary file (2.03 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/__pycache__/tf_rpc_service_pb2_grpc.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: tensorflow/distribute/experimental/rpc/proto/tf_rpc_service.proto
4
+ """Generated protocol buffer code."""
5
+ from google.protobuf.internal import builder as _builder
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ # @@protoc_insertion_point(imports)
10
+
11
+ _sym_db = _symbol_database.Default()
12
+
13
+
14
+ from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
15
+ from tensorflow.core.protobuf import struct_pb2 as tensorflow_dot_core_dot_protobuf_dot_struct__pb2
16
+
17
+
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nAtensorflow/distribute/experimental/rpc/proto/tf_rpc_service.proto\x12\x0etensorflow.rpc\x1a&tensorflow/core/framework/tensor.proto\x1a%tensorflow/core/protobuf/struct.proto\"M\n\x0b\x43\x61llRequest\x12\x0e\n\x06method\x18\x01 \x01(\t\x12.\n\rinput_tensors\x18\x02 \x03(\x0b\x32\x17.tensorflow.TensorProto\"?\n\x0c\x43\x61llResponse\x12/\n\x0eoutput_tensors\x18\x01 \x03(\x0b\x32\x17.tensorflow.TensorProto\"\r\n\x0bListRequest\"\x87\x01\n\x10RegisteredMethod\x12\x0e\n\x06method\x18\x01 \x01(\t\x12\x30\n\x0binput_specs\x18\x02 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\x12\x31\n\x0coutput_specs\x18\x03 \x01(\x0b\x32\x1b.tensorflow.StructuredValue\"L\n\x0cListResponse\x12<\n\x12registered_methods\x18\x01 \x03(\x0b\x32 .tensorflow.rpc.RegisteredMethod2\x96\x01\n\nRpcService\x12\x43\n\x04\x43\x61ll\x12\x1b.tensorflow.rpc.CallRequest\x1a\x1c.tensorflow.rpc.CallResponse\"\x00\x12\x43\n\x04List\x12\x1b.tensorflow.rpc.ListRequest\x1a\x1c.tensorflow.rpc.ListResponse\"\x00\x62\x06proto3')
19
+
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.distribute.experimental.rpc.proto.tf_rpc_service_pb2', globals())
22
+ if _descriptor._USE_C_DESCRIPTORS == False:
23
+
24
+ DESCRIPTOR._options = None
25
+ _CALLREQUEST._serialized_start=164
26
+ _CALLREQUEST._serialized_end=241
27
+ _CALLRESPONSE._serialized_start=243
28
+ _CALLRESPONSE._serialized_end=306
29
+ _LISTREQUEST._serialized_start=308
30
+ _LISTREQUEST._serialized_end=321
31
+ _REGISTEREDMETHOD._serialized_start=324
32
+ _REGISTEREDMETHOD._serialized_end=459
33
+ _LISTRESPONSE._serialized_start=461
34
+ _LISTRESPONSE._serialized_end=537
35
+ _RPCSERVICE._serialized_start=540
36
+ _RPCSERVICE._serialized_end=690
37
+ # @@protoc_insertion_point(module_scope)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/distribute/experimental/rpc/proto/tf_rpc_service_pb2_grpc.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ import grpc
3
+
4
+ from tensorflow.distribute.experimental.rpc.proto import tf_rpc_service_pb2 as tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2
5
+
6
+
7
+ class RpcServiceStub(object):
8
+ # missing associated documentation comment in .proto file
9
+ pass
10
+
11
+ def __init__(self, channel):
12
+ """Constructor.
13
+
14
+ Args:
15
+ channel: A grpc.Channel.
16
+ """
17
+ self.Call = channel.unary_unary(
18
+ '/tensorflow.rpc.RpcService/Call',
19
+ request_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallRequest.SerializeToString,
20
+ response_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallResponse.FromString,
21
+ )
22
+ self.List = channel.unary_unary(
23
+ '/tensorflow.rpc.RpcService/List',
24
+ request_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListRequest.SerializeToString,
25
+ response_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListResponse.FromString,
26
+ )
27
+
28
+
29
+ class RpcServiceServicer(object):
30
+ # missing associated documentation comment in .proto file
31
+ pass
32
+
33
+ def Call(self, request, context):
34
+ """RPC for invoking a registered function on remote server.
35
+ """
36
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
37
+ context.set_details('Method not implemented!')
38
+ raise NotImplementedError('Method not implemented!')
39
+
40
+ def List(self, request, context):
41
+ """RPC for listing available methods in a server.
42
+ """
43
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
44
+ context.set_details('Method not implemented!')
45
+ raise NotImplementedError('Method not implemented!')
46
+
47
+
48
+ def add_RpcServiceServicer_to_server(servicer, server):
49
+ rpc_method_handlers = {
50
+ 'Call': grpc.unary_unary_rpc_method_handler(
51
+ servicer.Call,
52
+ request_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallRequest.FromString,
53
+ response_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.CallResponse.SerializeToString,
54
+ ),
55
+ 'List': grpc.unary_unary_rpc_method_handler(
56
+ servicer.List,
57
+ request_deserializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListRequest.FromString,
58
+ response_serializer=tensorflow_dot_distribute_dot_experimental_dot_rpc_dot_proto_dot_tf__rpc__service__pb2.ListResponse.SerializeToString,
59
+ ),
60
+ }
61
+ generic_handler = grpc.method_handlers_generic_handler(
62
+ 'tensorflow.rpc.RpcService', rpc_method_handlers)
63
+ server.add_generic_rpc_handlers((generic_handler,))
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (191 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/analyzer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """This tool analyzes a TensorFlow Lite graph."""
16
+
17
+ import os
18
+
19
+ # pylint: disable=g-import-not-at-top
20
+ if not os.path.splitext(__file__)[0].endswith(
21
+ os.path.join("tflite_runtime", "analyzer")):
22
+ # This file is part of tensorflow package.
23
+ from tensorflow.compiler.mlir.lite.python import wrap_converter
24
+ from tensorflow.lite.python.analyzer_wrapper import _pywrap_analyzer_wrapper as _analyzer_wrapper
25
+ from tensorflow.python.util.tf_export import tf_export as _tf_export
26
+ else:
27
+ # This file is part of tflite_runtime package.
28
+ from tflite_runtime import _pywrap_analyzer_wrapper as _analyzer_wrapper
29
+
30
+ def _tf_export(*x, **kwargs):
31
+ del x, kwargs
32
+ return lambda x: x
33
+
34
+
35
+ @_tf_export("lite.experimental.Analyzer")
36
+ class ModelAnalyzer():
37
+ """Provides a collection of TFLite model analyzer tools.
38
+
39
+ Example:
40
+
41
+ ```python
42
+ model = tf.keras.applications.MobileNetV3Large()
43
+ fb_model = tf.lite.TFLiteConverterV2.from_keras_model(model).convert()
44
+ tf.lite.experimental.Analyzer.analyze(model_content=fb_model)
45
+ # === TFLite ModelAnalyzer ===
46
+ #
47
+ # Your TFLite model has ‘1’ subgraph(s). In the subgraph description below,
48
+ # T# represents the Tensor numbers. For example, in Subgraph#0, the MUL op
49
+ # takes tensor #0 and tensor #19 as input and produces tensor #136 as output.
50
+ #
51
+ # Subgraph#0 main(T#0) -> [T#263]
52
+ # Op#0 MUL(T#0, T#19) -> [T#136]
53
+ # Op#1 ADD(T#136, T#18) -> [T#137]
54
+ # Op#2 CONV_2D(T#137, T#44, T#93) -> [T#138]
55
+ # Op#3 HARD_SWISH(T#138) -> [T#139]
56
+ # Op#4 DEPTHWISE_CONV_2D(T#139, T#94, T#24) -> [T#140]
57
+ # ...
58
+ ```
59
+
60
+ WARNING: Experimental interface, subject to change.
61
+ """
62
+
63
+ @staticmethod
64
+ def analyze(model_path=None,
65
+ model_content=None,
66
+ gpu_compatibility=False,
67
+ **kwargs):
68
+ """Analyzes the given tflite_model with dumping model structure.
69
+
70
+ This tool provides a way to understand users' TFLite flatbuffer model by
71
+ dumping internal graph structure. It also provides additional features
72
+ like checking GPU delegate compatibility.
73
+
74
+ WARNING: Experimental interface, subject to change.
75
+ The output format is not guaranteed to stay stable, so don't
76
+ write scripts to this.
77
+
78
+ Args:
79
+ model_path: TFLite flatbuffer model path.
80
+ model_content: TFLite flatbuffer model object.
81
+ gpu_compatibility: Whether to check GPU delegate compatibility.
82
+ **kwargs: Experimental keyword arguments to analyze API.
83
+
84
+ Returns:
85
+ Print analyzed report via console output.
86
+ """
87
+ if not model_path and not model_content:
88
+ raise ValueError("neither `model_path` nor `model_content` is provided")
89
+ if model_path:
90
+ print(f"=== {model_path} ===\n")
91
+ tflite_model = model_path
92
+ input_is_filepath = True
93
+ else:
94
+ print("=== TFLite ModelAnalyzer ===\n")
95
+ tflite_model = model_content
96
+ input_is_filepath = False
97
+
98
+ if kwargs.get("experimental_use_mlir", False):
99
+ print(
100
+ wrap_converter.wrapped_flat_buffer_file_to_mlir(
101
+ tflite_model, input_is_filepath
102
+ )
103
+ )
104
+ else:
105
+ print(
106
+ _analyzer_wrapper.ModelAnalyzer(tflite_model, input_is_filepath,
107
+ gpu_compatibility))
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/conversion_metadata_schema_py_generated.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flatbuffers
2
+
3
+ # automatically generated by the FlatBuffers compiler, do not modify
4
+
5
+ # namespace: tflite
6
+
7
+ from flatbuffers.compat import import_numpy
8
+ np = import_numpy()
9
+
10
+ class ModelType(object):
11
+ NONE = 0
12
+ TF_SAVED_MODEL = 1
13
+ KERAS_MODEL = 2
14
+ TF_CONCRETE_FUNCTIONS = 3
15
+ TF_GRAPH_DEF = 4
16
+ TF_SESSION = 5
17
+ JAX = 6
18
+ PYTORCH = 7
19
+
20
+
21
+ class ModelOptimizationMode(object):
22
+ PTQ_FLOAT16 = 1001
23
+ PTQ_DYNAMIC_RANGE = 1002
24
+ PTQ_FULL_INTEGER = 1003
25
+ PTQ_INT16 = 1004
26
+ QUANTIZATION_AWARE_TRAINING = 2000
27
+ RANDOM_SPARSITY = 3001
28
+ BLOCK_SPARSITY = 3002
29
+ STRUCTURED_SPARSITY = 3003
30
+
31
+
32
+ class Environment(object):
33
+ __slots__ = ['_tab']
34
+
35
+ @classmethod
36
+ def GetRootAs(cls, buf, offset=0):
37
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
38
+ x = Environment()
39
+ x.Init(buf, n + offset)
40
+ return x
41
+
42
+ @classmethod
43
+ def GetRootAsEnvironment(cls, buf, offset=0):
44
+ """This method is deprecated. Please switch to GetRootAs."""
45
+ return cls.GetRootAs(buf, offset)
46
+ # Environment
47
+ def Init(self, buf, pos):
48
+ self._tab = flatbuffers.table.Table(buf, pos)
49
+
50
+ # Environment
51
+ def TensorflowVersion(self):
52
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
53
+ if o != 0:
54
+ return self._tab.String(o + self._tab.Pos)
55
+ return None
56
+
57
+ # Environment
58
+ def ApiVersion(self):
59
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
60
+ if o != 0:
61
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
62
+ return 0
63
+
64
+ # Environment
65
+ def ModelType(self):
66
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
67
+ if o != 0:
68
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
69
+ return 0
70
+
71
+ # Environment
72
+ def ModelHash(self):
73
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
74
+ if o != 0:
75
+ return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
76
+ return 0
77
+
78
+ def EnvironmentStart(builder):
79
+ builder.StartObject(4)
80
+
81
+ def EnvironmentAddTensorflowVersion(builder, tensorflowVersion):
82
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(tensorflowVersion), 0)
83
+
84
+ def EnvironmentAddApiVersion(builder, apiVersion):
85
+ builder.PrependUint32Slot(1, apiVersion, 0)
86
+
87
+ def EnvironmentAddModelType(builder, modelType):
88
+ builder.PrependInt32Slot(2, modelType, 0)
89
+
90
+ def EnvironmentAddModelHash(builder, modelHash):
91
+ builder.PrependUint64Slot(3, modelHash, 0)
92
+
93
+ def EnvironmentEnd(builder):
94
+ return builder.EndObject()
95
+
96
+
97
+
98
+ class EnvironmentT(object):
99
+
100
+ # EnvironmentT
101
+ def __init__(self):
102
+ self.tensorflowVersion = None # type: str
103
+ self.apiVersion = 0 # type: int
104
+ self.modelType = 0 # type: int
105
+ self.modelHash = 0 # type: int
106
+
107
+ @classmethod
108
+ def InitFromBuf(cls, buf, pos):
109
+ environment = Environment()
110
+ environment.Init(buf, pos)
111
+ return cls.InitFromObj(environment)
112
+
113
+ @classmethod
114
+ def InitFromPackedBuf(cls, buf, pos=0):
115
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
116
+ return cls.InitFromBuf(buf, pos+n)
117
+
118
+ @classmethod
119
+ def InitFromObj(cls, environment):
120
+ x = EnvironmentT()
121
+ x._UnPack(environment)
122
+ return x
123
+
124
+ # EnvironmentT
125
+ def _UnPack(self, environment):
126
+ if environment is None:
127
+ return
128
+ self.tensorflowVersion = environment.TensorflowVersion()
129
+ self.apiVersion = environment.ApiVersion()
130
+ self.modelType = environment.ModelType()
131
+ self.modelHash = environment.ModelHash()
132
+
133
+ # EnvironmentT
134
+ def Pack(self, builder):
135
+ if self.tensorflowVersion is not None:
136
+ tensorflowVersion = builder.CreateString(self.tensorflowVersion)
137
+ EnvironmentStart(builder)
138
+ if self.tensorflowVersion is not None:
139
+ EnvironmentAddTensorflowVersion(builder, tensorflowVersion)
140
+ EnvironmentAddApiVersion(builder, self.apiVersion)
141
+ EnvironmentAddModelType(builder, self.modelType)
142
+ EnvironmentAddModelHash(builder, self.modelHash)
143
+ environment = EnvironmentEnd(builder)
144
+ return environment
145
+
146
+
147
+ class SparsityBlockSize(object):
148
+ __slots__ = ['_tab']
149
+
150
+ @classmethod
151
+ def GetRootAs(cls, buf, offset=0):
152
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
153
+ x = SparsityBlockSize()
154
+ x.Init(buf, n + offset)
155
+ return x
156
+
157
+ @classmethod
158
+ def GetRootAsSparsityBlockSize(cls, buf, offset=0):
159
+ """This method is deprecated. Please switch to GetRootAs."""
160
+ return cls.GetRootAs(buf, offset)
161
+ # SparsityBlockSize
162
+ def Init(self, buf, pos):
163
+ self._tab = flatbuffers.table.Table(buf, pos)
164
+
165
+ # SparsityBlockSize
166
+ def Values(self, j):
167
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
168
+ if o != 0:
169
+ a = self._tab.Vector(o)
170
+ return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
171
+ return 0
172
+
173
+ # SparsityBlockSize
174
+ def ValuesAsNumpy(self):
175
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
176
+ if o != 0:
177
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
178
+ return 0
179
+
180
+ # SparsityBlockSize
181
+ def ValuesLength(self):
182
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
183
+ if o != 0:
184
+ return self._tab.VectorLen(o)
185
+ return 0
186
+
187
+ # SparsityBlockSize
188
+ def ValuesIsNone(self):
189
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
190
+ return o == 0
191
+
192
+ def SparsityBlockSizeStart(builder):
193
+ builder.StartObject(1)
194
+
195
+ def SparsityBlockSizeAddValues(builder, values):
196
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(values), 0)
197
+
198
+ def SparsityBlockSizeStartValuesVector(builder, numElems):
199
+ return builder.StartVector(4, numElems, 4)
200
+
201
+ def SparsityBlockSizeEnd(builder):
202
+ return builder.EndObject()
203
+
204
+
205
+ try:
206
+ from typing import List
207
+ except:
208
+ pass
209
+
210
+ class SparsityBlockSizeT(object):
211
+
212
+ # SparsityBlockSizeT
213
+ def __init__(self):
214
+ self.values = None # type: List[int]
215
+
216
+ @classmethod
217
+ def InitFromBuf(cls, buf, pos):
218
+ sparsityBlockSize = SparsityBlockSize()
219
+ sparsityBlockSize.Init(buf, pos)
220
+ return cls.InitFromObj(sparsityBlockSize)
221
+
222
+ @classmethod
223
+ def InitFromPackedBuf(cls, buf, pos=0):
224
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
225
+ return cls.InitFromBuf(buf, pos+n)
226
+
227
+ @classmethod
228
+ def InitFromObj(cls, sparsityBlockSize):
229
+ x = SparsityBlockSizeT()
230
+ x._UnPack(sparsityBlockSize)
231
+ return x
232
+
233
+ # SparsityBlockSizeT
234
+ def _UnPack(self, sparsityBlockSize):
235
+ if sparsityBlockSize is None:
236
+ return
237
+ if not sparsityBlockSize.ValuesIsNone():
238
+ if np is None:
239
+ self.values = []
240
+ for i in range(sparsityBlockSize.ValuesLength()):
241
+ self.values.append(sparsityBlockSize.Values(i))
242
+ else:
243
+ self.values = sparsityBlockSize.ValuesAsNumpy()
244
+
245
+ # SparsityBlockSizeT
246
+ def Pack(self, builder):
247
+ if self.values is not None:
248
+ if np is not None and type(self.values) is np.ndarray:
249
+ values = builder.CreateNumpyVector(self.values)
250
+ else:
251
+ SparsityBlockSizeStartValuesVector(builder, len(self.values))
252
+ for i in reversed(range(len(self.values))):
253
+ builder.PrependUint32(self.values[i])
254
+ values = builder.EndVector()
255
+ SparsityBlockSizeStart(builder)
256
+ if self.values is not None:
257
+ SparsityBlockSizeAddValues(builder, values)
258
+ sparsityBlockSize = SparsityBlockSizeEnd(builder)
259
+ return sparsityBlockSize
260
+
261
+
262
+ class ConversionOptions(object):
263
+ __slots__ = ['_tab']
264
+
265
+ @classmethod
266
+ def GetRootAs(cls, buf, offset=0):
267
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
268
+ x = ConversionOptions()
269
+ x.Init(buf, n + offset)
270
+ return x
271
+
272
+ @classmethod
273
+ def GetRootAsConversionOptions(cls, buf, offset=0):
274
+ """This method is deprecated. Please switch to GetRootAs."""
275
+ return cls.GetRootAs(buf, offset)
276
+ # ConversionOptions
277
+ def Init(self, buf, pos):
278
+ self._tab = flatbuffers.table.Table(buf, pos)
279
+
280
+ # ConversionOptions
281
+ def ModelOptimizationModes(self, j):
282
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
283
+ if o != 0:
284
+ a = self._tab.Vector(o)
285
+ return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
286
+ return 0
287
+
288
+ # ConversionOptions
289
+ def ModelOptimizationModesAsNumpy(self):
290
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
291
+ if o != 0:
292
+ return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
293
+ return 0
294
+
295
+ # ConversionOptions
296
+ def ModelOptimizationModesLength(self):
297
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
298
+ if o != 0:
299
+ return self._tab.VectorLen(o)
300
+ return 0
301
+
302
+ # ConversionOptions
303
+ def ModelOptimizationModesIsNone(self):
304
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
305
+ return o == 0
306
+
307
+ # ConversionOptions
308
+ def AllowCustomOps(self):
309
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
310
+ if o != 0:
311
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
312
+ return False
313
+
314
+ # ConversionOptions
315
+ def EnableSelectTfOps(self):
316
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
317
+ if o != 0:
318
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
319
+ return False
320
+
321
+ # ConversionOptions
322
+ def ForceSelectTfOps(self):
323
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
324
+ if o != 0:
325
+ return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
326
+ return False
327
+
328
+ # ConversionOptions
329
+ def SparsityBlockSizes(self, j):
330
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
331
+ if o != 0:
332
+ x = self._tab.Vector(o)
333
+ x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
334
+ x = self._tab.Indirect(x)
335
+ obj = SparsityBlockSize()
336
+ obj.Init(self._tab.Bytes, x)
337
+ return obj
338
+ return None
339
+
340
+ # ConversionOptions
341
+ def SparsityBlockSizesLength(self):
342
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
343
+ if o != 0:
344
+ return self._tab.VectorLen(o)
345
+ return 0
346
+
347
+ # ConversionOptions
348
+ def SparsityBlockSizesIsNone(self):
349
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
350
+ return o == 0
351
+
352
+ def ConversionOptionsStart(builder):
353
+ builder.StartObject(5)
354
+
355
+ def ConversionOptionsAddModelOptimizationModes(builder, modelOptimizationModes):
356
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(modelOptimizationModes), 0)
357
+
358
+ def ConversionOptionsStartModelOptimizationModesVector(builder, numElems):
359
+ return builder.StartVector(4, numElems, 4)
360
+
361
+ def ConversionOptionsAddAllowCustomOps(builder, allowCustomOps):
362
+ builder.PrependBoolSlot(1, allowCustomOps, 0)
363
+
364
+ def ConversionOptionsAddEnableSelectTfOps(builder, enableSelectTfOps):
365
+ builder.PrependBoolSlot(2, enableSelectTfOps, 0)
366
+
367
+ def ConversionOptionsAddForceSelectTfOps(builder, forceSelectTfOps):
368
+ builder.PrependBoolSlot(3, forceSelectTfOps, 0)
369
+
370
+ def ConversionOptionsAddSparsityBlockSizes(builder, sparsityBlockSizes):
371
+ builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(sparsityBlockSizes), 0)
372
+
373
+ def ConversionOptionsStartSparsityBlockSizesVector(builder, numElems):
374
+ return builder.StartVector(4, numElems, 4)
375
+
376
+ def ConversionOptionsEnd(builder):
377
+ return builder.EndObject()
378
+
379
+
380
+ try:
381
+ from typing import List
382
+ except:
383
+ pass
384
+
385
+ class ConversionOptionsT(object):
386
+
387
+ # ConversionOptionsT
388
+ def __init__(self):
389
+ self.modelOptimizationModes = None # type: List[int]
390
+ self.allowCustomOps = False # type: bool
391
+ self.enableSelectTfOps = False # type: bool
392
+ self.forceSelectTfOps = False # type: bool
393
+ self.sparsityBlockSizes = None # type: List[SparsityBlockSizeT]
394
+
395
+ @classmethod
396
+ def InitFromBuf(cls, buf, pos):
397
+ conversionOptions = ConversionOptions()
398
+ conversionOptions.Init(buf, pos)
399
+ return cls.InitFromObj(conversionOptions)
400
+
401
+ @classmethod
402
+ def InitFromPackedBuf(cls, buf, pos=0):
403
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
404
+ return cls.InitFromBuf(buf, pos+n)
405
+
406
+ @classmethod
407
+ def InitFromObj(cls, conversionOptions):
408
+ x = ConversionOptionsT()
409
+ x._UnPack(conversionOptions)
410
+ return x
411
+
412
+ # ConversionOptionsT
413
+ def _UnPack(self, conversionOptions):
414
+ if conversionOptions is None:
415
+ return
416
+ if not conversionOptions.ModelOptimizationModesIsNone():
417
+ if np is None:
418
+ self.modelOptimizationModes = []
419
+ for i in range(conversionOptions.ModelOptimizationModesLength()):
420
+ self.modelOptimizationModes.append(conversionOptions.ModelOptimizationModes(i))
421
+ else:
422
+ self.modelOptimizationModes = conversionOptions.ModelOptimizationModesAsNumpy()
423
+ self.allowCustomOps = conversionOptions.AllowCustomOps()
424
+ self.enableSelectTfOps = conversionOptions.EnableSelectTfOps()
425
+ self.forceSelectTfOps = conversionOptions.ForceSelectTfOps()
426
+ if not conversionOptions.SparsityBlockSizesIsNone():
427
+ self.sparsityBlockSizes = []
428
+ for i in range(conversionOptions.SparsityBlockSizesLength()):
429
+ if conversionOptions.SparsityBlockSizes(i) is None:
430
+ self.sparsityBlockSizes.append(None)
431
+ else:
432
+ sparsityBlockSize_ = SparsityBlockSizeT.InitFromObj(conversionOptions.SparsityBlockSizes(i))
433
+ self.sparsityBlockSizes.append(sparsityBlockSize_)
434
+
435
+ # ConversionOptionsT
436
+ def Pack(self, builder):
437
+ if self.modelOptimizationModes is not None:
438
+ if np is not None and type(self.modelOptimizationModes) is np.ndarray:
439
+ modelOptimizationModes = builder.CreateNumpyVector(self.modelOptimizationModes)
440
+ else:
441
+ ConversionOptionsStartModelOptimizationModesVector(builder, len(self.modelOptimizationModes))
442
+ for i in reversed(range(len(self.modelOptimizationModes))):
443
+ builder.PrependInt32(self.modelOptimizationModes[i])
444
+ modelOptimizationModes = builder.EndVector()
445
+ if self.sparsityBlockSizes is not None:
446
+ sparsityBlockSizeslist = []
447
+ for i in range(len(self.sparsityBlockSizes)):
448
+ sparsityBlockSizeslist.append(self.sparsityBlockSizes[i].Pack(builder))
449
+ ConversionOptionsStartSparsityBlockSizesVector(builder, len(self.sparsityBlockSizes))
450
+ for i in reversed(range(len(self.sparsityBlockSizes))):
451
+ builder.PrependUOffsetTRelative(sparsityBlockSizeslist[i])
452
+ sparsityBlockSizes = builder.EndVector()
453
+ ConversionOptionsStart(builder)
454
+ if self.modelOptimizationModes is not None:
455
+ ConversionOptionsAddModelOptimizationModes(builder, modelOptimizationModes)
456
+ ConversionOptionsAddAllowCustomOps(builder, self.allowCustomOps)
457
+ ConversionOptionsAddEnableSelectTfOps(builder, self.enableSelectTfOps)
458
+ ConversionOptionsAddForceSelectTfOps(builder, self.forceSelectTfOps)
459
+ if self.sparsityBlockSizes is not None:
460
+ ConversionOptionsAddSparsityBlockSizes(builder, sparsityBlockSizes)
461
+ conversionOptions = ConversionOptionsEnd(builder)
462
+ return conversionOptions
463
+
464
+
465
+ class ConversionMetadata(object):
466
+ __slots__ = ['_tab']
467
+
468
+ @classmethod
469
+ def GetRootAs(cls, buf, offset=0):
470
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
471
+ x = ConversionMetadata()
472
+ x.Init(buf, n + offset)
473
+ return x
474
+
475
+ @classmethod
476
+ def GetRootAsConversionMetadata(cls, buf, offset=0):
477
+ """This method is deprecated. Please switch to GetRootAs."""
478
+ return cls.GetRootAs(buf, offset)
479
+ # ConversionMetadata
480
+ def Init(self, buf, pos):
481
+ self._tab = flatbuffers.table.Table(buf, pos)
482
+
483
+ # ConversionMetadata
484
+ def Environment(self):
485
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
486
+ if o != 0:
487
+ x = self._tab.Indirect(o + self._tab.Pos)
488
+ obj = Environment()
489
+ obj.Init(self._tab.Bytes, x)
490
+ return obj
491
+ return None
492
+
493
+ # ConversionMetadata
494
+ def Options(self):
495
+ o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
496
+ if o != 0:
497
+ x = self._tab.Indirect(o + self._tab.Pos)
498
+ obj = ConversionOptions()
499
+ obj.Init(self._tab.Bytes, x)
500
+ return obj
501
+ return None
502
+
503
+ def ConversionMetadataStart(builder):
504
+ builder.StartObject(2)
505
+
506
+ def ConversionMetadataAddEnvironment(builder, environment):
507
+ builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(environment), 0)
508
+
509
+ def ConversionMetadataAddOptions(builder, options):
510
+ builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(options), 0)
511
+
512
+ def ConversionMetadataEnd(builder):
513
+ return builder.EndObject()
514
+
515
+
516
+ try:
517
+ from typing import Optional
518
+ except:
519
+ pass
520
+
521
+ class ConversionMetadataT(object):
522
+
523
+ # ConversionMetadataT
524
+ def __init__(self):
525
+ self.environment = None # type: Optional[EnvironmentT]
526
+ self.options = None # type: Optional[ConversionOptionsT]
527
+
528
+ @classmethod
529
+ def InitFromBuf(cls, buf, pos):
530
+ conversionMetadata = ConversionMetadata()
531
+ conversionMetadata.Init(buf, pos)
532
+ return cls.InitFromObj(conversionMetadata)
533
+
534
+ @classmethod
535
+ def InitFromPackedBuf(cls, buf, pos=0):
536
+ n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
537
+ return cls.InitFromBuf(buf, pos+n)
538
+
539
+ @classmethod
540
+ def InitFromObj(cls, conversionMetadata):
541
+ x = ConversionMetadataT()
542
+ x._UnPack(conversionMetadata)
543
+ return x
544
+
545
+ # ConversionMetadataT
546
+ def _UnPack(self, conversionMetadata):
547
+ if conversionMetadata is None:
548
+ return
549
+ if conversionMetadata.Environment() is not None:
550
+ self.environment = EnvironmentT.InitFromObj(conversionMetadata.Environment())
551
+ if conversionMetadata.Options() is not None:
552
+ self.options = ConversionOptionsT.InitFromObj(conversionMetadata.Options())
553
+
554
+ # ConversionMetadataT
555
+ def Pack(self, builder):
556
+ if self.environment is not None:
557
+ environment = self.environment.Pack(builder)
558
+ if self.options is not None:
559
+ options = self.options.Pack(builder)
560
+ ConversionMetadataStart(builder)
561
+ if self.environment is not None:
562
+ ConversionMetadataAddEnvironment(builder, environment)
563
+ if self.options is not None:
564
+ ConversionMetadataAddOptions(builder, options)
565
+ conversionMetadata = ConversionMetadataEnd(builder)
566
+ return conversionMetadata
567
+
568
+
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/convert_phase.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for collecting TFLite metrics."""
16
+
17
+ import collections
18
+ import enum
19
+ import functools
20
+ from typing import Text
21
+
22
+ from tensorflow.compiler.mlir.lite.metrics import converter_error_data_pb2
23
+ from tensorflow.lite.python.metrics import metrics
24
+
25
+
26
+ class Component(enum.Enum):
27
+ """Enum class defining name of the converter components."""
28
+ # Validate the given input and prepare and optimize TensorFlow Model.
29
+ PREPARE_TF_MODEL = "PREPARE_TF_MODEL"
30
+
31
+ # Convert to TFLite model format.
32
+ CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL"
33
+
34
+ # RUN quantization and sparsification.
35
+ OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL"
36
+
37
+
38
+ SubComponentItem = collections.namedtuple("SubComponentItem",
39
+ ["name", "component"])
40
+
41
+
42
+ class SubComponent(SubComponentItem, enum.Enum):
43
+ """Enum class defining name of the converter subcomponents.
44
+
45
+ This enum only defines the subcomponents in Python, there might be more
46
+ subcomponents defined in C++.
47
+ """
48
+
49
+ def __str__(self):
50
+ return self.value.name
51
+
52
+ @property
53
+ def name(self):
54
+ return self.value.name
55
+
56
+ @property
57
+ def component(self):
58
+ return self.value.component
59
+
60
+ # The subcomponent name is unspecified.
61
+ UNSPECIFIED = SubComponentItem("UNSPECIFIED", None)
62
+
63
+ # Valid the given input and parameters.
64
+ VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS",
65
+ Component.PREPARE_TF_MODEL)
66
+
67
+ # Load GraphDef from SavedModel.
68
+ LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL",
69
+ Component.PREPARE_TF_MODEL)
70
+
71
+ # Convert a SavedModel to frozen graph.
72
+ FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL",
73
+ Component.PREPARE_TF_MODEL)
74
+
75
+ # Save a Keras model to SavedModel.
76
+ CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem(
77
+ "CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
78
+
79
+ # Save Concrete functions to SavedModel.
80
+ CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem(
81
+ "CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
82
+
83
+ # Convert a Keras model to a frozen graph.
84
+ FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL",
85
+ Component.PREPARE_TF_MODEL)
86
+
87
+ # Replace all the variables with constants in a ConcreteFunction.
88
+ FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION",
89
+ Component.PREPARE_TF_MODEL)
90
+
91
+ # Run grappler optimization.
92
+ OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL",
93
+ Component.PREPARE_TF_MODEL)
94
+
95
+ # Convert using the old TOCO converter.
96
+ CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem(
97
+ "CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER",
98
+ Component.CONVERT_TF_TO_TFLITE_MODEL)
99
+
100
+ # Convert a GraphDef to TFLite model.
101
+ CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF",
102
+ Component.CONVERT_TF_TO_TFLITE_MODEL)
103
+
104
+ # Convert a SavedModel to TFLite model.
105
+ CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL",
106
+ Component.CONVERT_TF_TO_TFLITE_MODEL)
107
+
108
+ # Convert a Jax HLO to TFLite model.
109
+ CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO",
110
+ Component.CONVERT_TF_TO_TFLITE_MODEL)
111
+
112
+ # Do quantization by the deprecated quantizer.
113
+ QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem(
114
+ "QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL)
115
+
116
+ # Do calibration.
117
+ CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL)
118
+
119
+ # Do quantization by MLIR.
120
+ QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL)
121
+
122
+ # Do sparsification by MLIR.
123
+ SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL)
124
+
125
+
126
+ class ConverterError(Exception):
127
+ """Raised when an error occurs during model conversion."""
128
+
129
+ def __init__(self, message):
130
+ super(ConverterError, self).__init__(message)
131
+ self.errors = []
132
+ self._parse_error_message(message)
133
+
134
+ def append_error(self,
135
+ error_data: converter_error_data_pb2.ConverterErrorData):
136
+ self.errors.append(error_data)
137
+
138
+ def _parse_error_message(self, message):
139
+ """If the message matches a pattern, assigns the associated error code.
140
+
141
+ It is difficult to assign an error code to some errrors in MLIR side, Ex:
142
+ errors thrown by other components than TFLite or not using mlir::emitError.
143
+ This function try to detect them by the error message and assign the
144
+ corresponding error code.
145
+
146
+ Args:
147
+ message: The error message of this exception.
148
+ """
149
+ error_code_mapping = {
150
+ "Failed to functionalize Control Flow V1 ops. Consider using Control "
151
+ "Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/"
152
+ "tf/compat/v1/enable_control_flow_v2.":
153
+ converter_error_data_pb2.ConverterErrorData
154
+ .ERROR_UNSUPPORTED_CONTROL_FLOW_V1,
155
+ }
156
+ for pattern, error_code in error_code_mapping.items():
157
+ if pattern in message:
158
+ error_data = converter_error_data_pb2.ConverterErrorData()
159
+ error_data.error_message = message
160
+ error_data.error_code = error_code
161
+ self.append_error(error_data)
162
+ return
163
+
164
+
165
+ def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED):
166
+ """The decorator to identify converter component and subcomponent.
167
+
168
+ Args:
169
+ component: Converter component name.
170
+ subcomponent: Converter subcomponent name.
171
+
172
+ Returns:
173
+ Forward the result from the wrapped function.
174
+
175
+ Raises:
176
+ ValueError: if component and subcomponent name is not valid.
177
+ """
178
+ if component not in Component:
179
+ raise ValueError("Given component name not found")
180
+ if subcomponent not in SubComponent:
181
+ raise ValueError("Given subcomponent name not found")
182
+ if (subcomponent != SubComponent.UNSPECIFIED and
183
+ subcomponent.component != component):
184
+ raise ValueError("component and subcomponent name don't match")
185
+
186
+ def report_error(error_data: converter_error_data_pb2.ConverterErrorData):
187
+ # Always overwrites the component information, but only overwrites the
188
+ # subcomponent if it is not available.
189
+ error_data.component = component.value
190
+ if not error_data.subcomponent:
191
+ error_data.subcomponent = subcomponent.name
192
+ tflite_metrics = metrics.TFLiteConverterMetrics()
193
+ tflite_metrics.set_converter_error(error_data)
194
+
195
+ def report_error_message(error_message: Text):
196
+ error_data = converter_error_data_pb2.ConverterErrorData()
197
+ error_data.error_message = error_message
198
+ report_error(error_data)
199
+
200
+ def actual_decorator(func):
201
+
202
+ @functools.wraps(func)
203
+ def wrapper(*args, **kwargs):
204
+ try:
205
+ return func(*args, **kwargs)
206
+ except ConverterError as converter_error:
207
+ if converter_error.errors:
208
+ for error_data in converter_error.errors:
209
+ report_error(error_data)
210
+ else:
211
+ report_error_message(str(converter_error))
212
+ raise converter_error from None # Re-throws the exception.
213
+ except Exception as error:
214
+ report_error_message(str(error))
215
+ raise error from None # Re-throws the exception.
216
+
217
+ return wrapper
218
+
219
+ return actual_decorator
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/schema_util.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Schema utilities to get builtin code from operator code."""
16
+
17
+ from tensorflow.python.util import all_util
18
+
19
+
20
+ def get_builtin_code_from_operator_code(opcode):
21
+ """Return the builtin code of the given operator code.
22
+
23
+ The following method is introduced to resolve op builtin code shortage
24
+ problem. The new builtin operator will be assigned to the extended builtin
25
+ code field in the flatbuffer schema. Those methods helps to hide builtin code
26
+ details.
27
+
28
+ Args:
29
+ opcode: Operator code.
30
+
31
+ Returns:
32
+ The builtin code of the given operator code.
33
+ """
34
+ # Access BuiltinCode() method first if available.
35
+ if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode):
36
+ return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode())
37
+
38
+ return max(opcode.builtinCode, opcode.deprecatedBuiltinCode)
39
+
40
+
41
+ _allowed_symbols = [
42
+ 'get_builtin_code_from_operator_code',
43
+ ]
44
+
45
+ all_util.remove_undocumented(__name__, _allowed_symbols)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/tflite_convert.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Python command line interface for converting TF models to TFLite models."""
16
+
17
+ import argparse
18
+ import os
19
+ import sys
20
+ import warnings
21
+
22
+ from absl import app
23
+ import tensorflow as tf
24
+
25
+ from tensorflow.lite.python import lite
26
+ from tensorflow.lite.python.convert import register_custom_opdefs
27
+ from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
28
+ from tensorflow.lite.toco.logging import gen_html
29
+ from tensorflow.python import tf2
30
+ from tensorflow.python.framework import dtypes
31
+ from tensorflow.python.platform import gfile
32
+ from tensorflow.python.util import keras_deps
33
+
34
+ # Needed to enable TF2 by default.
35
+
36
+ _ = tf.keras.models.save_model # ensure necessary imports are executed
37
+
38
+
39
+ def _parse_array(values, type_fn=str):
40
+ if values is not None:
41
+ return [type_fn(val) for val in values.split(",") if val]
42
+ return None
43
+
44
+
45
+ def _parse_set(values):
46
+ if values is not None:
47
+ return set([item for item in values.split(",") if item])
48
+ return None
49
+
50
+
51
+ def _parse_inference_type(value, flag):
52
+ """Converts the inference type to the value of the constant.
53
+
54
+ Args:
55
+ value: str representing the inference type.
56
+ flag: str representing the flag name.
57
+
58
+ Returns:
59
+ tf.dtype.
60
+
61
+ Raises:
62
+ ValueError: Unsupported value.
63
+ """
64
+ if value == "FLOAT":
65
+ return dtypes.float32
66
+ if value == "INT8":
67
+ return dtypes.int8
68
+ if value == "UINT8" or value == "QUANTIZED_UINT8":
69
+ return dtypes.uint8
70
+ raise ValueError(
71
+ "Unsupported value for `{}` flag. Expected FLOAT, INT8, UINT8, or "
72
+ "QUANTIZED_UINT8 instead got {}.".format(flag, value))
73
+
74
+
75
+ class _ParseBooleanFlag(argparse.Action):
76
+ """Helper class to parse boolean flag that optionally accepts truth value."""
77
+
78
+ def __init__(self, option_strings, dest, nargs=None, **kwargs):
79
+ if nargs != "?":
80
+ # This should never happen. This class is only used once below with
81
+ # nargs="?".
82
+ raise ValueError(
83
+ "This parser only supports nargs='?' (0 or 1 additional arguments)")
84
+ super(_ParseBooleanFlag, self).__init__(
85
+ option_strings, dest, nargs=nargs, **kwargs)
86
+
87
+ def __call__(self, parser, namespace, values, option_string=None):
88
+ if values is None:
89
+ # Handling `--boolean_flag`.
90
+ # Without additional arguments, it implies true.
91
+ flag_value = True
92
+ elif values.lower() == "true":
93
+ # Handling `--boolean_flag=true`.
94
+ # (Case insensitive after the equal sign)
95
+ flag_value = True
96
+ elif values.lower() == "false":
97
+ # Handling `--boolean_flag=false`.
98
+ # (Case insensitive after the equal sign)
99
+ flag_value = False
100
+ else:
101
+ raise ValueError("Invalid argument to --{}. Must use flag alone,"
102
+ " or specify true/false.".format(self.dest))
103
+ setattr(namespace, self.dest, flag_value)
104
+
105
+
106
+ def _get_tflite_converter(flags):
107
+ """Makes a TFLiteConverter object based on the flags provided.
108
+
109
+ Args:
110
+ flags: argparse.Namespace object containing TFLite flags.
111
+
112
+ Returns:
113
+ TFLiteConverter object.
114
+
115
+ Raises:
116
+ ValueError: Invalid flags.
117
+ """
118
+ # Parse input and output arrays.
119
+ input_arrays = _parse_array(flags.input_arrays)
120
+ input_shapes = None
121
+ if flags.input_shapes:
122
+ input_shapes_list = [
123
+ _parse_array(shape, type_fn=int)
124
+ for shape in flags.input_shapes.split(":")
125
+ ]
126
+ input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
127
+ output_arrays = _parse_array(flags.output_arrays)
128
+
129
+ converter_kwargs = {
130
+ "input_arrays": input_arrays,
131
+ "input_shapes": input_shapes,
132
+ "output_arrays": output_arrays
133
+ }
134
+
135
+ # Create TFLiteConverter.
136
+ if flags.graph_def_file:
137
+ converter_fn = lite.TFLiteConverter.from_frozen_graph
138
+ converter_kwargs["graph_def_file"] = flags.graph_def_file
139
+ elif flags.saved_model_dir:
140
+ converter_fn = lite.TFLiteConverter.from_saved_model
141
+ converter_kwargs["saved_model_dir"] = flags.saved_model_dir
142
+ converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
143
+ converter_kwargs["signature_key"] = flags.saved_model_signature_key
144
+ elif flags.keras_model_file:
145
+ converter_fn = lite.TFLiteConverter.from_keras_model_file
146
+ converter_kwargs["model_file"] = flags.keras_model_file
147
+ else:
148
+ raise ValueError("--graph_def_file, --saved_model_dir, or "
149
+ "--keras_model_file must be specified.")
150
+
151
+ return converter_fn(**converter_kwargs)
152
+
153
+
154
+ def _convert_tf1_model(flags):
155
+ """Calls function to convert the TensorFlow 1.X model into a TFLite model.
156
+
157
+ Args:
158
+ flags: argparse.Namespace object.
159
+
160
+ Raises:
161
+ ValueError: Invalid flags.
162
+ """
163
+ # Register custom opdefs before converter object creation.
164
+ if flags.custom_opdefs:
165
+ register_custom_opdefs(_parse_array(flags.custom_opdefs))
166
+
167
+ # Create converter.
168
+ converter = _get_tflite_converter(flags)
169
+ if flags.inference_type:
170
+ converter.inference_type = _parse_inference_type(flags.inference_type,
171
+ "inference_type")
172
+ if flags.inference_input_type:
173
+ converter.inference_input_type = _parse_inference_type(
174
+ flags.inference_input_type, "inference_input_type")
175
+ if flags.output_format:
176
+ converter.output_format = _toco_flags_pb2.FileFormat.Value(
177
+ flags.output_format)
178
+
179
+ if flags.mean_values and flags.std_dev_values:
180
+ input_arrays = converter.get_input_arrays()
181
+ std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
182
+
183
+ # In quantized inference, mean_value has to be integer so that the real
184
+ # value 0.0 is exactly representable.
185
+ if converter.inference_type == dtypes.float32:
186
+ mean_values = _parse_array(flags.mean_values, type_fn=float)
187
+ else:
188
+ mean_values = _parse_array(flags.mean_values, type_fn=int)
189
+ quant_stats = list(zip(mean_values, std_dev_values))
190
+ if ((not flags.input_arrays and len(input_arrays) > 1) or
191
+ (len(input_arrays) != len(quant_stats))):
192
+ raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
193
+ "--mean_values. The flags must have the same number of "
194
+ "items. The current input arrays are '{0}'. "
195
+ "--input_arrays must be present when specifying "
196
+ "--std_dev_values and --mean_values with multiple input "
197
+ "tensors in order to map between names and "
198
+ "values.".format(",".join(input_arrays)))
199
+ converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
200
+ if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
201
+ not None):
202
+ converter.default_ranges_stats = (flags.default_ranges_min,
203
+ flags.default_ranges_max)
204
+
205
+ if flags.drop_control_dependency:
206
+ converter.drop_control_dependency = flags.drop_control_dependency
207
+ if flags.reorder_across_fake_quant:
208
+ converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
209
+ if flags.change_concat_input_ranges:
210
+ converter.change_concat_input_ranges = (
211
+ flags.change_concat_input_ranges == "TRUE")
212
+
213
+ if flags.allow_custom_ops:
214
+ converter.allow_custom_ops = flags.allow_custom_ops
215
+
216
+ if flags.target_ops:
217
+ ops_set_options = lite.OpsSet.get_options()
218
+ converter.target_spec.supported_ops = set()
219
+ for option in flags.target_ops.split(","):
220
+ if option not in ops_set_options:
221
+ raise ValueError("Invalid value for --target_ops. Options: "
222
+ "{0}".format(",".join(ops_set_options)))
223
+ converter.target_spec.supported_ops.add(lite.OpsSet(option))
224
+
225
+ if flags.experimental_select_user_tf_ops:
226
+ if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops:
227
+ raise ValueError("--experimental_select_user_tf_ops can only be set if "
228
+ "--target_ops contains SELECT_TF_OPS.")
229
+ user_op_set = set()
230
+ for op_name in flags.experimental_select_user_tf_ops.split(","):
231
+ user_op_set.add(op_name)
232
+ converter.target_spec.experimental_select_user_tf_ops = list(user_op_set)
233
+
234
+ if flags.post_training_quantize:
235
+ converter.optimizations = [lite.Optimize.DEFAULT]
236
+ if converter.inference_type != dtypes.float32:
237
+ print("--post_training_quantize quantizes a graph of inference_type "
238
+ "FLOAT. Overriding inference_type to FLOAT.")
239
+ converter.inference_type = dtypes.float32
240
+
241
+ if flags.quantize_to_float16:
242
+ converter.target_spec.supported_types = [dtypes.float16]
243
+ if not flags.post_training_quantize:
244
+ print("--quantize_to_float16 will only take effect with the "
245
+ "--post_training_quantize flag enabled.")
246
+
247
+ if flags.dump_graphviz_dir:
248
+ converter.dump_graphviz_dir = flags.dump_graphviz_dir
249
+ if flags.dump_graphviz_video:
250
+ converter.dump_graphviz_vode = flags.dump_graphviz_video
251
+ if flags.conversion_summary_dir:
252
+ converter.conversion_summary_dir = flags.conversion_summary_dir
253
+
254
+ converter.experimental_new_converter = flags.experimental_new_converter
255
+
256
+ if flags.experimental_new_quantizer is not None:
257
+ converter.experimental_new_quantizer = flags.experimental_new_quantizer
258
+
259
+ # Convert model.
260
+ output_data = converter.convert()
261
+ with gfile.GFile(flags.output_file, "wb") as f:
262
+ f.write(output_data)
263
+
264
+
265
+ def _convert_tf2_model(flags):
266
+ """Calls function to convert the TensorFlow 2.0 model into a TFLite model.
267
+
268
+ Args:
269
+ flags: argparse.Namespace object.
270
+
271
+ Raises:
272
+ ValueError: Unsupported file format.
273
+ """
274
+ # Load the model.
275
+ if flags.saved_model_dir:
276
+ converter = lite.TFLiteConverterV2.from_saved_model(
277
+ flags.saved_model_dir,
278
+ signature_keys=_parse_array(flags.saved_model_signature_key),
279
+ tags=_parse_set(flags.saved_model_tag_set))
280
+ elif flags.keras_model_file:
281
+ model = keras_deps.get_load_model_function()(flags.keras_model_file)
282
+ converter = lite.TFLiteConverterV2.from_keras_model(model)
283
+
284
+ converter.experimental_new_converter = flags.experimental_new_converter
285
+
286
+ if flags.experimental_new_quantizer is not None:
287
+ converter.experimental_new_quantizer = flags.experimental_new_quantizer
288
+
289
+ # Convert the model.
290
+ tflite_model = converter.convert()
291
+ with gfile.GFile(flags.output_file, "wb") as f:
292
+ f.write(tflite_model)
293
+
294
+
295
+ def _check_tf1_flags(flags, unparsed):
296
+ """Checks the parsed and unparsed flags to ensure they are valid in 1.X.
297
+
298
+ Raises an error if previously support unparsed flags are found. Raises an
299
+ error for parsed flags that don't meet the required conditions.
300
+
301
+ Args:
302
+ flags: argparse.Namespace object containing TFLite flags.
303
+ unparsed: List of unparsed flags.
304
+
305
+ Raises:
306
+ ValueError: Invalid flags.
307
+ """
308
+
309
+ # Check unparsed flags for common mistakes based on previous TOCO.
310
+ def _get_message_unparsed(flag, orig_flag, new_flag):
311
+ if flag.startswith(orig_flag):
312
+ return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
313
+ return ""
314
+
315
+ if unparsed:
316
+ output = ""
317
+ for flag in unparsed:
318
+ output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
319
+ output += _get_message_unparsed(flag, "--savedmodel_directory",
320
+ "--saved_model_dir")
321
+ output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
322
+ output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
323
+ output += _get_message_unparsed(flag, "--dump_graphviz",
324
+ "--dump_graphviz_dir")
325
+ if output:
326
+ raise ValueError(output)
327
+
328
+ # Check that flags are valid.
329
+ if flags.graph_def_file and (not flags.input_arrays or
330
+ not flags.output_arrays):
331
+ raise ValueError("--input_arrays and --output_arrays are required with "
332
+ "--graph_def_file")
333
+
334
+ if flags.input_shapes:
335
+ if not flags.input_arrays:
336
+ raise ValueError("--input_shapes must be used with --input_arrays")
337
+ if flags.input_shapes.count(":") != flags.input_arrays.count(","):
338
+ raise ValueError("--input_shapes and --input_arrays must have the same "
339
+ "number of items")
340
+
341
+ if flags.std_dev_values or flags.mean_values:
342
+ if bool(flags.std_dev_values) != bool(flags.mean_values):
343
+ raise ValueError("--std_dev_values and --mean_values must be used "
344
+ "together")
345
+ if flags.std_dev_values.count(",") != flags.mean_values.count(","):
346
+ raise ValueError("--std_dev_values, --mean_values must have the same "
347
+ "number of items")
348
+
349
+ if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
350
+ raise ValueError("--default_ranges_min and --default_ranges_max must be "
351
+ "used together")
352
+
353
+ if flags.dump_graphviz_video and not flags.dump_graphviz_dir:
354
+ raise ValueError("--dump_graphviz_video must be used with "
355
+ "--dump_graphviz_dir")
356
+
357
+ if flags.custom_opdefs and not flags.experimental_new_converter:
358
+ raise ValueError("--custom_opdefs must be used with "
359
+ "--experimental_new_converter")
360
+ if flags.custom_opdefs and not flags.allow_custom_ops:
361
+ raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
362
+ if (flags.experimental_select_user_tf_ops and
363
+ not flags.experimental_new_converter):
364
+ raise ValueError("--experimental_select_user_tf_ops must be used with "
365
+ "--experimental_new_converter")
366
+
367
+
368
+ def _check_tf2_flags(flags):
369
+ """Checks the parsed and unparsed flags to ensure they are valid in 2.X.
370
+
371
+ Args:
372
+ flags: argparse.Namespace object containing TFLite flags.
373
+
374
+ Raises:
375
+ ValueError: Invalid flags.
376
+ """
377
+ if not flags.keras_model_file and not flags.saved_model_dir:
378
+ raise ValueError("one of the arguments --saved_model_dir "
379
+ "--keras_model_file is required")
380
+
381
+
382
+ def _get_tf1_flags(parser):
383
+ """Returns ArgumentParser for tflite_convert for TensorFlow 1.X.
384
+
385
+ Args:
386
+ parser: ArgumentParser
387
+ """
388
+ # Input file flags.
389
+ input_file_group = parser.add_mutually_exclusive_group(required=True)
390
+ input_file_group.add_argument(
391
+ "--graph_def_file",
392
+ type=str,
393
+ help="Full filepath of file containing frozen TensorFlow GraphDef.")
394
+ input_file_group.add_argument(
395
+ "--saved_model_dir",
396
+ type=str,
397
+ help="Full filepath of directory containing the SavedModel.")
398
+ input_file_group.add_argument(
399
+ "--keras_model_file",
400
+ type=str,
401
+ help="Full filepath of HDF5 file containing tf.Keras model.")
402
+
403
+ # Model format flags.
404
+ parser.add_argument(
405
+ "--output_format",
406
+ type=str.upper,
407
+ choices=["TFLITE", "GRAPHVIZ_DOT"],
408
+ help="Output file format.")
409
+ parser.add_argument(
410
+ "--inference_type",
411
+ type=str.upper,
412
+ default="FLOAT",
413
+ help=("Target data type of real-number arrays in the output file. "
414
+ "Must be either FLOAT, INT8 or UINT8."))
415
+ parser.add_argument(
416
+ "--inference_input_type",
417
+ type=str.upper,
418
+ help=("Target data type of real-number input arrays. Allows for a "
419
+ "different type for input arrays in the case of quantization. "
420
+ "Must be either FLOAT, INT8 or UINT8."))
421
+
422
+ # Input and output arrays flags.
423
+ parser.add_argument(
424
+ "--input_arrays",
425
+ type=str,
426
+ help="Names of the input arrays, comma-separated.")
427
+ parser.add_argument(
428
+ "--input_shapes",
429
+ type=str,
430
+ help="Shapes corresponding to --input_arrays, colon-separated.")
431
+ parser.add_argument(
432
+ "--output_arrays",
433
+ type=str,
434
+ help="Names of the output arrays, comma-separated.")
435
+
436
+ # SavedModel related flags.
437
+ parser.add_argument(
438
+ "--saved_model_tag_set",
439
+ type=str,
440
+ help=("Comma-separated set of tags identifying the MetaGraphDef within "
441
+ "the SavedModel to analyze. All tags must be present. In order to "
442
+ "pass in an empty tag set, pass in \"\". (default \"serve\")"))
443
+ parser.add_argument(
444
+ "--saved_model_signature_key",
445
+ type=str,
446
+ help=("Key identifying the SignatureDef containing inputs and outputs. "
447
+ "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
448
+
449
+ # Quantization flags.
450
+ parser.add_argument(
451
+ "--std_dev_values",
452
+ type=str,
453
+ help=("Standard deviation of training data for each input tensor, "
454
+ "comma-separated floats. Used for quantized input tensors. "
455
+ "(default None)"))
456
+ parser.add_argument(
457
+ "--mean_values",
458
+ type=str,
459
+ help=("Mean of training data for each input tensor, comma-separated "
460
+ "floats. Used for quantized input tensors. (default None)"))
461
+ parser.add_argument(
462
+ "--default_ranges_min",
463
+ type=float,
464
+ help=("Default value for min bound of min/max range values used for all "
465
+ "arrays without a specified range, Intended for experimenting with "
466
+ "quantization via \"dummy quantization\". (default None)"))
467
+ parser.add_argument(
468
+ "--default_ranges_max",
469
+ type=float,
470
+ help=("Default value for max bound of min/max range values used for all "
471
+ "arrays without a specified range, Intended for experimenting with "
472
+ "quantization via \"dummy quantization\". (default None)"))
473
+ # quantize_weights is DEPRECATED.
474
+ parser.add_argument(
475
+ "--quantize_weights",
476
+ dest="post_training_quantize",
477
+ action="store_true",
478
+ help=argparse.SUPPRESS)
479
+ parser.add_argument(
480
+ "--post_training_quantize",
481
+ dest="post_training_quantize",
482
+ action="store_true",
483
+ help=(
484
+ "Boolean indicating whether to quantize the weights of the "
485
+ "converted float model. Model size will be reduced and there will "
486
+ "be latency improvements (at the cost of accuracy). (default False)"))
487
+ parser.add_argument(
488
+ "--quantize_to_float16",
489
+ dest="quantize_to_float16",
490
+ action="store_true",
491
+ help=("Boolean indicating whether to quantize weights to fp16 instead of "
492
+ "the default int8 when post-training quantization "
493
+ "(--post_training_quantize) is enabled. (default False)"))
494
+ # Graph manipulation flags.
495
+ parser.add_argument(
496
+ "--drop_control_dependency",
497
+ action="store_true",
498
+ help=("Boolean indicating whether to drop control dependencies silently. "
499
+ "This is due to TensorFlow not supporting control dependencies. "
500
+ "(default True)"))
501
+ parser.add_argument(
502
+ "--reorder_across_fake_quant",
503
+ action="store_true",
504
+ help=("Boolean indicating whether to reorder FakeQuant nodes in "
505
+ "unexpected locations. Used when the location of the FakeQuant "
506
+ "nodes is preventing graph transformations necessary to convert "
507
+ "the graph. Results in a graph that differs from the quantized "
508
+ "training graph, potentially causing differing arithmetic "
509
+ "behavior. (default False)"))
510
+ # Usage for this flag is --change_concat_input_ranges=true or
511
+ # --change_concat_input_ranges=false in order to make it clear what the flag
512
+ # is set to. This keeps the usage consistent with other usages of the flag
513
+ # where the default is different. The default value here is False.
514
+ parser.add_argument(
515
+ "--change_concat_input_ranges",
516
+ type=str.upper,
517
+ choices=["TRUE", "FALSE"],
518
+ help=("Boolean to change behavior of min/max ranges for inputs and "
519
+ "outputs of the concat operator for quantized models. Changes the "
520
+ "ranges of concat operator overlap when true. (default False)"))
521
+
522
+ # Permitted ops flags.
523
+ parser.add_argument(
524
+ "--allow_custom_ops",
525
+ action=_ParseBooleanFlag,
526
+ nargs="?",
527
+ help=("Boolean indicating whether to allow custom operations. When false "
528
+ "any unknown operation is an error. When true, custom ops are "
529
+ "created for any op that is unknown. The developer will need to "
530
+ "provide these to the TensorFlow Lite runtime with a custom "
531
+ "resolver. (default False)"))
532
+ parser.add_argument(
533
+ "--custom_opdefs",
534
+ type=str,
535
+ help=("String representing a list of custom ops OpDefs delineated with "
536
+ "commas that are included in the GraphDef. Required when using "
537
+ "custom operations with --experimental_new_converter."))
538
+ parser.add_argument(
539
+ "--target_ops",
540
+ type=str,
541
+ help=("Experimental flag, subject to change. Set of OpsSet options "
542
+ "indicating which converter to use. Options: {0}. One or more "
543
+ "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
544
+ "".format(",".join(lite.OpsSet.get_options()))))
545
+ parser.add_argument(
546
+ "--experimental_select_user_tf_ops",
547
+ type=str,
548
+ help=("Experimental flag, subject to change. Comma separated list of "
549
+ "user's defined TensorFlow operators required in the runtime."))
550
+
551
+ # Logging flags.
552
+ parser.add_argument(
553
+ "--dump_graphviz_dir",
554
+ type=str,
555
+ help=("Full filepath of folder to dump the graphs at various stages of "
556
+ "processing GraphViz .dot files. Preferred over --output_format="
557
+ "GRAPHVIZ_DOT in order to keep the requirements of the output "
558
+ "file."))
559
+ parser.add_argument(
560
+ "--dump_graphviz_video",
561
+ action="store_true",
562
+ help=("Boolean indicating whether to dump the graph after every graph "
563
+ "transformation"))
564
+ parser.add_argument(
565
+ "--conversion_summary_dir",
566
+ type=str,
567
+ help=("Full filepath to store the conversion logs, which includes "
568
+ "graphviz of the model before/after the conversion, an HTML report "
569
+ "and the conversion proto buffers. This will only be generated "
570
+ "when passing --experimental_new_converter"))
571
+
572
+
573
+ def _get_tf2_flags(parser):
574
+ """Returns ArgumentParser for tflite_convert for TensorFlow 2.0.
575
+
576
+ Args:
577
+ parser: ArgumentParser
578
+ """
579
+ # Input file flags.
580
+ input_file_group = parser.add_mutually_exclusive_group()
581
+ input_file_group.add_argument(
582
+ "--saved_model_dir",
583
+ type=str,
584
+ help="Full path of the directory containing the SavedModel.")
585
+ input_file_group.add_argument(
586
+ "--keras_model_file",
587
+ type=str,
588
+ help="Full filepath of HDF5 file containing tf.Keras model.")
589
+ # SavedModel related flags.
590
+ parser.add_argument(
591
+ "--saved_model_tag_set",
592
+ type=str,
593
+ help=("Comma-separated set of tags identifying the MetaGraphDef within "
594
+ "the SavedModel to analyze. All tags must be present. In order to "
595
+ "pass in an empty tag set, pass in \"\". (default \"serve\")"))
596
+ parser.add_argument(
597
+ "--saved_model_signature_key",
598
+ type=str,
599
+ help=("Key identifying the SignatureDef containing inputs and outputs. "
600
+ "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
601
+
602
+ # Enables 1.X converter in 2.X.
603
+ parser.add_argument(
604
+ "--enable_v1_converter",
605
+ action="store_true",
606
+ help=("Enables the TensorFlow V1 converter in 2.0"))
607
+
608
+
609
+ def _get_parser(use_v2_converter):
610
+ """Returns an ArgumentParser for tflite_convert.
611
+
612
+ Args:
613
+ use_v2_converter: Indicates which converter to return.
614
+ Return: ArgumentParser.
615
+ """
616
+ parser = argparse.ArgumentParser(
617
+ description=("Command line tool to run TensorFlow Lite Converter."))
618
+
619
+ # Output file flag.
620
+ parser.add_argument(
621
+ "--output_file",
622
+ type=str,
623
+ help="Full filepath of the output file.",
624
+ required=True)
625
+
626
+ if use_v2_converter:
627
+ _get_tf2_flags(parser)
628
+ else:
629
+ _get_tf1_flags(parser)
630
+
631
+ parser.add_argument(
632
+ "--experimental_new_converter",
633
+ action=_ParseBooleanFlag,
634
+ nargs="?",
635
+ default=True,
636
+ help=("Experimental flag, subject to change. Enables MLIR-based "
637
+ "conversion instead of TOCO conversion. (default True)"))
638
+
639
+ parser.add_argument(
640
+ "--experimental_new_quantizer",
641
+ action=_ParseBooleanFlag,
642
+ nargs="?",
643
+ help=("Experimental flag, subject to change. Enables MLIR-based "
644
+ "quantizer instead of flatbuffer conversion. (default True)"))
645
+ return parser
646
+
647
+
648
+ def run_main(_):
649
+ """Main in tflite_convert.py."""
650
+ use_v2_converter = tf2.enabled()
651
+ parser = _get_parser(use_v2_converter)
652
+ tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
653
+
654
+ # If the user is running TensorFlow 2.X but has passed in enable_v1_converter
655
+ # then parse the flags again with the 1.X converter flags.
656
+ if tf2.enabled() and tflite_flags.enable_v1_converter:
657
+ use_v2_converter = False
658
+ parser = _get_parser(use_v2_converter)
659
+ tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
660
+
661
+ # Checks if the flags are valid.
662
+ try:
663
+ if use_v2_converter:
664
+ _check_tf2_flags(tflite_flags)
665
+ else:
666
+ _check_tf1_flags(tflite_flags, unparsed)
667
+ except ValueError as e:
668
+ parser.print_usage()
669
+ file_name = os.path.basename(sys.argv[0])
670
+ sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
671
+ sys.exit(1)
672
+
673
+ # Convert the model according to the user provided flag.
674
+ if use_v2_converter:
675
+ _convert_tf2_model(tflite_flags)
676
+ else:
677
+ try:
678
+ _convert_tf1_model(tflite_flags)
679
+ finally:
680
+ if tflite_flags.conversion_summary_dir:
681
+ if tflite_flags.experimental_new_converter:
682
+ gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir,
683
+ tflite_flags.post_training_quantize,
684
+ tflite_flags.output_file)
685
+ else:
686
+ warnings.warn(
687
+ "Conversion summary will only be generated when enabling"
688
+ " the new converter via --experimental_new_converter. ")
689
+
690
+
691
+ def main():
692
+ app.run(main=run_main, argv=sys.argv[:1])
693
+
694
+
695
+ if __name__ == "__main__":
696
+ main()
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/util.py ADDED
@@ -0,0 +1,1177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Functions used by multiple converter files."""
16
+
17
+ import copy
18
+ import datetime
19
+ import sys
20
+
21
+ from absl import logging
22
+ import flatbuffers
23
+ import numpy as np
24
+
25
+ from tensorflow.core.protobuf import config_pb2 as _config_pb2
26
+ from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
27
+ from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb
28
+ from tensorflow.lite.python import schema_py_generated as schema_fb
29
+ from tensorflow.lite.python import schema_util
30
+ from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
31
+ from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
32
+ from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
33
+ from tensorflow.lite.tools import flatbuffer_utils
34
+ from tensorflow.python.eager import function
35
+ from tensorflow.python.framework import convert_to_constants as _convert_to_constants
36
+ from tensorflow.python.framework import dtypes
37
+ from tensorflow.python.framework import error_interpolation as _error_interpolation
38
+ from tensorflow.python.grappler import tf_optimizer
39
+ from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
40
+
41
+ # The field name of conversion metadata in the flatbuffer file.
42
+ CONVERSION_METADATA_FIELD_NAME = "CONVERSION_METADATA"
43
+
44
+ # Keras functions used by TFLite
45
+ model_input_signature = _tflite_keras_util.model_input_signature
46
+ trace_model_call = _tflite_keras_util.trace_model_call
47
+ get_save_spec = _tflite_keras_util.get_save_spec
48
+
49
+ # Jax functions used by TFLite
50
+ # pylint: disable=g-import-not-at-top
51
+ # pylint: disable=unused-import
52
+ try:
53
+ from jax import jit as _jit
54
+ except ImportError:
55
+ _jit = None
56
+ # pylint: enable=g-import-not-at-top
57
+ # pylint: enable=unused-import
58
+
59
+ # Defined as per TFLite schema
60
+ _MAP_TFLITE_ENUM_TO_TF_TYPES = {
61
+ 0: dtypes.float32,
62
+ 1: dtypes.float16,
63
+ 2: dtypes.int32,
64
+ 3: dtypes.uint8,
65
+ 4: dtypes.int64,
66
+ 5: dtypes.string,
67
+ 6: dtypes.bool,
68
+ 7: dtypes.int16,
69
+ 8: dtypes.complex64,
70
+ 9: dtypes.int8,
71
+ 10: dtypes.float64,
72
+ 11: dtypes.complex128,
73
+ 16: dtypes.uint32,
74
+ }
75
+
76
+ _TFLITE_FILE_IDENTIFIER = b"TFL3"
77
+
78
+ _MAP_QUANT_TO_IO_TYPES = {
79
+ dtypes.int8: {dtypes.int8, dtypes.uint8},
80
+ dtypes.int16: {dtypes.int16},
81
+ }
82
+
83
+
84
+ def _convert_tflite_enum_type_to_tf_type(tflite_enum_type):
85
+ """Converts tflite enum type (eg: 0) to tf type (eg: tf.float32).
86
+
87
+ Args:
88
+ tflite_enum_type: tflite enum type (eg: 0, that corresponds to float32)
89
+
90
+ Raises:
91
+ ValueError: If an invalid tflite enum type is provided.
92
+
93
+ Returns:
94
+ tf type (eg: tf.float32)
95
+ """
96
+ tf_type = _MAP_TFLITE_ENUM_TO_TF_TYPES.get(tflite_enum_type)
97
+ if tf_type is None:
98
+ raise ValueError(
99
+ "Unsupported enum {}. The valid map of enum to tf types is : {}"
100
+ .format(tflite_enum_type, _MAP_TFLITE_ENUM_TO_TF_TYPES))
101
+ return tf_type
102
+
103
+
104
+ def get_tf_type_name(tf_type):
105
+ """Converts tf.dtype (eg: tf.float32) to str (eg: "tf.float32")."""
106
+ return "tf." + tf_type.name if tf_type else None
107
+
108
+
109
+ def get_tensor_name(tensor):
110
+ """Returns name of the input tensor.
111
+
112
+ Args:
113
+ tensor: tf.Tensor
114
+
115
+ Returns:
116
+ str
117
+ """
118
+ parts = tensor.name.split(":")
119
+ if len(parts) > 2:
120
+ raise ValueError("Tensor name invalid. Expect 0 or 1 colon, got {0}".format(
121
+ len(parts) - 1))
122
+
123
+ # To be consistent with the tensor naming scheme in tensorflow, we need
124
+ # drop the ':0' suffix for the first tensor.
125
+ if len(parts) > 1 and parts[1] != "0":
126
+ return tensor.name
127
+ return parts[0]
128
+
129
+
130
+ def get_tensors_from_tensor_names(graph, tensor_names):
131
+ """Gets the Tensors associated with the `tensor_names` in the provided graph.
132
+
133
+ Args:
134
+ graph: TensorFlow Graph.
135
+ tensor_names: List of strings that represent names of tensors in the graph.
136
+
137
+ Returns:
138
+ A list of Tensor objects in the same order the names are provided.
139
+
140
+ Raises:
141
+ ValueError:
142
+ tensor_names contains an invalid tensor name.
143
+ """
144
+ # Get the list of all of the tensors.
145
+ tensor_name_to_tensor = {}
146
+ for op in graph.get_operations():
147
+ for tensor in op.values():
148
+ tensor_name_to_tensor[get_tensor_name(tensor)] = tensor
149
+
150
+ # Get the tensors associated with tensor_names.
151
+ tensors = []
152
+ invalid_tensors = []
153
+ for name in tensor_names:
154
+ if not isinstance(name, str):
155
+ raise ValueError("Invalid type for a tensor name in the provided graph. "
156
+ "Expected type for a tensor name is 'str', instead got "
157
+ "type '{}' for tensor name '{}'".format(
158
+ type(name), name))
159
+
160
+ tensor = tensor_name_to_tensor.get(name)
161
+ if tensor is None:
162
+ invalid_tensors.append(name)
163
+ else:
164
+ tensors.append(tensor)
165
+
166
+ # Throw ValueError if any user input names are not valid tensors.
167
+ if invalid_tensors:
168
+ raise ValueError("Invalid tensors '{}' were found.".format(
169
+ ",".join(invalid_tensors)))
170
+ return tensors
171
+
172
+
173
+ def set_tensor_shapes(tensors, shapes):
174
+ """Sets Tensor shape for each tensor if the shape is defined.
175
+
176
+ Args:
177
+ tensors: TensorFlow tensor.Tensor.
178
+ shapes: Dict of strings representing input tensor names to list of
179
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
180
+
181
+ Raises:
182
+ ValueError:
183
+ `shapes` contains an invalid tensor.
184
+ `shapes` contains an invalid shape for a valid tensor.
185
+ """
186
+ if shapes:
187
+ tensor_names_to_tensor = {
188
+ get_tensor_name(tensor): tensor for tensor in tensors
189
+ }
190
+ for name, shape in shapes.items():
191
+ if name not in tensor_names_to_tensor:
192
+ raise ValueError("Invalid tensor \'{}\' found in tensor shapes "
193
+ "map.".format(name))
194
+ if shape is not None:
195
+ tensor = tensor_names_to_tensor[name]
196
+ try:
197
+ tensor.set_shape(shape)
198
+ except ValueError as error:
199
+ message = ("The shape of tensor '{0}' cannot be changed from {1} to "
200
+ "{2}. {3}".format(name, tensor.shape, shape, str(error)))
201
+ raise ValueError(message)
202
+
203
+
204
+ def get_grappler_config(optimizers_list):
205
+ """Creates a tf.compat.v1.ConfigProto for configuring Grappler.
206
+
207
+ Args:
208
+ optimizers_list: List of strings that represents the list of optimizers.
209
+
210
+ Returns:
211
+ tf.ConfigProto.
212
+ """
213
+ config = _config_pb2.ConfigProto()
214
+ rewrite_options = config.graph_options.rewrite_options
215
+ for optimizer in optimizers_list:
216
+ rewrite_options.optimizers.append(optimizer)
217
+ return config
218
+
219
+
220
+ def run_graph_optimizations(graph_def,
221
+ input_arrays,
222
+ output_arrays,
223
+ config,
224
+ graph=None):
225
+ """Apply standard TensorFlow optimizations to the graph_def.
226
+
227
+ Args:
228
+ graph_def: Frozen GraphDef to be optimized.
229
+ input_arrays: List of arrays that are considered inputs of the graph.
230
+ output_arrays: List of arrays that are considered outputs of the graph.
231
+ config: tf.ConfigProto.
232
+ graph: TensorFlow Graph. Required when Eager mode is enabled. (default None)
233
+
234
+ Returns:
235
+ A new, optimized GraphDef.
236
+ """
237
+ meta_graph = _export_meta_graph(graph_def=graph_def, graph=graph)
238
+
239
+ signature = _meta_graph_pb2.SignatureDef()
240
+ for array in input_arrays:
241
+ signature.inputs[array.name].name = array.name
242
+ signature.inputs[array.name].dtype = array.dtype.as_datatype_enum
243
+ signature.inputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
244
+
245
+ for array in output_arrays:
246
+ signature.outputs[array.name].name = array.name
247
+ signature.outputs[array.name].dtype = array.dtype.as_datatype_enum
248
+ signature.outputs[array.name].tensor_shape.CopyFrom(array.shape.as_proto())
249
+
250
+ meta_graph.signature_def["not_used_key"].CopyFrom(signature)
251
+
252
+ # We need to add a collection called 'train_op' so that grappler
253
+ # knows what the outputs are.
254
+ fetch_collection = _meta_graph_pb2.CollectionDef()
255
+ for array in input_arrays + output_arrays:
256
+ fetch_collection.node_list.value.append(array.name)
257
+ meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
258
+
259
+ return tf_optimizer.OptimizeGraph(config, meta_graph)
260
+
261
+
262
+ def _convert_op_hints_if_present(sess, graph_def, output_tensors,
263
+ hinted_outputs_nodes):
264
+ if is_frozen_graph(sess):
265
+ raise ValueError("Try to convert op hints, needs unfrozen graph.")
266
+ output_arrays = [get_tensor_name(tensor) for tensor in output_tensors]
267
+ graph_def = _convert_to_constants.convert_variables_to_constants(
268
+ sess, graph_def, output_arrays + hinted_outputs_nodes)
269
+ graph_def = convert_op_hints_to_stubs(graph_def=graph_def)
270
+ return graph_def
271
+
272
+
273
+ def freeze_graph(sess, input_tensors, output_tensors):
274
+ """Returns a frozen GraphDef.
275
+
276
+ Runs a Grappler pass and freezes a graph with Variables in it. Otherwise the
277
+ existing GraphDef is returned. The Grappler pass is only run on models that
278
+ are frozen in order to inline the functions in the graph.
279
+ If OpHints is present, it will try to convert the OpHint graph.
280
+
281
+ Args:
282
+ sess: TensorFlow Session.
283
+ input_tensors: List of input tensors.
284
+ output_tensors: List of output tensors (only .name is used from this).
285
+
286
+ Returns:
287
+ Frozen GraphDef.
288
+ """
289
+ # Runs a Grappler pass in order to inline any functions in the graph.
290
+ # Asides from inlining any simple function, Grappler will also try to lower
291
+ # while loop into switch merge representation which is undesired for Ophints,
292
+ # so we simply remove those attributes to prevent Grappler from doing so.
293
+ graph_def = _convert_to_constants.disable_lower_using_switch_merge(
294
+ sess.graph_def)
295
+ config = get_grappler_config(["function"])
296
+ graph_def = run_graph_optimizations(
297
+ graph_def, input_tensors, output_tensors, config, graph=sess.graph)
298
+
299
+ # If ophints are present, just convert them.
300
+ hinted_outputs_nodes = find_all_hinted_output_nodes(sess)
301
+ if hinted_outputs_nodes:
302
+ return _convert_op_hints_if_present(sess, graph_def, output_tensors,
303
+ hinted_outputs_nodes)
304
+
305
+ if not is_frozen_graph(sess):
306
+ output_node_names = [tensor.name.split(":")[0] for tensor in output_tensors]
307
+ return _convert_to_constants.convert_variables_to_constants(
308
+ sess, graph_def, output_node_names
309
+ )
310
+ else:
311
+ return sess.graph_def
312
+
313
+
314
+ def is_frozen_graph(sess):
315
+ """Determines if the graph is frozen.
316
+
317
+ Determines if a graph has previously been frozen by checking for any
318
+ operations of type Variable*. If variables are found, the graph is not frozen.
319
+
320
+ Args:
321
+ sess: TensorFlow Session.
322
+
323
+ Returns:
324
+ Bool.
325
+ """
326
+ for op in sess.graph.get_operations():
327
+ if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
328
+ return False
329
+ return True
330
+
331
+
332
+ def build_debug_info_func(original_graph):
333
+ """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
334
+
335
+ Args:
336
+ original_graph: The original `Graph` containing all the op stack traces.
337
+
338
+ Returns:
339
+ A function which retrieves the stack traces from the original graph and
340
+ converts them to a `GraphDebugInfo` for a given set of nodes.
341
+ """
342
+
343
+ def f(original_nodes):
344
+ """Function to create `GraphDebugInfo` for the given `original_nodes`."""
345
+ if not original_graph:
346
+ return None
347
+ # For the given nodes, gets all the op definitions in the original graph.
348
+ useful_ops = []
349
+ for func, name in original_nodes:
350
+ try:
351
+ if not func:
352
+ useful_ops.append((func, original_graph.get_operation_by_name(name)))
353
+ else:
354
+ sub_func = original_graph._get_function(func) # pylint: disable=protected-access
355
+ if isinstance(sub_func, function.AtomicFunction): # pylint: disable=protected-access
356
+ useful_ops.append(
357
+ (func, sub_func.graph.get_operation_by_name(name)))
358
+ else:
359
+ sys.stderr.write(
360
+ "Use '@tf.function' or '@defun' to decorate the function.\n")
361
+ continue
362
+ except KeyError:
363
+ # New node created by graph optimizer. No stack trace from source code.
364
+ continue
365
+ # Convert all the op definitions to stack traces in terms of GraphDebugInfo.
366
+ return _error_interpolation.create_graph_debug_info_def(useful_ops)
367
+
368
+ return f
369
+
370
+
371
+ def convert_debug_info_func(saved_debug_info):
372
+ """Returns a method to retrieve the `GraphDebugInfo` from the original graph.
373
+
374
+ Args:
375
+ saved_debug_info: The `GraphDebugInfo` containing all the debug info.
376
+
377
+ Returns:
378
+ A function which retrieves the stack traces from the original graph and
379
+ converts them to a `GraphDebugInfo` for a given set of nodes.
380
+ """
381
+
382
+ def f(original_nodes):
383
+ """Function to create `GraphDebugInfo` for the given `original_nodes`."""
384
+ del original_nodes
385
+ return saved_debug_info
386
+
387
+ return f
388
+
389
+
390
+ def get_debug_info(nodes_to_debug_info_func, converted_graph):
391
+ """Returns the debug info for the original nodes in the `converted_graph`.
392
+
393
+ Args:
394
+ nodes_to_debug_info_func: The method to collect the op debug info for the
395
+ nodes.
396
+ converted_graph: A `GraphDef` after optimization and transformation.
397
+
398
+ Returns:
399
+ `GraphDebugInfo` for all the original nodes in `converted_graph`.
400
+ """
401
+ if not nodes_to_debug_info_func:
402
+ return None
403
+
404
+ # Collect all the debug info nodes from the converted_graph
405
+ original_nodes = set()
406
+ for node in converted_graph.node:
407
+ debug_nodes = node.experimental_debug_info.original_node_names
408
+ debug_funcs = node.experimental_debug_info.original_func_names
409
+ # If the `original_node_names` are empty, uses the node name directly.
410
+ if not debug_nodes:
411
+ original_nodes.add(("", node.name))
412
+ else:
413
+ for i in range(len(debug_nodes)):
414
+ debug_func = "" if i >= len(debug_funcs) else debug_funcs[i]
415
+ original_nodes.add((debug_func, debug_nodes[i]))
416
+
417
+ # Convert the nodes to the debug info proto object.
418
+ return nodes_to_debug_info_func(original_nodes)
419
+
420
+
421
+ def convert_bytes_to_c_source(data,
422
+ array_name,
423
+ max_line_width=80,
424
+ include_guard=None,
425
+ include_path=None,
426
+ use_tensorflow_license=False):
427
+ """Returns strings representing a C constant array containing `data`.
428
+
429
+ Args:
430
+ data: Byte array that will be converted into a C constant.
431
+ array_name: String to use as the variable name for the constant array.
432
+ max_line_width: The longest line length, for formatting purposes.
433
+ include_guard: Name to use for the include guard macro definition.
434
+ include_path: Optional path to include in the source file.
435
+ use_tensorflow_license: Whether to include the standard TensorFlow Apache2
436
+ license in the generated files.
437
+
438
+ Returns:
439
+ Text that can be compiled as a C source file to link in the data as a
440
+ literal array of values.
441
+ Text that can be used as a C header file to reference the literal array.
442
+ """
443
+
444
+ starting_pad = " "
445
+ array_lines = []
446
+ array_line = starting_pad
447
+ for value in bytearray(data):
448
+ if (len(array_line) + 4) > max_line_width:
449
+ array_lines.append(array_line + "\n")
450
+ array_line = starting_pad
451
+ array_line += " 0x%02x," % (value,)
452
+ if len(array_line) > len(starting_pad):
453
+ array_lines.append(array_line + "\n")
454
+ array_values = "".join(array_lines)
455
+
456
+ if include_guard is None:
457
+ include_guard = "TENSORFLOW_LITE_UTIL_" + array_name.upper() + "_DATA_H_"
458
+
459
+ if include_path is not None:
460
+ include_line = "#include \"{include_path}\"\n".format(
461
+ include_path=include_path)
462
+ else:
463
+ include_line = ""
464
+
465
+ if use_tensorflow_license:
466
+ license_text = """
467
+ /* Copyright {year} The TensorFlow Authors. All Rights Reserved.
468
+
469
+ Licensed under the Apache License, Version 2.0 (the "License");
470
+ you may not use this file except in compliance with the License.
471
+ You may obtain a copy of the License at
472
+
473
+ http://www.apache.org/licenses/LICENSE-2.0
474
+
475
+ Unless required by applicable law or agreed to in writing, software
476
+ distributed under the License is distributed on an "AS IS" BASIS,
477
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
478
+ See the License for the specific language governing permissions and
479
+ limitations under the License.
480
+ ==============================================================================*/
481
+ """.format(year=datetime.date.today().year)
482
+ else:
483
+ license_text = ""
484
+
485
+ source_template = """{license_text}
486
+ // This is a TensorFlow Lite model file that has been converted into a C data
487
+ // array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
488
+ // This form is useful for compiling into a binary for devices that don't have a
489
+ // file system.
490
+
491
+ {include_line}
492
+ // We need to keep the data array aligned on some architectures.
493
+ #ifdef __has_attribute
494
+ #define HAVE_ATTRIBUTE(x) __has_attribute(x)
495
+ #else
496
+ #define HAVE_ATTRIBUTE(x) 0
497
+ #endif
498
+ #if HAVE_ATTRIBUTE(aligned) || (defined(__GNUC__) && !defined(__clang__))
499
+ #define DATA_ALIGN_ATTRIBUTE __attribute__((aligned(4)))
500
+ #else
501
+ #define DATA_ALIGN_ATTRIBUTE
502
+ #endif
503
+
504
+ const unsigned char {array_name}[] DATA_ALIGN_ATTRIBUTE = {{
505
+ {array_values}}};
506
+ const int {array_name}_len = {array_length};
507
+ """
508
+
509
+ source_text = source_template.format(
510
+ array_name=array_name,
511
+ array_length=len(data),
512
+ array_values=array_values,
513
+ license_text=license_text,
514
+ include_line=include_line)
515
+
516
+ header_template = """
517
+ {license_text}
518
+
519
+ // This is a TensorFlow Lite model file that has been converted into a C data
520
+ // array using the tensorflow.lite.util.convert_bytes_to_c_source() function.
521
+ // This form is useful for compiling into a binary for devices that don't have a
522
+ // file system.
523
+
524
+ #ifndef {include_guard}
525
+ #define {include_guard}
526
+
527
+ extern const unsigned char {array_name}[];
528
+ extern const int {array_name}_len;
529
+
530
+ #endif // {include_guard}
531
+ """
532
+
533
+ header_text = header_template.format(
534
+ array_name=array_name,
535
+ include_guard=include_guard,
536
+ license_text=license_text)
537
+
538
+ return source_text, header_text
539
+
540
+
541
+ def _convert_model_from_bytearray_to_object(model_bytearray):
542
+ """Converts a tflite model from a bytearray into a parsable object."""
543
+ model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
544
+ model_object = schema_fb.ModelT.InitFromObj(model_object)
545
+ model_object = copy.deepcopy(model_object)
546
+ return model_object
547
+
548
+
549
+ def _convert_model_from_object_to_bytearray(model_object):
550
+ """Converts a tflite model from a parsable object into a bytearray."""
551
+ # Initial size of the buffer, which will grow automatically if needed
552
+ builder = flatbuffers.Builder(1024)
553
+ model_offset = model_object.Pack(builder)
554
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
555
+ return bytes(builder.Output())
556
+
557
+
558
+ def get_quantize_opcode_idx(model):
559
+ """Returns the quantize op idx."""
560
+ quant_opcode_idxs = []
561
+ for idx, opcode in enumerate(model.operatorCodes):
562
+ builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
563
+ if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
564
+ quant_opcode_idxs.append(idx)
565
+ return quant_opcode_idxs
566
+
567
+
568
+ def get_dequantize_opcode_idx(model):
569
+ """Returns the quantize op idx."""
570
+ quant_opcode_idxs = []
571
+ for idx, opcode in enumerate(model.operatorCodes):
572
+ builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
573
+ if builtin_code == schema_fb.BuiltinOperator.DEQUANTIZE:
574
+ quant_opcode_idxs.append(idx)
575
+ return quant_opcode_idxs
576
+
577
+
578
+ def _update_signature_def_tensors(tensor_maps, map_old_to_new_tensors):
579
+ """Update the tensors in the SignatureDef's TensorMaps."""
580
+ for i in range(len(tensor_maps)):
581
+ if tensor_maps[i].tensorIndex in map_old_to_new_tensors:
582
+ tensor_maps[i].tensorIndex = (
583
+ map_old_to_new_tensors[tensor_maps[i].tensorIndex])
584
+
585
+
586
+ def _remove_tensors_from_model(model, remove_tensors_idxs):
587
+ """Remove tensors from model."""
588
+ if not remove_tensors_idxs:
589
+ return
590
+ if len(model.subgraphs) > 1:
591
+ logging.info("Skipping the removal of dangled tensors since the model has "
592
+ "multiple subgraphs and tensors can be used in the different "
593
+ "subgraph(s)")
594
+ return
595
+ subgraph = model.subgraphs[0]
596
+ tensors = subgraph.tensors
597
+ operators = subgraph.operators
598
+
599
+ logging.debug("Removing tensors at indices : %s", remove_tensors_idxs)
600
+ # An optimized check to validate if "remove_tensors_idxs" (eg: [4,5,6]) is an
601
+ # exact subset, with ordering, of "tensors" indices (eg: [0,1,2,3,4,5,6]).
602
+ if min(remove_tensors_idxs) == len(tensors) - len(remove_tensors_idxs):
603
+ logging.debug("Removing tensors only at the end of the tensor list")
604
+ del tensors[min(remove_tensors_idxs):]
605
+ else:
606
+ logging.debug("Removing tensors requires updating the model")
607
+ # Map the old tensor indices to new tensor indices
608
+ d_old_to_new_tensors = {}
609
+ left_shift_by = 0
610
+ for idx in range(len(tensors)):
611
+ if idx in remove_tensors_idxs:
612
+ left_shift_by += 1
613
+ else:
614
+ d_old_to_new_tensors[idx] = idx - left_shift_by
615
+ logging.debug("Old to new tensors map: %s", d_old_to_new_tensors.__str__())
616
+ # Update tensor indices referenced throughout the model
617
+ def update_tensors(tensor_idxs):
618
+ for i, ti in enumerate(tensor_idxs):
619
+ tensor_idxs[i] = d_old_to_new_tensors.get(ti, -1)
620
+ update_tensors(subgraph.inputs)
621
+ update_tensors(subgraph.outputs)
622
+ for op in operators:
623
+ update_tensors(op.inputs)
624
+ update_tensors(op.outputs)
625
+ if model.signatureDefs:
626
+ signature_def = model.signatureDefs[0]
627
+ _update_signature_def_tensors(signature_def.inputs, d_old_to_new_tensors)
628
+ _update_signature_def_tensors(signature_def.outputs, d_old_to_new_tensors)
629
+ # Delete the tensors
630
+ for idx in sorted(remove_tensors_idxs, reverse=True):
631
+ tensors.pop(idx)
632
+ logging.debug("Removed tensors marked for deletion")
633
+
634
+
635
+ def _modify_model_input_type(model, inference_input_type=dtypes.float32):
636
+ """Modify model input type."""
637
+ if inference_input_type == dtypes.float32:
638
+ return
639
+
640
+ if not model.signatureDefs:
641
+ _modify_model_input_type_per_subgraph(model, 0, -1, inference_input_type)
642
+ return
643
+
644
+ for signature_index, signature_def in enumerate(model.signatureDefs):
645
+ _modify_model_input_type_per_subgraph(model, signature_def.subgraphIndex,
646
+ signature_index, inference_input_type)
647
+
648
+
649
+ def _modify_model_input_type_per_subgraph(model, subgraph_index,
650
+ signature_index,
651
+ inference_input_type):
652
+ """Modify model input type per subgraph."""
653
+ subgraph = model.subgraphs[subgraph_index]
654
+ tensors = subgraph.tensors
655
+ operators = subgraph.operators
656
+
657
+ # Find all quantize operators
658
+ quant_opcode_idxs = get_quantize_opcode_idx(model)
659
+ if operators and not quant_opcode_idxs:
660
+ for input_idx in subgraph.inputs:
661
+ input_type = _convert_tflite_enum_type_to_tf_type(tensors[input_idx].type)
662
+ if input_type == dtypes.float32:
663
+ raise ValueError("Model input is not dequantized.")
664
+ # None of the inputs have float32, then they must be int16, int8, or bool
665
+ return
666
+
667
+ # Validate that the model input is quantized
668
+ input_quant_ops = []
669
+ for op in operators:
670
+ # Find operators that quantize model input
671
+ if op.opcodeIndex in quant_opcode_idxs and op.inputs[0] in subgraph.inputs:
672
+ float_tensor, quant_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
673
+ # If found, validate that the operator's input type is float
674
+ float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
675
+ if float_type != dtypes.float32:
676
+ if float_type == inference_input_type:
677
+ continue
678
+ else:
679
+ raise ValueError(
680
+ "Initial model input type must be tf.float32. Expected type for "
681
+ "tensor with name '{}' is tf.float32, instead type is {}".format(
682
+ float_tensor.name, get_tf_type_name(float_type)))
683
+ # If found, validate that the operator output is quantized and compatible
684
+ # with the final model input type
685
+ quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
686
+ if quant_type not in _MAP_QUANT_TO_IO_TYPES:
687
+ raise ValueError(
688
+ "Initial model input is not quantized. Expected type for "
689
+ "tensor with name '{}' should be in {}, instead type is {}".format(
690
+ quant_tensor.name,
691
+ tuple(get_tf_type_name(t) for t in
692
+ _MAP_QUANT_TO_IO_TYPES.keys()),
693
+ get_tf_type_name(quant_type)))
694
+ else:
695
+ inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
696
+ if inference_input_type not in inference_io_types:
697
+ raise ValueError(
698
+ "Unsupported `inference_input_type` value. Expected to be in "
699
+ "{}, instead got {}.".format(
700
+ tuple(get_tf_type_name(t) for t in inference_io_types),
701
+ get_tf_type_name(inference_input_type)))
702
+ input_quant_ops.append(op)
703
+
704
+ if len(subgraph.inputs) != len(input_quant_ops):
705
+ logging.warning(
706
+ "For model inputs containing unsupported operations which cannot be "
707
+ "quantized, the `inference_input_type` attribute will default to the "
708
+ "original type."
709
+ )
710
+
711
+ # Modify model input type
712
+ if inference_input_type == dtypes.uint8:
713
+ # Change quant op (float to int8) to quant op (uint8 to int8)
714
+ for op in input_quant_ops:
715
+ int8_quantization = tensors[op.outputs[0]].quantization
716
+ uint8_quantization = schema_fb.QuantizationParametersT()
717
+ uint8_quantization.scale = [int8_quantization.scale[0]]
718
+ uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
719
+ tensors[op.inputs[0]].quantization = uint8_quantization
720
+ tensors[op.inputs[0]].type = schema_fb.TensorType.UINT8
721
+ elif inference_input_type in _MAP_QUANT_TO_IO_TYPES:
722
+ # Remove the inputs and the quant operator
723
+ remove_tensors_idxs = set()
724
+ for op in input_quant_ops:
725
+ subgraph.inputs[subgraph.inputs == op.inputs[0]] = op.outputs[0]
726
+ if signature_index >= 0:
727
+ signature_def = model.signatureDefs[signature_index]
728
+ for i in range(len(signature_def.inputs)):
729
+ if signature_def.inputs[i].tensorIndex == op.inputs[0]:
730
+ signature_def.inputs[i].tensorIndex = op.outputs[0]
731
+ remove_tensors_idxs.add(op.inputs[0])
732
+ operators.remove(op)
733
+ # Remove tensors marked for deletion.
734
+ _remove_tensors_from_model(model, remove_tensors_idxs)
735
+ else:
736
+ raise ValueError(
737
+ "Unsupported `inference_input_type` value {}.".format(
738
+ get_tf_type_name(inference_input_type)))
739
+
740
+
741
+ def _modify_model_output_type(model, inference_output_type=dtypes.float32):
742
+ """Modify model output type."""
743
+ if inference_output_type == dtypes.float32:
744
+ return
745
+
746
+ if not model.signatureDefs:
747
+ _modify_model_output_type_per_subgraph(model, 0, -1, inference_output_type)
748
+ return
749
+
750
+ for signature_index, signature_def in enumerate(model.signatureDefs):
751
+ _modify_model_output_type_per_subgraph(model, signature_def.subgraphIndex,
752
+ signature_index,
753
+ inference_output_type)
754
+
755
+
756
+ def _modify_model_output_type_per_subgraph(model, subgraph_index,
757
+ signature_index,
758
+ inference_output_type):
759
+ """Modify model output type per subgraph."""
760
+ subgraph = model.subgraphs[subgraph_index]
761
+ tensors = subgraph.tensors
762
+ operators = subgraph.operators
763
+
764
+ # Find all dequantize operators
765
+ dequant_opcode_idxs = get_dequantize_opcode_idx(model)
766
+ if operators and not dequant_opcode_idxs:
767
+ for output in subgraph.outputs:
768
+ output_type = _convert_tflite_enum_type_to_tf_type(tensors[output].type)
769
+ if output_type == dtypes.float32:
770
+ raise ValueError("Model output is not dequantized.")
771
+ # None of the outputs have float32, then they must be int16, int8, or bool
772
+ return
773
+
774
+ # Validate that the model output is dequantized
775
+ output_dequant_ops = []
776
+ for op in operators:
777
+ # Find operators that dequantize model output
778
+ if (op.opcodeIndex in dequant_opcode_idxs and
779
+ op.outputs[0] in subgraph.outputs):
780
+ # If found, validate that the operator's output type is float
781
+ quant_tensor, float_tensor = tensors[op.inputs[0]], tensors[op.outputs[0]]
782
+ float_type = _convert_tflite_enum_type_to_tf_type(float_tensor.type)
783
+ if float_type != dtypes.float32:
784
+ if float_type == inference_output_type:
785
+ continue
786
+ else:
787
+ raise ValueError(
788
+ "Initial model output type must be tf.float32. Expected type for "
789
+ "tensor with name '{}' is tf.float32, instead type is {}".format(
790
+ float_tensor.name, get_tf_type_name(float_type)))
791
+ # If found, validate that the operator input is quantized and compatible
792
+ # with the final model output type
793
+ quant_type = _convert_tflite_enum_type_to_tf_type(quant_tensor.type)
794
+ if quant_type not in _MAP_QUANT_TO_IO_TYPES:
795
+ raise ValueError(
796
+ "Initial model output is not dequantized. Expected type for "
797
+ "tensor with name '{}' should be in {}, instead type is {}".format(
798
+ quant_tensor.name,
799
+ tuple(get_tf_type_name(t) for t in
800
+ _MAP_QUANT_TO_IO_TYPES.keys()),
801
+ get_tf_type_name(quant_type)))
802
+ else:
803
+ inference_io_types = _MAP_QUANT_TO_IO_TYPES[quant_type]
804
+ if inference_output_type not in inference_io_types:
805
+ raise ValueError(
806
+ "Unsupported `inference_output_type` value. Expected to be in "
807
+ "{}, instead got {}.".format(
808
+ tuple(get_tf_type_name(t) for t in inference_io_types),
809
+ get_tf_type_name(inference_output_type)))
810
+ output_dequant_ops.append(op)
811
+
812
+ if len(subgraph.outputs) != len(output_dequant_ops):
813
+ logging.warning(
814
+ "For model outputs containing unsupported operations which cannot be "
815
+ "quantized, the `inference_output_type` attribute will default to the "
816
+ "original type."
817
+ )
818
+
819
+ # Modify model output type
820
+ if inference_output_type == dtypes.uint8:
821
+ # Find a quantize operator
822
+ quant_opcode_idx = -1
823
+ for idx, opcode in enumerate(model.operatorCodes):
824
+ builtin_code = schema_util.get_builtin_code_from_operator_code(opcode)
825
+ if builtin_code == schema_fb.BuiltinOperator.QUANTIZE:
826
+ quant_opcode_idx = idx
827
+ break
828
+ # Create a quantize operator, if none exist
829
+ if quant_opcode_idx == -1:
830
+ quant_op = schema_fb.OperatorCodeT()
831
+ quant_op.builtinCode = schema_fb.BuiltinOperator.QUANTIZE
832
+ quant_op.deprecatedBuiltinCode = schema_fb.BuiltinOperator.QUANTIZE
833
+ model.operatorCodes.append(quant_op)
834
+ quant_opcode_idx = len(model.operatorCodes) - 1
835
+ # Change dequant op (int8 to float) to quant op (int8 to uint8)
836
+ for op in output_dequant_ops:
837
+ op.opcodeIndex = quant_opcode_idx
838
+ int8_quantization = tensors[op.inputs[0]].quantization
839
+ uint8_quantization = schema_fb.QuantizationParametersT()
840
+ uint8_quantization.scale = [int8_quantization.scale[0]]
841
+ uint8_quantization.zeroPoint = [int8_quantization.zeroPoint[0] + 128]
842
+ tensors[op.outputs[0]].quantization = uint8_quantization
843
+ tensors[op.outputs[0]].type = schema_fb.TensorType.UINT8
844
+ elif inference_output_type in _MAP_QUANT_TO_IO_TYPES:
845
+ # Remove the outputs and the dequant operator
846
+ remove_tensors_idxs = set()
847
+ for op in output_dequant_ops:
848
+ subgraph.outputs[subgraph.outputs == op.outputs[0]] = op.inputs[0]
849
+ if signature_index >= 0:
850
+ signature_def = model.signatureDefs[signature_index]
851
+ for i in range(len(signature_def.outputs)):
852
+ if signature_def.outputs[i].tensorIndex == op.outputs[0]:
853
+ signature_def.outputs[i].tensorIndex = op.inputs[0]
854
+ remove_tensors_idxs.add(op.outputs[0])
855
+ operators.remove(op)
856
+ # Remove tensors marked for deletion.
857
+ _remove_tensors_from_model(model, remove_tensors_idxs)
858
+ else:
859
+ raise ValueError(
860
+ "Unsupported `inference_output_type` value {}.".format(
861
+ get_tf_type_name(inference_output_type)))
862
+
863
+
864
+ def _remove_redundant_quantize_ops(model):
865
+ """Finds back to back quantize ops and remove the first quantize op."""
866
+ if not model.signatureDefs:
867
+ _remove_redundant_quantize_ops_per_subgraph(model, 0, -1)
868
+ return
869
+
870
+ for signature_index, signature_def in enumerate(model.signatureDefs):
871
+ _remove_redundant_quantize_ops_per_subgraph(model,
872
+ signature_def.subgraphIndex,
873
+ signature_index)
874
+
875
+
876
+ def _remove_redundant_quantize_ops_per_subgraph(model, subgraph_index,
877
+ signature_index):
878
+ """Remove redundant quantize ops per subgraph."""
879
+ subgraph = model.subgraphs[subgraph_index]
880
+ tensors = subgraph.tensors
881
+ operators = subgraph.operators
882
+
883
+ # Find all quantize operators.
884
+ quant_opcode_idxs = get_quantize_opcode_idx(model)
885
+ dequant_opcode_idxs = get_dequantize_opcode_idx(model)
886
+
887
+ # Find all redundant quant tensors.
888
+ all_quant_ops = []
889
+ redundant_quant_tensors = {}
890
+ output_dequant_tensors = {}
891
+ for op in operators:
892
+ if op.opcodeIndex in quant_opcode_idxs:
893
+ all_quant_ops.append(op)
894
+ input_tensor = tensors[op.inputs[0]]
895
+ output_tensor = tensors[op.outputs[0]]
896
+ input_type = _convert_tflite_enum_type_to_tf_type(input_tensor.type)
897
+ output_type = _convert_tflite_enum_type_to_tf_type(output_tensor.type)
898
+ # This is a requantize op, so write down its input tensor index.
899
+ if input_type != dtypes.float32 and output_type != dtypes.float32:
900
+ redundant_quant_tensors[op.inputs[0]] = op
901
+ if (op.opcodeIndex in dequant_opcode_idxs and
902
+ op.outputs[0] in subgraph.outputs):
903
+ output_dequant_tensors[op.inputs[0]] = op
904
+
905
+ # Remove all the quant ops which produce the redundant quant tensors.
906
+ for op in all_quant_ops:
907
+ output_tensor_idx = op.outputs[0]
908
+ if output_tensor_idx in redundant_quant_tensors:
909
+ requantize_op = redundant_quant_tensors[output_tensor_idx]
910
+ if model.signatureDefs:
911
+ signature_def = model.signatureDefs[0]
912
+ for output in signature_def.outputs:
913
+ if output.tensorIndex == op.outputs[0]:
914
+ output.tensorIndex = op.inputs[0]
915
+ deleted_tensor = requantize_op.inputs[0]
916
+ # Reset the input of the requantize op to the float input
917
+ requantize_op.inputs[0] = op.inputs[0]
918
+ # Migrate other operator users to output tensor of requantize op
919
+ for op_user in operators:
920
+ if deleted_tensor in op_user.inputs and op_user != requantize_op:
921
+ for idx, input_tensor in enumerate(op_user.inputs):
922
+ if input_tensor == deleted_tensor:
923
+ op_user.inputs[idx] = requantize_op.outputs[0]
924
+ operators.remove(op)
925
+
926
+ # Remove all the quant ops which connect to the output dequant op.
927
+ for op in all_quant_ops:
928
+ output_tensor_idx = op.outputs[0]
929
+ if output_tensor_idx in output_dequant_tensors:
930
+ dequant_op = output_dequant_tensors[output_tensor_idx]
931
+ subgraph.outputs[subgraph.outputs == dequant_op.outputs[0]] = op.inputs[0]
932
+ if signature_index >= 0:
933
+ signature_def = model.signatureDefs[signature_index]
934
+ for output in signature_def.outputs:
935
+ if output.tensorIndex == dequant_op.outputs[0]:
936
+ output.tensorIndex = op.inputs[0]
937
+ operators.remove(op)
938
+ operators.remove(dequant_op)
939
+
940
+
941
+ def modify_model_io_type(
942
+ model, inference_input_type=dtypes.float32,
943
+ inference_output_type=dtypes.float32):
944
+ """Modify the input/output type of a tflite model.
945
+
946
+ Args:
947
+ model: A tflite model.
948
+ inference_input_type: tf.DType representing modified input type.
949
+ (default tf.float32. If model input is int8 quantized, it must be in
950
+ {tf.float32, tf.int8,tf.uint8}, else if model input is int16 quantized,
951
+ it must be in {tf.float32, tf.int16}, else it must be tf.float32)
952
+ inference_output_type: tf.DType representing modified output type.
953
+ (default tf.float32. If model output is int8 dequantized, it must be in
954
+ {tf.float32, tf.int8,tf.uint8}, else if model output is int16 dequantized,
955
+ it must be in {tf.float32, tf.int16}, else it must be tf.float32)
956
+ Returns:
957
+ A tflite model with modified input/output type.
958
+
959
+ Raises:
960
+ ValueError: If `inference_input_type`/`inference_output_type` is unsupported
961
+ or a supported integer type is specified for a model whose input/output is
962
+ not quantized/dequantized.
963
+ RuntimeError: If the modification was unsuccessful.
964
+
965
+ """
966
+ if (inference_input_type == dtypes.float32 and
967
+ inference_output_type == dtypes.float32):
968
+ return model
969
+
970
+ model_object = _convert_model_from_bytearray_to_object(model)
971
+
972
+ _modify_model_input_type(model_object, inference_input_type)
973
+
974
+ _modify_model_output_type(model_object, inference_output_type)
975
+
976
+ _remove_redundant_quantize_ops(model_object)
977
+
978
+ return _convert_model_from_object_to_bytearray(model_object)
979
+
980
+
981
+ def get_sparsity_modes(model_object):
982
+ """Get sparsity modes used in a tflite model.
983
+
984
+ The sparsity modes are listed in conversion_metadata.fbs file.
985
+
986
+ Args:
987
+ model_object: A tflite model in object form.
988
+
989
+ Returns:
990
+ The list of sparsity modes used in the model.
991
+ """
992
+ if not model_object or not model_object.metadata:
993
+ return []
994
+
995
+ result = set()
996
+ for subgraph in model_object.subgraphs:
997
+ for tensor in subgraph.tensors:
998
+ if not tensor.sparsity:
999
+ continue
1000
+
1001
+ # Block map is the list if indexes where the block size is larger than 1.
1002
+ # So empty block map means it is random sparsity.
1003
+ if not tensor.sparsity.blockMap:
1004
+ result.add(
1005
+ conversion_metadata_fb.ModelOptimizationMode.RANDOM_SPARSITY)
1006
+ else:
1007
+ result.add(
1008
+ conversion_metadata_fb.ModelOptimizationMode.BLOCK_SPARSITY)
1009
+
1010
+ return list(result)
1011
+
1012
+
1013
+ def get_model_hash(model):
1014
+ """Calculate a 64-bit integer hash for a TensorFlow Lite model based on its structure.
1015
+
1016
+ Args:
1017
+ model: A TensorFlow Lite model object.
1018
+
1019
+ Returns:
1020
+ int: A 64-bit integer hash value representing the model structure.
1021
+ """
1022
+ # TODO(b/344872922): Move the hashing implementation to C++ layer since not
1023
+ # all calls to the converter come via the Python API.
1024
+ hash_value = 0
1025
+
1026
+ for subgraph in model.subgraphs:
1027
+ if subgraph.operators is not None:
1028
+ hash_value = update_hash_with_primitive_value(
1029
+ hash_value, len(subgraph.operators)
1030
+ )
1031
+
1032
+ for operator in subgraph.operators:
1033
+ if operator.inputs is not None:
1034
+ hash_value = update_hash_with_array(hash_value, operator.inputs)
1035
+
1036
+ if operator.outputs is not None:
1037
+ hash_value = update_hash_with_array(hash_value, operator.outputs)
1038
+
1039
+ if subgraph.tensors is not None:
1040
+ hash_value = update_hash_with_primitive_value(
1041
+ hash_value, len(subgraph.tensors)
1042
+ )
1043
+
1044
+ for tensor in subgraph.tensors:
1045
+ if tensor.buffer is not None:
1046
+ buffer = model.buffers[tensor.buffer]
1047
+ if buffer.data is not None:
1048
+ hash_value = update_hash_with_primitive_value(
1049
+ hash_value, len(buffer.data)
1050
+ )
1051
+
1052
+ if tensor.shape is not None:
1053
+ hash_value = update_hash_with_array(hash_value, tensor.shape)
1054
+
1055
+ if subgraph.inputs is not None:
1056
+ hash_value = update_hash_with_primitive_value(
1057
+ hash_value, len(subgraph.inputs)
1058
+ )
1059
+
1060
+ if subgraph.outputs is not None:
1061
+ hash_value = update_hash_with_primitive_value(
1062
+ hash_value, len(subgraph.outputs)
1063
+ )
1064
+
1065
+ return hash_value
1066
+
1067
+
1068
+ def update_hash_with_primitive_value(hash_value, value):
1069
+ """Update the hash value using a primitive value.
1070
+
1071
+ Args:
1072
+ hash_value (uint64): The current hash value.
1073
+ value: The primitive value to incorporate into the hash.
1074
+
1075
+ Returns:
1076
+ int: The updated hash value.
1077
+ """
1078
+ hash_const = np.uint64(0x9E3779B97F4A7800)
1079
+ hash_value = np.uint64(hash_value)
1080
+ value = np.uint64(value)
1081
+
1082
+ # Convert to arrays before shifting.
1083
+ hash_value = np.array([hash_value])
1084
+ value = np.array([value])
1085
+
1086
+ # Shift the values, then take the value from the first index.
1087
+ hash_value = np.bitwise_xor(
1088
+ hash_value,
1089
+ (
1090
+ value
1091
+ + hash_const
1092
+ + np.left_shift(hash_value, 10)
1093
+ + np.right_shift(hash_value, 4)
1094
+ ),
1095
+ )[0]
1096
+
1097
+ return hash_value
1098
+
1099
+
1100
+ def update_hash_with_array(hash_value, int_array):
1101
+ """Update the hash value using a TFLite int array.
1102
+
1103
+ Args:
1104
+ hash_value (int): The current hash value.
1105
+ int_array: A TFLite int array to incorporate into the hash.
1106
+
1107
+ Returns:
1108
+ int: The updated hash value.
1109
+ """
1110
+ if int_array is not None:
1111
+ for i in int_array:
1112
+ hash_value = update_hash_with_primitive_value(hash_value, i)
1113
+ return hash_value
1114
+
1115
+
1116
+ def populate_conversion_metadata(model_object, metadata):
1117
+ """Add or update conversion metadata to a tflite model.
1118
+
1119
+ Args:
1120
+ model_object: A tflite model in object form.
1121
+ metadata: The conversion metadata.
1122
+
1123
+ Returns:
1124
+ A tflite model object with embedded conversion metadata.
1125
+ """
1126
+ try:
1127
+ metadata_builder = flatbuffers.Builder(0)
1128
+ metadata_builder.Finish(metadata.Pack(metadata_builder))
1129
+ buffer_field = schema_fb.BufferT()
1130
+ buffer_field.data = metadata_builder.Output()
1131
+
1132
+ if not model_object.metadata:
1133
+ model_object.metadata = []
1134
+ else:
1135
+ # Check if metadata has already been populated.
1136
+ for meta in model_object.metadata:
1137
+ if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
1138
+ model_object.buffers[meta.buffer] = buffer_field
1139
+ return model_object
1140
+
1141
+ if not model_object.buffers:
1142
+ model_object.buffers = []
1143
+ model_object.buffers.append(buffer_field)
1144
+ # Creates a new metadata field.
1145
+ metadata_field = schema_fb.MetadataT()
1146
+ metadata_field.name = CONVERSION_METADATA_FIELD_NAME
1147
+ metadata_field.buffer = len(model_object.buffers) - 1
1148
+ model_object.metadata.append(metadata_field)
1149
+
1150
+ return model_object
1151
+ except Exception: # pylint: disable=broad-except
1152
+ return model_object
1153
+
1154
+
1155
+ def get_conversion_metadata(model_buffer):
1156
+ """Read conversion metadata from a tflite model.
1157
+
1158
+ Args:
1159
+ model_buffer: A tflite model.
1160
+
1161
+ Returns:
1162
+ The conversion metadata or None if it is not populated.
1163
+ """
1164
+ model_object = flatbuffer_utils.convert_bytearray_to_object(model_buffer)
1165
+ if not model_object or not model_object.metadata:
1166
+ return None
1167
+
1168
+ for meta in model_object.metadata:
1169
+ if meta.name.decode("utf-8") == CONVERSION_METADATA_FIELD_NAME:
1170
+ metadata_buf = model_object.buffers[meta.buffer].data.tobytes()
1171
+ return conversion_metadata_fb.ConversionMetadataT.InitFromObj(
1172
+ conversion_metadata_fb.ConversionMetadata.GetRootAsConversionMetadata(
1173
+ metadata_buf, 0
1174
+ )
1175
+ )
1176
+
1177
+ return None
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (197 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/flatbuffer_utils.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/__pycache__/visualize.cpython-310.pyc ADDED
Binary file (14.4 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/flatbuffer_utils.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utility functions for FlatBuffers.
16
+
17
+ All functions that are commonly used to work with FlatBuffers.
18
+
19
+ Refer to the tensorflow lite flatbuffer schema here:
20
+ tensorflow/lite/schema/schema.fbs
21
+ """
22
+
23
+ import copy
24
+ import random
25
+ import re
26
+ import struct
27
+ import sys
28
+
29
+ import flatbuffers
30
+
31
+ from tensorflow.lite.python import schema_py_generated as schema_fb
32
+ from tensorflow.lite.python import schema_util
33
+ from tensorflow.python.platform import gfile
34
+
35
+ _TFLITE_FILE_IDENTIFIER = b'TFL3'
36
+
37
+
38
+ def convert_bytearray_to_object(model_bytearray):
39
+ """Converts a tflite model from a bytearray to an object for parsing."""
40
+ model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
41
+ return schema_fb.ModelT.InitFromObj(model_object)
42
+
43
+
44
+ def read_model(input_tflite_file):
45
+ """Reads a tflite model as a python object.
46
+
47
+ Args:
48
+ input_tflite_file: Full path name to the input tflite file
49
+
50
+ Raises:
51
+ RuntimeError: If input_tflite_file path is invalid.
52
+ IOError: If input_tflite_file cannot be opened.
53
+
54
+ Returns:
55
+ A python object corresponding to the input tflite file.
56
+ """
57
+ if not gfile.Exists(input_tflite_file):
58
+ raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
59
+ with gfile.GFile(input_tflite_file, 'rb') as input_file_handle:
60
+ model_bytearray = bytearray(input_file_handle.read())
61
+ return read_model_from_bytearray(model_bytearray)
62
+
63
+
64
+ def read_model_from_bytearray(model_bytearray):
65
+ """Reads a tflite model as a python object.
66
+
67
+ Args:
68
+ model_bytearray: TFLite model in bytearray format.
69
+
70
+ Returns:
71
+ A python object corresponding to the input tflite file.
72
+ """
73
+ model = convert_bytearray_to_object(model_bytearray)
74
+ if sys.byteorder == 'big':
75
+ byte_swap_tflite_model_obj(model, 'little', 'big')
76
+
77
+ # Offset handling for models > 2GB
78
+ for buffer in model.buffers:
79
+ if buffer.offset:
80
+ buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size]
81
+ buffer.offset = 0
82
+ buffer.size = 0
83
+ for subgraph in model.subgraphs:
84
+ for op in subgraph.operators:
85
+ if op.largeCustomOptionsOffset:
86
+ op.customOptions = model_bytearray[
87
+ op.largeCustomOptionsOffset : op.largeCustomOptionsOffset
88
+ + op.largeCustomOptionsSize
89
+ ]
90
+ op.largeCustomOptionsOffset = 0
91
+ op.largeCustomOptionsSize = 0
92
+
93
+ return model
94
+
95
+
96
+ def read_model_with_mutable_tensors(input_tflite_file):
97
+ """Reads a tflite model as a python object with mutable tensors.
98
+
99
+ Similar to read_model() with the addition that the returned object has
100
+ mutable tensors (read_model() returns an object with immutable tensors).
101
+
102
+ NOTE: This API only works for TFLite generated with
103
+ _experimental_use_buffer_offset=false
104
+
105
+ Args:
106
+ input_tflite_file: Full path name to the input tflite file
107
+
108
+ Raises:
109
+ RuntimeError: If input_tflite_file path is invalid.
110
+ IOError: If input_tflite_file cannot be opened.
111
+
112
+ Returns:
113
+ A mutable python object corresponding to the input tflite file.
114
+ """
115
+ return copy.deepcopy(read_model(input_tflite_file))
116
+
117
+
118
+ def convert_object_to_bytearray(model_object, extra_buffer=b''):
119
+ """Converts a tflite model from an object to a immutable bytearray."""
120
+ # Initial size of the buffer, which will grow automatically if needed
121
+ builder = flatbuffers.Builder(1024)
122
+ model_offset = model_object.Pack(builder)
123
+ builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
124
+ model_bytearray = bytes(builder.Output())
125
+ model_bytearray = model_bytearray + extra_buffer
126
+ return model_bytearray
127
+
128
+
129
+ def write_model(model_object, output_tflite_file):
130
+ """Writes the tflite model, a python object, into the output file.
131
+
132
+ NOTE: This API only works for TFLite generated with
133
+ _experimental_use_buffer_offset=false
134
+
135
+ Args:
136
+ model_object: A tflite model as a python object
137
+ output_tflite_file: Full path name to the output tflite file.
138
+
139
+ Raises:
140
+ IOError: If output_tflite_file path is invalid or cannot be opened.
141
+ """
142
+ if sys.byteorder == 'big':
143
+ model_object = copy.deepcopy(model_object)
144
+ byte_swap_tflite_model_obj(model_object, 'big', 'little')
145
+ model_bytearray = convert_object_to_bytearray(model_object)
146
+ with gfile.GFile(output_tflite_file, 'wb') as output_file_handle:
147
+ output_file_handle.write(model_bytearray)
148
+
149
+
150
+ def strip_strings(model):
151
+ """Strips all nonessential strings from the model to reduce model size.
152
+
153
+ We remove the following strings:
154
+ (find strings by searching ":string" in the tensorflow lite flatbuffer schema)
155
+ 1. Model description
156
+ 2. SubGraph name
157
+ 3. Tensor names
158
+ We retain OperatorCode custom_code and Metadata name.
159
+
160
+ Args:
161
+ model: The model from which to remove nonessential strings.
162
+ """
163
+
164
+ model.description = None
165
+ for subgraph in model.subgraphs:
166
+ subgraph.name = None
167
+ for tensor in subgraph.tensors:
168
+ tensor.name = None
169
+ # We clear all signature_def structure, since without names it is useless.
170
+ model.signatureDefs = None
171
+
172
+
173
+ def type_to_name(tensor_type):
174
+ """Converts a numerical enum to a readable tensor type."""
175
+ for name, value in schema_fb.TensorType.__dict__.items():
176
+ if value == tensor_type:
177
+ return name
178
+ return None
179
+
180
+
181
+ def randomize_weights(model, random_seed=0, buffers_to_skip=None):
182
+ """Randomize weights in a model.
183
+
184
+ Args:
185
+ model: The model in which to randomize weights.
186
+ random_seed: The input to the random number generator (default value is 0).
187
+ buffers_to_skip: The list of buffer indices to skip. The weights in these
188
+ buffers are left unmodified.
189
+ """
190
+
191
+ # The input to the random seed generator. The default value is 0.
192
+ random.seed(random_seed)
193
+
194
+ # Parse model buffers which store the model weights
195
+ buffers = model.buffers
196
+ buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None
197
+ if buffers_to_skip is not None:
198
+ buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
199
+
200
+ buffer_types = {}
201
+ for graph in model.subgraphs:
202
+ for op in graph.operators:
203
+ if op.inputs is None:
204
+ break
205
+ for input_idx in op.inputs:
206
+ tensor = graph.tensors[input_idx]
207
+ buffer_types[tensor.buffer] = type_to_name(tensor.type)
208
+
209
+ for i in buffer_ids:
210
+ buffer_i_data = buffers[i].data
211
+ buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
212
+ if buffer_i_size == 0:
213
+ continue
214
+
215
+ # Raw data buffers are of type ubyte (or uint8) whose values lie in the
216
+ # range [0, 255]. Those ubytes (or unint8s) are the underlying
217
+ # representation of each datatype. For example, a bias tensor of type
218
+ # int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
219
+ # For floats, we need to generate a valid float and then pack it into
220
+ # the raw bytes in place.
221
+ buffer_type = buffer_types.get(i, 'INT8')
222
+ if buffer_type.startswith('FLOAT'):
223
+ format_code = 'e' if buffer_type == 'FLOAT16' else 'f'
224
+ for offset in range(0, buffer_i_size, struct.calcsize(format_code)):
225
+ value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2
226
+ struct.pack_into(format_code, buffer_i_data, offset, value)
227
+ else:
228
+ for j in range(buffer_i_size):
229
+ buffer_i_data[j] = random.randint(0, 255)
230
+
231
+
232
+ def rename_custom_ops(model, map_custom_op_renames):
233
+ """Rename custom ops so they use the same naming style as builtin ops.
234
+
235
+ Args:
236
+ model: The input tflite model.
237
+ map_custom_op_renames: A mapping from old to new custom op names.
238
+ """
239
+ for op_code in model.operatorCodes:
240
+ if op_code.customCode:
241
+ op_code_str = op_code.customCode.decode('ascii')
242
+ if op_code_str in map_custom_op_renames:
243
+ op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
244
+
245
+
246
+ def opcode_to_name(model, op_code):
247
+ """Converts a TFLite op_code to the human readable name.
248
+
249
+ Args:
250
+ model: The input tflite model.
251
+ op_code: The op_code to resolve to a readable name.
252
+
253
+ Returns:
254
+ A string containing the human readable op name, or None if not resolvable.
255
+ """
256
+ op = model.operatorCodes[op_code]
257
+ code = max(op.builtinCode, op.deprecatedBuiltinCode)
258
+ for name, value in vars(schema_fb.BuiltinOperator).items():
259
+ if value == code:
260
+ return name
261
+ return None
262
+
263
+
264
+ def xxd_output_to_bytes(input_cc_file):
265
+ """Converts xxd output C++ source file to bytes (immutable).
266
+
267
+ Args:
268
+ input_cc_file: Full path name to th C++ source file dumped by xxd
269
+
270
+ Raises:
271
+ RuntimeError: If input_cc_file path is invalid.
272
+ IOError: If input_cc_file cannot be opened.
273
+
274
+ Returns:
275
+ A bytearray corresponding to the input cc file array.
276
+ """
277
+ # Match hex values in the string with comma as separator
278
+ pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
279
+
280
+ model_bytearray = bytearray()
281
+
282
+ with open(input_cc_file) as file_handle:
283
+ for line in file_handle:
284
+ values_match = pattern.match(line)
285
+
286
+ if values_match is None:
287
+ continue
288
+
289
+ # Match in the parentheses (hex array only)
290
+ list_text = values_match.group(1)
291
+
292
+ # Extract hex values (text) from the line
293
+ # e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
294
+ values_text = filter(None, list_text.split(','))
295
+
296
+ # Convert to hex
297
+ values = [int(x, base=16) for x in values_text]
298
+ model_bytearray.extend(values)
299
+
300
+ return bytes(model_bytearray)
301
+
302
+
303
+ def xxd_output_to_object(input_cc_file):
304
+ """Converts xxd output C++ source file to object.
305
+
306
+ Args:
307
+ input_cc_file: Full path name to th C++ source file dumped by xxd
308
+
309
+ Raises:
310
+ RuntimeError: If input_cc_file path is invalid.
311
+ IOError: If input_cc_file cannot be opened.
312
+
313
+ Returns:
314
+ A python object corresponding to the input tflite file.
315
+ """
316
+ model_bytes = xxd_output_to_bytes(input_cc_file)
317
+ return convert_bytearray_to_object(model_bytes)
318
+
319
+
320
+ def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness):
321
+ """Helper function for byte-swapping the buffers field."""
322
+ to_swap = [
323
+ buffer.data[i : i + chunksize]
324
+ for i in range(0, len(buffer.data), chunksize)
325
+ ]
326
+ buffer.data = b''.join([
327
+ int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness)
328
+ for byteswap in to_swap
329
+ ])
330
+
331
+
332
+ def byte_swap_string_content(buffer, from_endiness, to_endiness):
333
+ """Helper function for byte-swapping the string buffer.
334
+
335
+ Args:
336
+ buffer: TFLite string buffer of from_endiness format.
337
+ from_endiness: The original endianness format of the string buffer.
338
+ to_endiness: The destined endianness format of the string buffer.
339
+ """
340
+ num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness)
341
+ string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :])
342
+ prefix_data = b''.join([
343
+ int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes(
344
+ 4, to_endiness
345
+ )
346
+ for i in range(0, (num_of_strings + 1) * 4 + 1, 4)
347
+ ])
348
+ buffer.data = prefix_data + string_content
349
+
350
+
351
+ def byte_swap_tflite_model_obj(model, from_endiness, to_endiness):
352
+ """Byte swaps the buffers field in a TFLite model.
353
+
354
+ Args:
355
+ model: TFLite model object of from_endiness format.
356
+ from_endiness: The original endianness format of the buffers in model.
357
+ to_endiness: The destined endianness format of the buffers in model.
358
+ """
359
+ if model is None:
360
+ return
361
+ # Get all the constant buffers, byte swapping them as per their data types
362
+ buffer_swapped = []
363
+ types_of_16_bits = [
364
+ schema_fb.TensorType.FLOAT16,
365
+ schema_fb.TensorType.INT16,
366
+ schema_fb.TensorType.UINT16,
367
+ ]
368
+ types_of_32_bits = [
369
+ schema_fb.TensorType.FLOAT32,
370
+ schema_fb.TensorType.INT32,
371
+ schema_fb.TensorType.COMPLEX64,
372
+ schema_fb.TensorType.UINT32,
373
+ ]
374
+ types_of_64_bits = [
375
+ schema_fb.TensorType.INT64,
376
+ schema_fb.TensorType.FLOAT64,
377
+ schema_fb.TensorType.COMPLEX128,
378
+ schema_fb.TensorType.UINT64,
379
+ ]
380
+ for subgraph in model.subgraphs:
381
+ for tensor in subgraph.tensors:
382
+ if (
383
+ tensor.buffer > 0
384
+ and tensor.buffer < len(model.buffers)
385
+ and tensor.buffer not in buffer_swapped
386
+ and model.buffers[tensor.buffer].data is not None
387
+ ):
388
+ if tensor.type == schema_fb.TensorType.STRING:
389
+ byte_swap_string_content(
390
+ model.buffers[tensor.buffer], from_endiness, to_endiness
391
+ )
392
+ elif tensor.type in types_of_16_bits:
393
+ byte_swap_buffer_content(
394
+ model.buffers[tensor.buffer], 2, from_endiness, to_endiness
395
+ )
396
+ elif tensor.type in types_of_32_bits:
397
+ byte_swap_buffer_content(
398
+ model.buffers[tensor.buffer], 4, from_endiness, to_endiness
399
+ )
400
+ elif tensor.type in types_of_64_bits:
401
+ byte_swap_buffer_content(
402
+ model.buffers[tensor.buffer], 8, from_endiness, to_endiness
403
+ )
404
+ else:
405
+ continue
406
+ buffer_swapped.append(tensor.buffer)
407
+
408
+
409
+ def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness):
410
+ """Generates a new model byte array after byte swapping its buffers field.
411
+
412
+ Args:
413
+ tflite_model: TFLite flatbuffer in a byte array.
414
+ from_endiness: The original endianness format of the buffers in
415
+ tflite_model.
416
+ to_endiness: The destined endianness format of the buffers in tflite_model.
417
+
418
+ Returns:
419
+ TFLite flatbuffer in a byte array, after being byte swapped to to_endiness
420
+ format.
421
+ """
422
+ if tflite_model is None:
423
+ return None
424
+ # Load TFLite Flatbuffer byte array into an object.
425
+ model = convert_bytearray_to_object(tflite_model)
426
+
427
+ # Byte swapping the constant buffers as per their data types
428
+ byte_swap_tflite_model_obj(model, from_endiness, to_endiness)
429
+
430
+ # Return a TFLite flatbuffer as a byte array.
431
+ return convert_object_to_bytearray(model)
432
+
433
+
434
+ def count_resource_variables(model):
435
+ """Calculates the number of unique resource variables in a model.
436
+
437
+ Args:
438
+ model: the input tflite model, either as bytearray or object.
439
+
440
+ Returns:
441
+ An integer number representing the number of unique resource variables.
442
+ """
443
+ if not isinstance(model, schema_fb.ModelT):
444
+ model = convert_bytearray_to_object(model)
445
+ unique_shared_names = set()
446
+ for subgraph in model.subgraphs:
447
+ if subgraph.operators is None:
448
+ continue
449
+ for op in subgraph.operators:
450
+ builtin_code = schema_util.get_builtin_code_from_operator_code(
451
+ model.operatorCodes[op.opcodeIndex]
452
+ )
453
+ if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
454
+ unique_shared_names.add(op.builtinOptions.sharedName)
455
+ return len(unique_shared_names)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (206 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (216 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (223 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/__pycache__/debugger.cpython-310.pyc ADDED
Binary file (18.4 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/optimize/debugging/python/debugger.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Python TF-Lite QuantizationDebugger."""
16
+ import collections
17
+ import csv
18
+ import re
19
+ from typing import (Any, Callable, Dict, IO, Iterable, List, Mapping, Optional,
20
+ Sequence, Tuple)
21
+
22
+ import numpy as np
23
+
24
+ from tensorflow.lite.python import convert
25
+ from tensorflow.lite.python import interpreter as _interpreter
26
+ from tensorflow.lite.python.metrics import metrics as metrics_stub # type: ignore
27
+ from tensorflow.python.util import tf_export
28
+
29
+
30
+ # TODO(b/198099651): move converter implementation out of lite.py
31
+ TFLiteConverter = Any # importing tf.lite creates circular dependency
32
+
33
+ # Returns metrics based on difference of values for quantized/float ops.
34
+ _DEFAULT_LAYER_DEBUG_METRICS = {
35
+ 'num_elements': lambda diffs: diffs.size,
36
+ 'stddev': np.std,
37
+ 'mean_error': np.average,
38
+ 'max_abs_error': lambda diffs: np.max(np.abs(diffs)),
39
+ 'mean_squared_error': lambda diffs: np.average(diffs**2),
40
+ }
41
+
42
+ _NUMERIC_VERIFY_OP_NAME = 'NumericVerify'
43
+
44
+
45
+ def _get_quant_params(
46
+ tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]:
47
+ """Returns first scale and zero point from tensor detail, if present."""
48
+ quant_params = tensor_detail['quantization_parameters']
49
+ if not quant_params:
50
+ return None
51
+ if quant_params['scales'] and quant_params['zero_points']:
52
+ return (quant_params['scales'][0], quant_params['zero_points'][0])
53
+ return None
54
+
55
+
56
+ @tf_export.tf_export('lite.experimental.QuantizationDebugOptions')
57
+ class QuantizationDebugOptions:
58
+ """Debug options to set up a given QuantizationDebugger."""
59
+
60
+ def __init__(self,
61
+ layer_debug_metrics: Optional[Mapping[str,
62
+ Callable[[np.ndarray],
63
+ float]]] = None,
64
+ model_debug_metrics: Optional[Mapping[
65
+ str, Callable[[Sequence[np.ndarray], Sequence[np.ndarray]],
66
+ float]]] = None,
67
+ layer_direct_compare_metrics: Optional[Mapping[str, Callable[
68
+ [Sequence[np.ndarray], Sequence[np.ndarray], float, int],
69
+ float]]] = None,
70
+ denylisted_ops: Optional[List[str]] = None,
71
+ denylisted_nodes: Optional[List[str]] = None,
72
+ fully_quantize: bool = False) -> None:
73
+ """Initializes debugger options.
74
+
75
+ Args:
76
+ layer_debug_metrics: a dict to specify layer debug functions
77
+ {function_name_str: function} where the function accepts result of
78
+ NumericVerify Op, which is value difference between float and
79
+ dequantized op results. The function returns single scalar value.
80
+ model_debug_metrics: a dict to specify model debug functions
81
+ {function_name_str: function} where the function accepts outputs from
82
+ two models, and returns single scalar value for a metric. (e.g.
83
+ accuracy, IoU)
84
+ layer_direct_compare_metrics: a dict to specify layer debug functions
85
+ {function_name_str: function}. The signature is different from that of
86
+ `layer_debug_metrics`, and this one gets passed (original float value,
87
+ original quantized value, scale, zero point). The function's
88
+ implementation is responsible for correctly dequantize the quantized
89
+ value to compare. Use this one when comparing diff is not enough.
90
+ (Note) quantized value is passed as int8, so cast to int32 is needed.
91
+ denylisted_ops: a list of op names which is expected to be removed from
92
+ quantization.
93
+ denylisted_nodes: a list of op's output tensor names to be removed from
94
+ quantization.
95
+ fully_quantize: Bool indicating whether to fully quantize the model.
96
+ Besides model body, the input/output will be quantized as well.
97
+ Corresponding to mlir_quantize's fully_quantize parameter.
98
+
99
+ Raises:
100
+ ValueError: when there are duplicate keys
101
+ """
102
+ self.layer_debug_metrics = layer_debug_metrics
103
+ self.model_debug_metrics = model_debug_metrics
104
+ self.layer_direct_compare_metrics = layer_direct_compare_metrics
105
+
106
+ keys = []
107
+ for metrics in [
108
+ layer_debug_metrics, model_debug_metrics, layer_direct_compare_metrics
109
+ ]:
110
+ if metrics is not None:
111
+ keys.extend(metrics.keys())
112
+ if len(keys) != len(set(keys)):
113
+ raise ValueError('Provided metrics have duplicate keys.')
114
+
115
+ self.denylisted_ops = denylisted_ops
116
+ self.denylisted_nodes = denylisted_nodes
117
+ self.fully_quantize = fully_quantize
118
+
119
+
120
+ @tf_export.tf_export('lite.experimental.QuantizationDebugger')
121
+ class QuantizationDebugger:
122
+ """Debugger for Quantized TensorFlow Lite debug mode models.
123
+
124
+ This can run the TensorFlow Lite converted models equipped with debug ops and
125
+ collect debug information. This debugger calculates statistics from
126
+ user-defined post-processing functions as well as default ones.
127
+ """
128
+
129
+ def __init__(self,
130
+ quant_debug_model_path: Optional[str] = None,
131
+ quant_debug_model_content: Optional[bytes] = None,
132
+ float_model_path: Optional[str] = None,
133
+ float_model_content: Optional[bytes] = None,
134
+ debug_dataset: Optional[Callable[
135
+ [], Iterable[Sequence[np.ndarray]]]] = None,
136
+ debug_options: Optional[QuantizationDebugOptions] = None,
137
+ converter: Optional[TFLiteConverter] = None) -> None:
138
+ """Runs the TFLite debugging model with given debug options.
139
+
140
+ Args:
141
+ quant_debug_model_path: Path to the quantized debug TFLite model file.
142
+ quant_debug_model_content: Content of the quantized debug TFLite model.
143
+ float_model_path: Path to float TFLite model file.
144
+ float_model_content: Content of the float TFLite model.
145
+ debug_dataset: a factory function that returns dataset generator which is
146
+ used to generate input samples (list of np.ndarray) for the model. The
147
+ generated elements must have same types and shape as inputs to the
148
+ model.
149
+ debug_options: Debug options to debug the given model.
150
+ converter: Optional, use converter instead of quantized model.
151
+
152
+ Raises:
153
+ ValueError: If the debugger was unable to be created.
154
+
155
+ Attributes:
156
+ layer_statistics: results of error metrics for each NumericVerify op
157
+ results. in {layer_name: {metric_name: metric}} format.
158
+ model_statistics: results of error metrics for difference between float
159
+ and quantized models. in {metric_name: metric} format.
160
+ """
161
+ self._data_gen = debug_dataset
162
+ self._debug_options = debug_options or QuantizationDebugOptions()
163
+ self.converter = None
164
+ self.calibrated_model = None
165
+ self.float_model = None
166
+ self._float_interpreter = None
167
+ if converter is not None:
168
+ if self._debug_options.model_debug_metrics:
169
+ old_optimizations = converter.optimizations
170
+ self.converter = self._set_converter_options_for_float(converter)
171
+ self.float_model = self.converter.convert()
172
+ converter.optimizations = old_optimizations
173
+
174
+ self.converter = self._set_converter_options_for_calibration(converter)
175
+ self.calibrated_model = self.converter.convert()
176
+ # Converter should be already set up with all options
177
+ self._init_from_converter(
178
+ self._debug_options,
179
+ self.converter,
180
+ self.calibrated_model,
181
+ float_model=self.float_model)
182
+ else:
183
+ self._quant_interpreter = _interpreter.Interpreter(
184
+ quant_debug_model_path,
185
+ quant_debug_model_content,
186
+ experimental_preserve_all_tensors=(
187
+ self._debug_options.layer_direct_compare_metrics is not None))
188
+ if self._debug_options.model_debug_metrics:
189
+ self._float_interpreter = _interpreter.Interpreter(
190
+ float_model_path, float_model_content)
191
+ self._initialize_stats()
192
+
193
+ @property
194
+ def options(self) -> QuantizationDebugOptions:
195
+ return self._debug_options
196
+
197
+ @options.setter
198
+ def options(self, options: QuantizationDebugOptions) -> None:
199
+ self._debug_options = options
200
+ if not self.converter or not self.calibrated_model:
201
+ return
202
+ self._init_from_converter(
203
+ self._debug_options,
204
+ self.converter,
205
+ self.calibrated_model,
206
+ float_model=self.float_model)
207
+ self._initialize_stats()
208
+
209
+ def _initialize_stats(self):
210
+ """Helper function initializes stats."""
211
+ # TODO(b/177749613) : Fix the dependency on tf.lite._get_ops_details()
212
+ # Following code is needed to get op's name from the output tensor index,
213
+ # since NumericVerify op only provides its quantized input tensor index.
214
+ self._defining_op = dict()
215
+ for op_info in self._quant_interpreter._get_ops_details(): # pylint: disable=protected-access
216
+ self._defining_op.update(
217
+ {tensor_idx: op_info['index'] for tensor_idx in op_info['outputs']})
218
+
219
+ self._numeric_verify_tensor_details = None
220
+ self._numeric_verify_op_details = None
221
+ if not self._get_numeric_verify_tensor_details():
222
+ raise ValueError('Please check if the quantized model is in debug mode')
223
+
224
+ self._layer_debug_metrics = _DEFAULT_LAYER_DEBUG_METRICS.copy()
225
+ if self._debug_options.layer_debug_metrics:
226
+ self._layer_debug_metrics.update(self._debug_options.layer_debug_metrics)
227
+
228
+ self.layer_statistics = None
229
+ self.model_statistics = None
230
+
231
+ self._metrics = metrics_stub.TFLiteMetrics()
232
+ self._metrics.increase_counter_debugger_creation()
233
+
234
+ def _get_quantized_model(self, is_debug: bool) -> bytes:
235
+ if not self.converter:
236
+ raise ValueError('No converter found, use this function with the '
237
+ 'converter option in the constructor.')
238
+
239
+ return convert.mlir_quantize(
240
+ self.calibrated_model,
241
+ disable_per_channel=self.converter._experimental_disable_per_channel, # pylint: disable=protected-access
242
+ fully_quantize=self._debug_options.fully_quantize,
243
+ enable_numeric_verify=is_debug,
244
+ denylisted_ops=self._debug_options.denylisted_ops,
245
+ denylisted_nodes=self._debug_options.denylisted_nodes)
246
+
247
+ def get_nondebug_quantized_model(self) -> bytes:
248
+ """Returns a non-instrumented quantized model.
249
+
250
+ Convert the quantized model with the initialized converter and
251
+ return bytes for nondebug model. The model will not be instrumented with
252
+ numeric verification operations.
253
+
254
+ Returns:
255
+ Model bytes corresponding to the model.
256
+ Raises:
257
+ ValueError: if converter is not passed to the debugger.
258
+ """
259
+ return self._get_quantized_model(is_debug=False)
260
+
261
+ def get_debug_quantized_model(self) -> bytes:
262
+ """Returns an instrumented quantized model.
263
+
264
+ Convert the quantized model with the initialized converter and
265
+ return bytes for model. The model will be instrumented with numeric
266
+ verification operations and should only be used for debugging.
267
+
268
+ Returns:
269
+ Model bytes corresponding to the model.
270
+ Raises:
271
+ ValueError: if converter is not passed to the debugger.
272
+ """
273
+ return self._get_quantized_model(is_debug=True)
274
+
275
+ def _init_from_converter(self,
276
+ options: QuantizationDebugOptions,
277
+ converter: TFLiteConverter,
278
+ calibrated_model: Optional[bytes] = None,
279
+ float_model: Optional[bytes] = None) -> None:
280
+ """Convert the model and apply options.
281
+
282
+ Converts the quantized model and initializes a quantized model interpreter
283
+ with the quantized model. Returns a float model interpreter if float model
284
+ is provided.
285
+
286
+ Args:
287
+ options: a QuantizationDebugOptions object.
288
+ converter: an initialized tf.lite.TFLiteConverter.
289
+ calibrated_model: Calibrated model bytes.
290
+ float_model: Float model bytes.
291
+ """
292
+ self.quant_model = convert.mlir_quantize(
293
+ calibrated_model,
294
+ disable_per_channel=converter._experimental_disable_per_channel, # pylint: disable=protected-access
295
+ fully_quantize=options.fully_quantize,
296
+ enable_numeric_verify=True,
297
+ denylisted_ops=options.denylisted_ops,
298
+ denylisted_nodes=options.denylisted_nodes)
299
+ self._quant_interpreter = _interpreter.Interpreter(
300
+ model_content=self.quant_model)
301
+ self._float_interpreter = None
302
+ if float_model is not None:
303
+ self._float_interpreter = _interpreter.Interpreter(
304
+ model_content=float_model)
305
+
306
+ def _set_converter_options_for_float(
307
+ self, converter: TFLiteConverter) -> TFLiteConverter:
308
+ """Verify converter options and set required experimental options."""
309
+ if converter.optimizations:
310
+ converter.optimizations = []
311
+ return converter
312
+
313
+ def _set_converter_options_for_calibration(
314
+ self, converter: TFLiteConverter) -> TFLiteConverter:
315
+ """Verify converter options and set required experimental options."""
316
+ if not converter.optimizations:
317
+ raise ValueError(
318
+ 'converter object must set optimizations to lite.Optimize.DEFAULT')
319
+ if not converter.representative_dataset:
320
+ raise ValueError('converter object must set representative_dataset')
321
+
322
+ converter.experimental_mlir_quantizer = True
323
+ converter._experimental_calibrate_only = True # pylint: disable=protected-access
324
+ return converter
325
+
326
+ def run(self) -> None:
327
+ """Runs models and gets metrics."""
328
+ self.layer_statistics = self._collect_layer_statistics()
329
+ if self._debug_options.model_debug_metrics:
330
+ self.model_statistics = self._collect_model_statistics()
331
+
332
+ def _collect_layer_statistics(self) -> Dict[str, Dict[str, float]]:
333
+ """Collects layer statistics by applying layer debug metrics.
334
+
335
+ For all data from the given RepresentativeDataset, collect statistics per
336
+ example by getting the NumericVerify op results in _quant_interpreter
337
+ and calculating layer debug metrics on the results.
338
+
339
+ Returns:
340
+ aggregated per-layer statistics of NumericVerify results.
341
+ {layer_name: {metric_name: metric}}
342
+ """
343
+ layer_statistics = collections.defaultdict(
344
+ lambda: collections.defaultdict(list))
345
+
346
+ initialize = True
347
+ for tensor_data in self._data_gen():
348
+ self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
349
+ initialize = False
350
+
351
+ # Run the model.
352
+ self._quant_interpreter.invoke()
353
+
354
+ # Collect the statistics of this invoke result.
355
+ for tensor_detail in self._get_numeric_verify_tensor_details():
356
+ tensor_name = tensor_detail['name'] # pytype: disable=unsupported-operands # dynamic-method-lookup
357
+ diffs = self._quant_interpreter.get_tensor(tensor_detail['index']) # pytype: disable=unsupported-operands # dynamic-method-lookup
358
+ for metric_name, metric_fn in self._layer_debug_metrics.items():
359
+ layer_statistics[tensor_name][metric_name].append(metric_fn(diffs))
360
+
361
+ if self._debug_options.layer_direct_compare_metrics is not None:
362
+ for tensor_detail in self._get_numeric_verify_tensor_details():
363
+ tensor_name = tensor_detail['name'] # pytype: disable=unsupported-operands # dynamic-method-lookup
364
+ op_idx = self._defining_op[tensor_detail['index']] # pytype: disable=unsupported-operands # dynamic-method-lookup
365
+ op_detail = self._quant_interpreter._get_op_details(op_idx) # pylint: disable=protected-access
366
+ q_idx, f_idx = op_detail['inputs']
367
+ quant_input_detail = self._quant_interpreter._get_tensor_details( # pylint: disable=protected-access
368
+ q_idx, subgraph_index=0)
369
+ for (metric_name, metric_fn
370
+ ) in self._debug_options.layer_direct_compare_metrics.items():
371
+ layer_statistics[tensor_name][metric_name].append(
372
+ metric_fn(
373
+ self._quant_interpreter.get_tensor(f_idx),
374
+ self._quant_interpreter.get_tensor(q_idx),
375
+ quant_input_detail['quantization_parameters']['scales'][0],
376
+ quant_input_detail['quantization_parameters']['zero_points']
377
+ [0]))
378
+
379
+ # Calculate final aggregated metrics for each layer.
380
+ for metrics in layer_statistics.values():
381
+ for metric_name in metrics:
382
+ metrics[metric_name] = np.nanmean(metrics[metric_name])
383
+
384
+ return layer_statistics
385
+
386
+ def _collect_model_statistics(self) -> Dict[str, float]:
387
+ """Collects model output metrics.
388
+
389
+ For all data from the given RepresentativeDataset, collect all model output
390
+ results from float model & quantized debug model, and calculate metrics
391
+ by using model output functions. As a result, self.model_results is filled,
392
+
393
+ where self.model_results[model_output_function_name] = `aggregated model
394
+ output function value` (a scalar).
395
+
396
+ Returns:
397
+ aggregated per-model output discrepancy metrics.
398
+ {metric_name: aggregated_metric}
399
+ """
400
+
401
+ model_statistics = collections.defaultdict(list)
402
+
403
+ initialize = True
404
+ for tensor_data in self._data_gen():
405
+ # Run quantized debug model and collect output results.
406
+ self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
407
+ self._quant_interpreter.invoke()
408
+ quant_tensor_data = self._get_output_tensors(self._quant_interpreter)
409
+
410
+ # Run float model if it's initialized.
411
+ float_tensor_data = []
412
+ if self._float_interpreter:
413
+ self._set_input_tensors(
414
+ self._float_interpreter, tensor_data, initialize)
415
+ self._float_interpreter.invoke()
416
+ float_tensor_data = self._get_output_tensors(self._float_interpreter)
417
+
418
+ initialize = False
419
+
420
+ # Calculate the metrics.
421
+ for (metric_name,
422
+ metric_fn) in self._debug_options.model_debug_metrics.items():
423
+ model_statistics[metric_name].append(
424
+ metric_fn(float_tensor_data, quant_tensor_data))
425
+
426
+ # Calculate final aggregated metrics for each outputs.
427
+ return {
428
+ metric_name: np.mean(metric)
429
+ for metric_name, metric in model_statistics.items()
430
+ }
431
+
432
+ def _set_input_tensors(self, interpreter: _interpreter.Interpreter,
433
+ tensor_data: Sequence[np.ndarray],
434
+ initialize: bool) -> None:
435
+ """Sets input tensors into TFLite model Interpreter.
436
+
437
+ Args:
438
+ interpreter: a tf.lite.Interpreter object with allocated tensors.
439
+ tensor_data: a list of Numpy array data.
440
+ initialize: set to true when input is first set for the interpreter, to
441
+ set input shapes and allocate tensors.
442
+
443
+ Raises:
444
+ ValueError: when inputs can't be set, or size of provided inputs does not
445
+ match size of model inputs.
446
+ """
447
+ input_details = interpreter.get_input_details()
448
+ if len(input_details) != len(tensor_data):
449
+ raise ValueError(
450
+ 'Number of inputs provided ({}) does not match number of inputs to '
451
+ 'the model ({})'.format(len(tensor_data), len(input_details)))
452
+
453
+ if initialize:
454
+ for input_detail, tensor in zip(input_details, tensor_data):
455
+ interpreter.resize_tensor_input(input_detail['index'], tensor.shape)
456
+ interpreter.allocate_tensors()
457
+
458
+ for input_detail, tensor in zip(input_details, tensor_data):
459
+ if tensor.dtype == np.float32 and input_detail['dtype'] == np.int8:
460
+ quant_params = _get_quant_params(input_detail)
461
+ if quant_params:
462
+ scale, zero_point = quant_params
463
+ tensor = np.round((tensor / scale) + zero_point).astype(np.int8)
464
+ interpreter.set_tensor(input_detail['index'], tensor)
465
+
466
+ def _get_output_tensors(
467
+ self, interpreter: _interpreter.Interpreter) -> List[np.ndarray]:
468
+ """Returns output tensors of given TFLite model Interpreter.
469
+
470
+ Args:
471
+ interpreter: a tf.lite.Interpreter object with allocated tensors.
472
+
473
+ Returns:
474
+ a list of numpy arrays representing output tensor results.
475
+ """
476
+
477
+ outputs = []
478
+ for output_detail in interpreter.get_output_details():
479
+ tensor = interpreter.get_tensor(output_detail['index'])
480
+ if output_detail['dtype'] == np.int8:
481
+ quant_params = _get_quant_params(output_detail)
482
+ if quant_params:
483
+ scale, zero_point = quant_params
484
+ tensor = ((tensor.astype(np.float32) - zero_point) * scale).astype(
485
+ np.float32)
486
+ outputs.append(tensor)
487
+
488
+ return outputs
489
+
490
+ def _get_numeric_verify_tensor_details(self) -> List[str]:
491
+ """Returns all names of all tensors from NumericVerify op."""
492
+ # pylint: disable=protected-access
493
+ if not self._numeric_verify_tensor_details:
494
+ self._numeric_verify_tensor_details = []
495
+ self._numeric_verify_op_details = {}
496
+ for op_info in self._quant_interpreter._get_ops_details():
497
+ if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME:
498
+ self._numeric_verify_tensor_details.append(
499
+ self._quant_interpreter._get_tensor_details(
500
+ op_info['outputs'][0], subgraph_index=0))
501
+ tensor_name = self._numeric_verify_tensor_details[-1]['name']
502
+ self._numeric_verify_op_details[tensor_name] = op_info
503
+ # pylint: enable=protected-access
504
+ return self._numeric_verify_tensor_details
505
+
506
+ def _get_operand_name_and_index(self,
507
+ numeric_verify_name: str) -> Tuple[str, int]:
508
+ """Gets the index and name of NumericVerify Op's quantized input tensor.
509
+
510
+ Args:
511
+ numeric_verify_name: name of the NumericVerify op's output tensor. It has
512
+ format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}`
513
+
514
+ Returns:
515
+ Tuple of (tensor_name, tensor_idx) for quantized op's output tensor.
516
+ """
517
+ tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1)
518
+ float_tensor_name = tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:]
519
+ if re.match(r'\d', float_tensor_name[-1]):
520
+ float_tensor_name = float_tensor_name[:-1]
521
+
522
+ return (float_tensor_name, int(tensor_idx))
523
+
524
+ def layer_statistics_dump(self, file: IO[str]) -> None:
525
+ """Dumps layer statistics into file, in csv format.
526
+
527
+ Args:
528
+ file: file, or file-like object to write.
529
+ """
530
+ # order of `fields` is the order of fields in csv.
531
+ fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys())
532
+ if self._debug_options.layer_direct_compare_metrics is not None:
533
+ fields += list(self._debug_options.layer_direct_compare_metrics.keys())
534
+ fields += ['scale', 'zero_point', 'tensor_name']
535
+ writer = csv.DictWriter(file, fields)
536
+ writer.writeheader()
537
+ if self.layer_statistics:
538
+ for name, metrics in self.layer_statistics.items():
539
+ data = metrics.copy()
540
+ (data['tensor_name'], _) = self._get_operand_name_and_index(name)
541
+ data['tensor_idx'] = self._numeric_verify_op_details[name]['inputs'][0]
542
+ data['op_name'] = self._quant_interpreter._get_op_details( # pylint: disable=protected-access
543
+ self._defining_op[data['tensor_idx']])['op_name']
544
+ details = self._quant_interpreter._get_tensor_details( # pylint: disable=protected-access
545
+ data['tensor_idx'], subgraph_index=0)
546
+ data['scale'], data['zero_point'] = (
547
+ details['quantization_parameters']['scales'][0],
548
+ details['quantization_parameters']['zero_points'][0])
549
+ writer.writerow(data)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/tools/visualize.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+ """This tool creates an html visualization of a TensorFlow Lite graph.
17
+
18
+ Example usage:
19
+
20
+ python visualize.py foo.tflite foo.html
21
+ """
22
+
23
+ import json
24
+ import os
25
+ import re
26
+ import sys
27
+ import numpy as np
28
+
29
+ # pylint: disable=g-import-not-at-top
30
+ if not os.path.splitext(__file__)[0].endswith(
31
+ os.path.join("tflite_runtime", "visualize")):
32
+ # This file is part of tensorflow package.
33
+ from tensorflow.lite.python import schema_py_generated as schema_fb
34
+ else:
35
+ # This file is part of tflite_runtime package.
36
+ from tflite_runtime import schema_py_generated as schema_fb
37
+
38
+ # A CSS description for making the visualizer
39
+ _CSS = """
40
+ <html>
41
+ <head>
42
+ <style>
43
+ body {font-family: sans-serif; background-color: #fa0;}
44
+ table {background-color: #eca;}
45
+ th {background-color: black; color: white;}
46
+ h1 {
47
+ background-color: ffaa00;
48
+ padding:5px;
49
+ color: black;
50
+ }
51
+
52
+ svg {
53
+ margin: 10px;
54
+ border: 2px;
55
+ border-style: solid;
56
+ border-color: black;
57
+ background: white;
58
+ }
59
+
60
+ div {
61
+ border-radius: 5px;
62
+ background-color: #fec;
63
+ padding:5px;
64
+ margin:5px;
65
+ }
66
+
67
+ .tooltip {color: blue;}
68
+ .tooltip .tooltipcontent {
69
+ visibility: hidden;
70
+ color: black;
71
+ background-color: yellow;
72
+ padding: 5px;
73
+ border-radius: 4px;
74
+ position: absolute;
75
+ z-index: 1;
76
+ }
77
+ .tooltip:hover .tooltipcontent {
78
+ visibility: visible;
79
+ }
80
+
81
+ .edges line {
82
+ stroke: #333;
83
+ }
84
+
85
+ text {
86
+ font-weight: bold;
87
+ }
88
+
89
+ .nodes text {
90
+ color: black;
91
+ pointer-events: none;
92
+ font-family: sans-serif;
93
+ font-size: 11px;
94
+ }
95
+ </style>
96
+
97
+ <script src="https://d3js.org/d3.v4.min.js"></script>
98
+
99
+ </head>
100
+ <body>
101
+ """
102
+
103
+ _D3_HTML_TEMPLATE = """
104
+ <script>
105
+ function buildGraph() {
106
+ // Build graph data
107
+ var graph = %s;
108
+
109
+ var svg = d3.select("#subgraph%d")
110
+ var width = svg.attr("width");
111
+ var height = svg.attr("height");
112
+ // Make the graph scrollable.
113
+ svg = svg.call(d3.zoom().on("zoom", function() {
114
+ svg.attr("transform", d3.event.transform);
115
+ })).append("g");
116
+
117
+
118
+ var color = d3.scaleOrdinal(d3.schemeDark2);
119
+
120
+ var simulation = d3.forceSimulation()
121
+ .force("link", d3.forceLink().id(function(d) {return d.id;}))
122
+ .force("charge", d3.forceManyBody())
123
+ .force("center", d3.forceCenter(0.5 * width, 0.5 * height));
124
+
125
+ var edge = svg.append("g").attr("class", "edges").selectAll("line")
126
+ .data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none")
127
+
128
+ // Make the node group
129
+ var node = svg.selectAll(".nodes")
130
+ .data(graph.nodes)
131
+ .enter().append("g")
132
+ .attr("x", function(d){return d.x})
133
+ .attr("y", function(d){return d.y})
134
+ .attr("transform", function(d) {
135
+ return "translate( " + d.x + ", " + d.y + ")"
136
+ })
137
+ .attr("class", "nodes")
138
+ .call(d3.drag()
139
+ .on("start", function(d) {
140
+ if(!d3.event.active) simulation.alphaTarget(1.0).restart();
141
+ d.fx = d.x;d.fy = d.y;
142
+ })
143
+ .on("drag", function(d) {
144
+ d.fx = d3.event.x; d.fy = d3.event.y;
145
+ })
146
+ .on("end", function(d) {
147
+ if (!d3.event.active) simulation.alphaTarget(0);
148
+ d.fx = d.fy = null;
149
+ }));
150
+ // Within the group, draw a box for the node position and text
151
+ // on the side.
152
+
153
+ var node_width = 150;
154
+ var node_height = 30;
155
+
156
+ node.append("rect")
157
+ .attr("r", "5px")
158
+ .attr("width", node_width)
159
+ .attr("height", node_height)
160
+ .attr("rx", function(d) { return d.group == 1 ? 1 : 10; })
161
+ .attr("stroke", "#000000")
162
+ .attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; })
163
+ node.append("text")
164
+ .text(function(d) { return d.name; })
165
+ .attr("x", 5)
166
+ .attr("y", 20)
167
+ .attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; })
168
+ // Setup force parameters and update position callback
169
+
170
+
171
+ var node = svg.selectAll(".nodes")
172
+ .data(graph.nodes);
173
+
174
+ // Bind the links
175
+ var name_to_g = {}
176
+ node.each(function(data, index, nodes) {
177
+ console.log(data.id)
178
+ name_to_g[data.id] = this;
179
+ });
180
+
181
+ function proc(w, t) {
182
+ return parseInt(w.getAttribute(t));
183
+ }
184
+ edge.attr("d", function(d) {
185
+ function lerp(t, a, b) {
186
+ return (1.0-t) * a + t * b;
187
+ }
188
+ var x1 = proc(name_to_g[d.source],"x") + node_width /2;
189
+ var y1 = proc(name_to_g[d.source],"y") + node_height;
190
+ var x2 = proc(name_to_g[d.target],"x") + node_width /2;
191
+ var y2 = proc(name_to_g[d.target],"y");
192
+ var s = "M " + x1 + " " + y1
193
+ + " C " + x1 + " " + lerp(.5, y1, y2)
194
+ + " " + x2 + " " + lerp(.5, y1, y2)
195
+ + " " + x2 + " " + y2
196
+ return s;
197
+ });
198
+
199
+ }
200
+ buildGraph()
201
+ </script>
202
+ """
203
+
204
+
205
+ def TensorTypeToName(tensor_type):
206
+ """Converts a numerical enum to a readable tensor type."""
207
+ for name, value in schema_fb.TensorType.__dict__.items():
208
+ if value == tensor_type:
209
+ return name
210
+ return None
211
+
212
+
213
+ def BuiltinCodeToName(code):
214
+ """Converts a builtin op code enum to a readable name."""
215
+ for name, value in schema_fb.BuiltinOperator.__dict__.items():
216
+ if value == code:
217
+ return name
218
+ return None
219
+
220
+
221
+ def NameListToString(name_list):
222
+ """Converts a list of integers to the equivalent ASCII string."""
223
+ if isinstance(name_list, str):
224
+ return name_list
225
+ else:
226
+ result = ""
227
+ if name_list is not None:
228
+ for val in name_list:
229
+ result = result + chr(int(val))
230
+ return result
231
+
232
+
233
+ class OpCodeMapper:
234
+ """Maps an opcode index to an op name."""
235
+
236
+ def __init__(self, data):
237
+ self.code_to_name = {}
238
+ for idx, d in enumerate(data["operator_codes"]):
239
+ self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"])
240
+ if self.code_to_name[idx] == "CUSTOM":
241
+ self.code_to_name[idx] = NameListToString(d["custom_code"])
242
+
243
+ def __call__(self, x):
244
+ if x not in self.code_to_name:
245
+ s = "<UNKNOWN>"
246
+ else:
247
+ s = self.code_to_name[x]
248
+ return "%s (%d)" % (s, x)
249
+
250
+
251
+ class DataSizeMapper:
252
+ """For buffers, report the number of bytes."""
253
+
254
+ def __call__(self, x):
255
+ if x is not None:
256
+ return "%d bytes" % len(x)
257
+ else:
258
+ return "--"
259
+
260
+
261
+ class TensorMapper:
262
+ """Maps a list of tensor indices to a tooltip hoverable indicator of more."""
263
+
264
+ def __init__(self, subgraph_data):
265
+ self.data = subgraph_data
266
+
267
+ def __call__(self, x):
268
+ html = ""
269
+ if x is None:
270
+ return html
271
+
272
+ html += "<span class='tooltip'><span class='tooltipcontent'>"
273
+ for i in x:
274
+ tensor = self.data["tensors"][i]
275
+ html += str(i) + " "
276
+ html += NameListToString(tensor["name"]) + " "
277
+ html += TensorTypeToName(tensor["type"]) + " "
278
+ html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
279
+ html += (repr(tensor["shape_signature"])
280
+ if "shape_signature" in tensor else "[]") + "<br>"
281
+ html += "</span>"
282
+ html += repr(x)
283
+ html += "</span>"
284
+ return html
285
+
286
+
287
+ def GenerateGraph(subgraph_idx, g, opcode_mapper):
288
+ """Produces the HTML required to have a d3 visualization of the dag."""
289
+
290
+ def TensorName(idx):
291
+ return "t%d" % idx
292
+
293
+ def OpName(idx):
294
+ return "o%d" % idx
295
+
296
+ edges = []
297
+ nodes = []
298
+ first = {}
299
+ second = {}
300
+ pixel_mult = 200 # TODO(aselle): multiplier for initial placement
301
+ width_mult = 170 # TODO(aselle): multiplier for initial placement
302
+ for op_index, op in enumerate(g["operators"] or []):
303
+ if op["inputs"] is not None:
304
+ for tensor_input_position, tensor_index in enumerate(op["inputs"]):
305
+ if tensor_index not in first:
306
+ first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult,
307
+ (tensor_input_position + 1) * width_mult)
308
+ edges.append({
309
+ "source": TensorName(tensor_index),
310
+ "target": OpName(op_index)
311
+ })
312
+ if op["outputs"] is not None:
313
+ for tensor_output_position, tensor_index in enumerate(op["outputs"]):
314
+ if tensor_index not in second:
315
+ second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult,
316
+ (tensor_output_position + 1) * width_mult)
317
+ edges.append({
318
+ "target": TensorName(tensor_index),
319
+ "source": OpName(op_index)
320
+ })
321
+
322
+ nodes.append({
323
+ "id": OpName(op_index),
324
+ "name": opcode_mapper(op["opcode_index"]),
325
+ "group": 2,
326
+ "x": pixel_mult,
327
+ "y": (op_index + 1) * pixel_mult
328
+ })
329
+ for tensor_index, tensor in enumerate(g["tensors"]):
330
+ initial_y = (
331
+ first[tensor_index] if tensor_index in first else
332
+ second[tensor_index] if tensor_index in second else (0, 0))
333
+
334
+ nodes.append({
335
+ "id": TensorName(tensor_index),
336
+ "name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
337
+ "group": 1,
338
+ "x": initial_y[1],
339
+ "y": initial_y[0]
340
+ })
341
+ graph_str = json.dumps({"nodes": nodes, "edges": edges})
342
+
343
+ html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
344
+ return html
345
+
346
+
347
+ def GenerateTableHtml(items, keys_to_print, display_index=True):
348
+ """Given a list of object values and keys to print, make an HTML table.
349
+
350
+ Args:
351
+ items: Items to print an array of dicts.
352
+ keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
353
+ items[0][key] should exist. display_fn is the mapping function on display.
354
+ i.e. the displayed html cell will have the string returned by
355
+ `mapping_fn(items[0][key])`.
356
+ display_index: add a column which is the index of each row in `items`.
357
+
358
+ Returns:
359
+ An html table.
360
+ """
361
+ html = ""
362
+ # Print the list of items
363
+ html += "<table><tr>\n"
364
+ html += "<tr>\n"
365
+ if display_index:
366
+ html += "<th>index</th>"
367
+ for h, mapper in keys_to_print:
368
+ html += "<th>%s</th>" % h
369
+ html += "</tr>\n"
370
+ for idx, tensor in enumerate(items):
371
+ html += "<tr>\n"
372
+ if display_index:
373
+ html += "<td>%d</td>" % idx
374
+ # print tensor.keys()
375
+ for h, mapper in keys_to_print:
376
+ val = tensor[h] if h in tensor else None
377
+ val = val if mapper is None else mapper(val)
378
+ html += "<td>%s</td>\n" % val
379
+
380
+ html += "</tr>\n"
381
+ html += "</table>\n"
382
+ return html
383
+
384
+
385
+ def CamelCaseToSnakeCase(camel_case_input):
386
+ """Converts an identifier in CamelCase to snake_case."""
387
+ s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
388
+ return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
389
+
390
+
391
+ def FlatbufferToDict(fb, preserve_as_numpy):
392
+ """Converts a hierarchy of FB objects into a nested dict.
393
+
394
+ We avoid transforming big parts of the flat buffer into python arrays. This
395
+ speeds conversion from ten minutes to a few seconds on big graphs.
396
+
397
+ Args:
398
+ fb: a flat buffer structure. (i.e. ModelT)
399
+ preserve_as_numpy: true if all downstream np.arrays should be preserved.
400
+ false if all downstream np.array should become python arrays
401
+ Returns:
402
+ A dictionary representing the flatbuffer rather than a flatbuffer object.
403
+ """
404
+ if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
405
+ return fb
406
+ elif hasattr(fb, "__dict__"):
407
+ result = {}
408
+ for attribute_name in dir(fb):
409
+ attribute = fb.__getattribute__(attribute_name)
410
+ if not callable(attribute) and attribute_name[0] != "_":
411
+ snake_name = CamelCaseToSnakeCase(attribute_name)
412
+ preserve = True if attribute_name == "buffers" else preserve_as_numpy
413
+ result[snake_name] = FlatbufferToDict(attribute, preserve)
414
+ return result
415
+ elif isinstance(fb, np.ndarray):
416
+ return fb if preserve_as_numpy else fb.tolist()
417
+ elif hasattr(fb, "__len__"):
418
+ return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
419
+ else:
420
+ return fb
421
+
422
+
423
+ def CreateDictFromFlatbuffer(buffer_data):
424
+ model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
425
+ model = schema_fb.ModelT.InitFromObj(model_obj)
426
+ return FlatbufferToDict(model, preserve_as_numpy=False)
427
+
428
+
429
+ def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name
430
+ """Returns html description with the given tflite model.
431
+
432
+ Args:
433
+ tflite_input: TFLite flatbuffer model path or model object.
434
+ input_is_filepath: Tells if tflite_input is a model path or a model object.
435
+
436
+ Returns:
437
+ Dump of the given tflite model in HTML format.
438
+
439
+ Raises:
440
+ RuntimeError: If the input is not valid.
441
+ """
442
+
443
+ # Convert the model into a JSON flatbuffer using flatc (build if doesn't
444
+ # exist.
445
+ if input_is_filepath:
446
+ if not os.path.exists(tflite_input):
447
+ raise RuntimeError("Invalid filename %r" % tflite_input)
448
+ if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
449
+ with open(tflite_input, "rb") as file_handle:
450
+ file_data = bytearray(file_handle.read())
451
+ data = CreateDictFromFlatbuffer(file_data)
452
+ elif tflite_input.endswith(".json"):
453
+ data = json.load(open(tflite_input))
454
+ else:
455
+ raise RuntimeError("Input file was not .tflite or .json")
456
+ else:
457
+ data = CreateDictFromFlatbuffer(tflite_input)
458
+ html = ""
459
+ html += _CSS
460
+ html += "<h1>TensorFlow Lite Model</h2>"
461
+
462
+ data["filename"] = tflite_input if input_is_filepath else (
463
+ "Null (used model object)") # Avoid special case
464
+
465
+ toplevel_stuff = [("filename", None), ("version", None),
466
+ ("description", None)]
467
+
468
+ html += "<table>\n"
469
+ for key, mapping in toplevel_stuff:
470
+ if not mapping:
471
+ mapping = lambda x: x
472
+ html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key)))
473
+ html += "</table>\n"
474
+
475
+ # Spec on what keys to display
476
+ buffer_keys_to_display = [("data", DataSizeMapper())]
477
+ operator_keys_to_display = [("builtin_code", BuiltinCodeToName),
478
+ ("custom_code", NameListToString),
479
+ ("version", None)]
480
+
481
+ # Update builtin code fields.
482
+ for d in data["operator_codes"]:
483
+ d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"])
484
+
485
+ for subgraph_idx, g in enumerate(data["subgraphs"]):
486
+ # Subgraph local specs on what to display
487
+ html += "<div class='subgraph'>"
488
+ tensor_mapper = TensorMapper(g)
489
+ opcode_mapper = OpCodeMapper(data)
490
+ op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
491
+ ("builtin_options", None),
492
+ ("opcode_index", opcode_mapper)]
493
+ tensor_keys_to_display = [("name", NameListToString),
494
+ ("type", TensorTypeToName), ("shape", None),
495
+ ("shape_signature", None), ("buffer", None),
496
+ ("quantization", None)]
497
+
498
+ html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
499
+
500
+ # Inputs and outputs.
501
+ html += "<h3>Inputs/Outputs</h3>\n"
502
+ html += GenerateTableHtml([{
503
+ "inputs": g["inputs"],
504
+ "outputs": g["outputs"]
505
+ }], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
506
+ display_index=False)
507
+
508
+ # Print the tensors.
509
+ html += "<h3>Tensors</h3>\n"
510
+ html += GenerateTableHtml(g["tensors"], tensor_keys_to_display)
511
+
512
+ # Print the ops.
513
+ if g["operators"]:
514
+ html += "<h3>Ops</h3>\n"
515
+ html += GenerateTableHtml(g["operators"], op_keys_to_display)
516
+
517
+ # Visual graph.
518
+ html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % (
519
+ subgraph_idx,)
520
+ html += GenerateGraph(subgraph_idx, g, opcode_mapper)
521
+ html += "</div>"
522
+
523
+ # Buffers have no data, but maybe in the future they will
524
+ html += "<h2>Buffers</h2>\n"
525
+ html += GenerateTableHtml(data["buffers"], buffer_keys_to_display)
526
+
527
+ # Operator codes
528
+ html += "<h2>Operator Codes</h2>\n"
529
+ html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
530
+
531
+ html += "</body></html>\n"
532
+
533
+ return html
534
+
535
+
536
+ def main(argv):
537
+ try:
538
+ tflite_input = argv[1]
539
+ html_output = argv[2]
540
+ except IndexError:
541
+ print("Usage: %s <input tflite> <output html>" % (argv[0]))
542
+ else:
543
+ html = create_html(tflite_input)
544
+ with open(html_output, "w") as output_file:
545
+ output_file.write(html)
546
+
547
+
548
+ if __name__ == "__main__":
549
+ main(sys.argv)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (216 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.pyi ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ class TestClassDef:
17
+ def __init__(self) -> None: ...
18
+ def method(self) -> object: ...
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/impl/testing/pybind_for_testing.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dd76e74055bba4c02308da5f57791117799704b278e153aef7741edbae230b2
3
+ size 1072920
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """This module implements operators that AutoGraph overloads.
16
+
17
+ Note that "operator" is used loosely here, and includes control structures like
18
+ conditionals and loops, implemented in functional form, using for example
19
+ closures for the body.
20
+ """
21
+
22
+ # Naming conventions:
23
+ # * operator names match the name usually used for the respective Python
24
+ # idiom; examples: for_stmt, list_append
25
+ # * operator arguments match either of:
26
+ # - the corresponding Python AST attribute (e.g. the condition of an if
27
+ # statement is called test) if the operator represents an AST construct
28
+ # - the names used in the Python docs, if the operator is a function (e.g.
29
+ # list_ and x for append, see
30
+ # https://docs.python.org/3.7/tutorial/datastructures.html)
31
+ #
32
+ # All operators may accept a final argument named "opts", of a type that
33
+ # subclasses namedtuple and contains any arguments that are only required
34
+ # for some specializations of the operator.
35
+
36
+ from tensorflow.python.autograph.operators.conditional_expressions import if_exp
37
+ from tensorflow.python.autograph.operators.control_flow import for_stmt
38
+ from tensorflow.python.autograph.operators.control_flow import if_stmt
39
+ from tensorflow.python.autograph.operators.control_flow import while_stmt
40
+ from tensorflow.python.autograph.operators.data_structures import list_append
41
+ from tensorflow.python.autograph.operators.data_structures import list_pop
42
+ from tensorflow.python.autograph.operators.data_structures import list_stack
43
+ from tensorflow.python.autograph.operators.data_structures import ListPopOpts
44
+ from tensorflow.python.autograph.operators.data_structures import ListStackOpts
45
+ from tensorflow.python.autograph.operators.data_structures import new_list
46
+ from tensorflow.python.autograph.operators.exceptions import assert_stmt
47
+ from tensorflow.python.autograph.operators.logical import and_
48
+ from tensorflow.python.autograph.operators.logical import eq
49
+ from tensorflow.python.autograph.operators.logical import not_
50
+ from tensorflow.python.autograph.operators.logical import not_eq
51
+ from tensorflow.python.autograph.operators.logical import or_
52
+ from tensorflow.python.autograph.operators.py_builtins import float_
53
+ from tensorflow.python.autograph.operators.py_builtins import int_
54
+ from tensorflow.python.autograph.operators.py_builtins import len_
55
+ from tensorflow.python.autograph.operators.py_builtins import print_
56
+ from tensorflow.python.autograph.operators.py_builtins import range_
57
+ from tensorflow.python.autograph.operators.slices import get_item
58
+ from tensorflow.python.autograph.operators.slices import GetItemOpts
59
+ from tensorflow.python.autograph.operators.slices import set_item
60
+ from tensorflow.python.autograph.operators.variables import ld
61
+ from tensorflow.python.autograph.operators.variables import ldu
62
+ from tensorflow.python.autograph.operators.variables import Undefined
63
+ from tensorflow.python.autograph.operators.variables import UndefinedReturnValue
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.73 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/conditional_expressions.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/control_flow.cpython-310.pyc ADDED
Binary file (36.7 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/data_structures.cpython-310.pyc ADDED
Binary file (9.65 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/exceptions.cpython-310.pyc ADDED
Binary file (2.57 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/autograph/operators/__pycache__/logical.cpython-310.pyc ADDED
Binary file (2.87 kB). View file