Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/debug_service_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/debug_service_pb2_grpc.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/debugger_event_metadata_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/debug_service_pb2.py +43 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/debug_service_pb2_grpc.py +90 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/debugger_event_metadata_pb2.py +25 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/function_cache.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/function_type.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/function_type_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/type_dispatch.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/function_cache.py +103 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/function_type.py +720 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/function_type_pb2.py +32 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/type_dispatch.py +131 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__init__.py +38 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__pycache__/serialization_test_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__pycache__/trace_type_builder.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__pycache__/util.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/custom_nest_trace_type.py +143 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/default_types.py +826 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/default_types_pb2.py +38 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/serialization.py +100 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/serialization_pb2.py +26 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/serialization_test_pb2.py +30 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/trace_type_builder.py +208 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/util.py +52 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/__pycache__/toco_flags_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/__pycache__/types_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__pycache__/gen_html.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__pycache__/toco_conversion_log_pb2.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/gen_html.py +265 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/toco_conversion_log_pb2.py +37 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/__pycache__/toco_from_protos.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/toco_from_protos.py +74 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/tsl/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/tsl/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/tsl/profiler/__init__.py +0 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (191 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (197 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/debug_service_pb2.cpython-310.pyc
ADDED
|
Binary file (2.86 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/debug_service_pb2_grpc.cpython-310.pyc
ADDED
|
Binary file (3.35 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/__pycache__/debugger_event_metadata_pb2.cpython-310.pyc
ADDED
|
Binary file (1.09 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/debug_service_pb2.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/core/debug/debug_service.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
|
| 15 |
+
from tensorflow.core.profiler import tfprof_log_pb2 as tensorflow_dot_core_dot_profiler_dot_tfprof__log__pb2
|
| 16 |
+
from tensorflow.core.protobuf import debug_pb2 as tensorflow_dot_core_dot_protobuf_dot_debug__pb2
|
| 17 |
+
from tensorflow.core.util import event_pb2 as tensorflow_dot_core_dot_util_dot_event__pb2
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)tensorflow/core/debug/debug_service.proto\x12\ntensorflow\x1a&tensorflow/core/framework/tensor.proto\x1a)tensorflow/core/profiler/tfprof_log.proto\x1a$tensorflow/core/protobuf/debug.proto\x1a tensorflow/core/util/event.proto\"\xde\x02\n\nEventReply\x12I\n\x16\x64\x65\x62ug_op_state_changes\x18\x01 \x03(\x0b\x32).tensorflow.EventReply.DebugOpStateChange\x12\'\n\x06tensor\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto\x1a\xdb\x01\n\x12\x44\x65\x62ugOpStateChange\x12>\n\x05state\x18\x01 \x01(\x0e\x32/.tensorflow.EventReply.DebugOpStateChange.State\x12\x11\n\tnode_name\x18\x02 \x01(\t\x12\x13\n\x0boutput_slot\x18\x03 \x01(\x05\x12\x10\n\x08\x64\x65\x62ug_op\x18\x04 \x01(\t\"K\n\x05State\x12\x15\n\x11STATE_UNSPECIFIED\x10\x00\x12\x0c\n\x08\x44ISABLED\x10\x01\x12\r\n\tREAD_ONLY\x10\x02\x12\x0e\n\nREAD_WRITE\x10\x03\"\xa7\x03\n\rCallTraceback\x12\x35\n\tcall_type\x18\x01 \x01(\x0e\x32\".tensorflow.CallTraceback.CallType\x12\x10\n\x08\x63\x61ll_key\x18\x02 \x01(\t\x12\x30\n\x0corigin_stack\x18\x03 \x01(\x0b\x32\x1a.tensorflow.tfprof.CodeDef\x12L\n\x13origin_id_to_string\x18\x04 \x03(\x0b\x32/.tensorflow.CallTraceback.OriginIdToStringEntry\x12\x36\n\x0fgraph_traceback\x18\x05 \x01(\x0b\x32\x1d.tensorflow.tfprof.OpLogProto\x12\x15\n\rgraph_version\x18\x06 \x01(\x03\x1a\x37\n\x15OriginIdToStringEntry\x12\x0b\n\x03key\x18\x01 \x01(\x03\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"E\n\x08\x43\x61llType\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x13\n\x0fGRAPH_EXECUTION\x10\x01\x12\x13\n\x0f\x45\x41GER_EXECUTION\x10\x02\x32\xdd\x01\n\rEventListener\x12;\n\nSendEvents\x12\x11.tensorflow.Event\x1a\x16.tensorflow.EventReply(\x01\x30\x01\x12\x43\n\x0eSendTracebacks\x12\x19.tensorflow.CallTraceback\x1a\x16.tensorflow.EventReply\x12J\n\x0fSendSourceFiles\x12\x1f.tensorflow.DebuggedSourceFiles\x1a\x16.tensorflow.EventReplyb\x06proto3')
|
| 21 |
+
|
| 22 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 23 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.core.debug.debug_service_pb2', globals())
|
| 24 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 25 |
+
|
| 26 |
+
DESCRIPTOR._options = None
|
| 27 |
+
_CALLTRACEBACK_ORIGINIDTOSTRINGENTRY._options = None
|
| 28 |
+
_CALLTRACEBACK_ORIGINIDTOSTRINGENTRY._serialized_options = b'8\001'
|
| 29 |
+
_EVENTREPLY._serialized_start=213
|
| 30 |
+
_EVENTREPLY._serialized_end=563
|
| 31 |
+
_EVENTREPLY_DEBUGOPSTATECHANGE._serialized_start=344
|
| 32 |
+
_EVENTREPLY_DEBUGOPSTATECHANGE._serialized_end=563
|
| 33 |
+
_EVENTREPLY_DEBUGOPSTATECHANGE_STATE._serialized_start=488
|
| 34 |
+
_EVENTREPLY_DEBUGOPSTATECHANGE_STATE._serialized_end=563
|
| 35 |
+
_CALLTRACEBACK._serialized_start=566
|
| 36 |
+
_CALLTRACEBACK._serialized_end=989
|
| 37 |
+
_CALLTRACEBACK_ORIGINIDTOSTRINGENTRY._serialized_start=863
|
| 38 |
+
_CALLTRACEBACK_ORIGINIDTOSTRINGENTRY._serialized_end=918
|
| 39 |
+
_CALLTRACEBACK_CALLTYPE._serialized_start=920
|
| 40 |
+
_CALLTRACEBACK_CALLTYPE._serialized_end=989
|
| 41 |
+
_EVENTLISTENER._serialized_start=992
|
| 42 |
+
_EVENTLISTENER._serialized_end=1213
|
| 43 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/debug_service_pb2_grpc.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
| 2 |
+
import grpc
|
| 3 |
+
|
| 4 |
+
from tensorflow.core.debug import debug_service_pb2 as tensorflow_dot_core_dot_debug_dot_debug__service__pb2
|
| 5 |
+
from tensorflow.core.protobuf import debug_pb2 as tensorflow_dot_core_dot_protobuf_dot_debug__pb2
|
| 6 |
+
from tensorflow.core.util import event_pb2 as tensorflow_dot_core_dot_util_dot_event__pb2
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class EventListenerStub(object):
|
| 10 |
+
"""EventListener: Receives Event protos, e.g., from debugged TensorFlow
|
| 11 |
+
runtime(s).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, channel):
|
| 15 |
+
"""Constructor.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
channel: A grpc.Channel.
|
| 19 |
+
"""
|
| 20 |
+
self.SendEvents = channel.stream_stream(
|
| 21 |
+
'/tensorflow.EventListener/SendEvents',
|
| 22 |
+
request_serializer=tensorflow_dot_core_dot_util_dot_event__pb2.Event.SerializeToString,
|
| 23 |
+
response_deserializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.FromString,
|
| 24 |
+
)
|
| 25 |
+
self.SendTracebacks = channel.unary_unary(
|
| 26 |
+
'/tensorflow.EventListener/SendTracebacks',
|
| 27 |
+
request_serializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.CallTraceback.SerializeToString,
|
| 28 |
+
response_deserializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.FromString,
|
| 29 |
+
)
|
| 30 |
+
self.SendSourceFiles = channel.unary_unary(
|
| 31 |
+
'/tensorflow.EventListener/SendSourceFiles',
|
| 32 |
+
request_serializer=tensorflow_dot_core_dot_protobuf_dot_debug__pb2.DebuggedSourceFiles.SerializeToString,
|
| 33 |
+
response_deserializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.FromString,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class EventListenerServicer(object):
|
| 38 |
+
"""EventListener: Receives Event protos, e.g., from debugged TensorFlow
|
| 39 |
+
runtime(s).
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def SendEvents(self, request_iterator, context):
|
| 43 |
+
"""Client(s) can use this RPC method to send the EventListener Event protos.
|
| 44 |
+
The Event protos can hold information such as:
|
| 45 |
+
1) intermediate tensors from a debugged graph being executed, which can
|
| 46 |
+
be sent from DebugIdentity ops configured with grpc URLs.
|
| 47 |
+
2) GraphDefs of partition graphs, which can be sent from special debug
|
| 48 |
+
ops that get executed immediately after the beginning of the graph
|
| 49 |
+
execution.
|
| 50 |
+
"""
|
| 51 |
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 52 |
+
context.set_details('Method not implemented!')
|
| 53 |
+
raise NotImplementedError('Method not implemented!')
|
| 54 |
+
|
| 55 |
+
def SendTracebacks(self, request, context):
|
| 56 |
+
"""Send the tracebacks of a TensorFlow execution call.
|
| 57 |
+
"""
|
| 58 |
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 59 |
+
context.set_details('Method not implemented!')
|
| 60 |
+
raise NotImplementedError('Method not implemented!')
|
| 61 |
+
|
| 62 |
+
def SendSourceFiles(self, request, context):
|
| 63 |
+
"""Send a collection of source code files being debugged.
|
| 64 |
+
"""
|
| 65 |
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
| 66 |
+
context.set_details('Method not implemented!')
|
| 67 |
+
raise NotImplementedError('Method not implemented!')
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def add_EventListenerServicer_to_server(servicer, server):
|
| 71 |
+
rpc_method_handlers = {
|
| 72 |
+
'SendEvents': grpc.stream_stream_rpc_method_handler(
|
| 73 |
+
servicer.SendEvents,
|
| 74 |
+
request_deserializer=tensorflow_dot_core_dot_util_dot_event__pb2.Event.FromString,
|
| 75 |
+
response_serializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.SerializeToString,
|
| 76 |
+
),
|
| 77 |
+
'SendTracebacks': grpc.unary_unary_rpc_method_handler(
|
| 78 |
+
servicer.SendTracebacks,
|
| 79 |
+
request_deserializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.CallTraceback.FromString,
|
| 80 |
+
response_serializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.SerializeToString,
|
| 81 |
+
),
|
| 82 |
+
'SendSourceFiles': grpc.unary_unary_rpc_method_handler(
|
| 83 |
+
servicer.SendSourceFiles,
|
| 84 |
+
request_deserializer=tensorflow_dot_core_dot_protobuf_dot_debug__pb2.DebuggedSourceFiles.FromString,
|
| 85 |
+
response_serializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.SerializeToString,
|
| 86 |
+
),
|
| 87 |
+
}
|
| 88 |
+
generic_handler = grpc.method_handlers_generic_handler(
|
| 89 |
+
'tensorflow.EventListener', rpc_method_handlers)
|
| 90 |
+
server.add_generic_rpc_handlers((generic_handler,))
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/debug/debugger_event_metadata_pb2.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/core/debug/debugger_event_metadata.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n3tensorflow/core/debug/debugger_event_metadata.proto\x12!third_party.tensorflow.core.debug\"e\n\x15\x44\x65\x62uggerEventMetadata\x12\x0e\n\x06\x64\x65vice\x18\x01 \x01(\t\x12\x13\n\x0boutput_slot\x18\x02 \x01(\x05\x12\x12\n\nnum_chunks\x18\x03 \x01(\x05\x12\x13\n\x0b\x63hunk_index\x18\x04 \x01(\x05\x62\x06proto3')
|
| 17 |
+
|
| 18 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 19 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.core.debug.debugger_event_metadata_pb2', globals())
|
| 20 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 21 |
+
|
| 22 |
+
DESCRIPTOR._options = None
|
| 23 |
+
_DEBUGGEREVENTMETADATA._serialized_start=90
|
| 24 |
+
_DEBUGGEREVENTMETADATA._serialized_end=191
|
| 25 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (200 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (213 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/function_cache.cpython-310.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/function_type.cpython-310.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/function_type_pb2.cpython-310.pyc
ADDED
|
Binary file (1.98 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/__pycache__/type_dispatch.cpython-310.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/function_cache.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Cache to manage functions based on their FunctionType."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
from typing import Any, NamedTuple, Optional
|
| 19 |
+
|
| 20 |
+
from tensorflow.core.function.polymorphism import function_type as function_type_lib
|
| 21 |
+
from tensorflow.core.function.polymorphism import type_dispatch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FunctionContext(NamedTuple):
|
| 25 |
+
"""Contains information regarding tf.function execution context."""
|
| 26 |
+
context: Any = None
|
| 27 |
+
scope_type: Any = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FunctionCache:
|
| 31 |
+
"""A container for managing functions."""
|
| 32 |
+
|
| 33 |
+
__slots__ = ["_primary", "_dispatch_dict", "_garbage_collectors"]
|
| 34 |
+
|
| 35 |
+
def __init__(self):
|
| 36 |
+
# Maps (FunctionContext, FunctionType) to a function.
|
| 37 |
+
self._primary = collections.OrderedDict()
|
| 38 |
+
|
| 39 |
+
# Maps FunctionContext to a TypeDispatchTable containing FunctionTypes of
|
| 40 |
+
# that particular context.
|
| 41 |
+
self._dispatch_dict = {}
|
| 42 |
+
|
| 43 |
+
def lookup(self, function_type: function_type_lib.FunctionType,
|
| 44 |
+
context: Optional[FunctionContext] = None) -> Optional[Any]:
|
| 45 |
+
"""Looks up a function based on the context and type."""
|
| 46 |
+
context = context or FunctionContext()
|
| 47 |
+
if context in self._dispatch_dict:
|
| 48 |
+
dispatch_type = self._dispatch_dict[context].dispatch(function_type)
|
| 49 |
+
if dispatch_type:
|
| 50 |
+
return self._primary[(context, dispatch_type)]
|
| 51 |
+
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def delete(self, function_type: function_type_lib.FunctionType,
|
| 55 |
+
context: Optional[FunctionContext] = None,
|
| 56 |
+
) -> bool:
|
| 57 |
+
"""Deletes a function given the context and type."""
|
| 58 |
+
context = context or FunctionContext()
|
| 59 |
+
if (context, function_type) not in self._primary:
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
del self._primary[(context, function_type)]
|
| 63 |
+
self._dispatch_dict[context].delete(function_type)
|
| 64 |
+
|
| 65 |
+
return True
|
| 66 |
+
|
| 67 |
+
def add(self, fn: Any, context: Optional[FunctionContext] = None) -> None:
|
| 68 |
+
"""Adds a new function using its function_type.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
fn: The function to be added to the cache.
|
| 72 |
+
context: A FunctionContext representing the current context.
|
| 73 |
+
"""
|
| 74 |
+
context = context or FunctionContext()
|
| 75 |
+
self._primary[(context, fn.function_type)] = fn
|
| 76 |
+
if context not in self._dispatch_dict:
|
| 77 |
+
self._dispatch_dict[context] = type_dispatch.TypeDispatchTable()
|
| 78 |
+
|
| 79 |
+
self._dispatch_dict[context].add_target(fn.function_type)
|
| 80 |
+
|
| 81 |
+
def generalize(
|
| 82 |
+
self, context: FunctionContext,
|
| 83 |
+
function_type: function_type_lib.FunctionType
|
| 84 |
+
) -> function_type_lib.FunctionType:
|
| 85 |
+
"""Try to generalize a FunctionType within a FunctionContext."""
|
| 86 |
+
if context in self._dispatch_dict:
|
| 87 |
+
return self._dispatch_dict[context].try_generalizing_function_type(
|
| 88 |
+
function_type)
|
| 89 |
+
else:
|
| 90 |
+
return function_type
|
| 91 |
+
|
| 92 |
+
# TODO(b/205971333): Remove this function.
|
| 93 |
+
def clear(self):
|
| 94 |
+
"""Removes all functions from the cache."""
|
| 95 |
+
self._primary.clear()
|
| 96 |
+
self._dispatch_dict.clear()
|
| 97 |
+
|
| 98 |
+
def values(self):
|
| 99 |
+
"""Returns a list of all functions held by this cache."""
|
| 100 |
+
return list(self._primary.values())
|
| 101 |
+
|
| 102 |
+
def __len__(self):
|
| 103 |
+
return len(self._primary)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/function_type.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Represents the types of TF functions."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import inspect
|
| 19 |
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
|
| 20 |
+
|
| 21 |
+
from absl import logging
|
| 22 |
+
|
| 23 |
+
from tensorflow.core.function import trace_type
|
| 24 |
+
from tensorflow.core.function.polymorphism import function_type_pb2
|
| 25 |
+
from tensorflow.core.function.trace_type import serialization
|
| 26 |
+
from tensorflow.python.types import core
|
| 27 |
+
from tensorflow.python.types import trace
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Represents a defined parameter default value that is saved alongside the
|
| 31 |
+
# function's captures.
|
| 32 |
+
class CapturedDefaultValue:
|
| 33 |
+
def __repr__(self):
|
| 34 |
+
return "<captured_default_value>"
|
| 35 |
+
|
| 36 |
+
def __str__(self):
|
| 37 |
+
return "<captured_default_value>"
|
| 38 |
+
|
| 39 |
+
CAPTURED_DEFAULT_VALUE = CapturedDefaultValue()
|
| 40 |
+
|
| 41 |
+
PROTO_TO_PY_ENUM = {
|
| 42 |
+
function_type_pb2.Parameter.Kind.POSITIONAL_ONLY:
|
| 43 |
+
inspect.Parameter.POSITIONAL_ONLY,
|
| 44 |
+
function_type_pb2.Parameter.Kind.POSITIONAL_OR_KEYWORD:
|
| 45 |
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
| 46 |
+
function_type_pb2.Parameter.Kind.VAR_POSITIONAL:
|
| 47 |
+
inspect.Parameter.VAR_POSITIONAL,
|
| 48 |
+
function_type_pb2.Parameter.Kind.KEYWORD_ONLY:
|
| 49 |
+
inspect.Parameter.KEYWORD_ONLY,
|
| 50 |
+
function_type_pb2.Parameter.Kind.VAR_KEYWORD:
|
| 51 |
+
inspect.Parameter.VAR_KEYWORD,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
PY_TO_PROTO_ENUM = {v: k for k, v in PROTO_TO_PY_ENUM.items()}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Parameter(inspect.Parameter):
|
| 58 |
+
"""Represents a parameter to a function."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, name: str, kind: Any, optional: bool,
|
| 61 |
+
type_constraint: Optional[trace.TraceType]):
|
| 62 |
+
if optional and kind not in [
|
| 63 |
+
self.POSITIONAL_ONLY, self.KEYWORD_ONLY, self.POSITIONAL_OR_KEYWORD
|
| 64 |
+
]:
|
| 65 |
+
raise ValueError(
|
| 66 |
+
"Parameter " + name +
|
| 67 |
+
" is optional and its kind must be one of {POSITIONAL_ONLY, " +
|
| 68 |
+
"KEYWORD_ONLY, POSITIONAL_OR_KEYWORD}. Got: " + str(kind))
|
| 69 |
+
|
| 70 |
+
if type_constraint and kind in [self.VAR_POSITIONAL, self.VAR_KEYWORD]:
|
| 71 |
+
raise TypeError("Variable args/kwargs can not have type constraints.")
|
| 72 |
+
|
| 73 |
+
if not isinstance(type_constraint, (trace.TraceType, type(None))):
|
| 74 |
+
raise TypeError(
|
| 75 |
+
"Type constraints can only be an instance of a TraceType but got " +
|
| 76 |
+
"type_constraint=" + str(type_constraint) + " for Parameter " + name)
|
| 77 |
+
|
| 78 |
+
super().__init__(
|
| 79 |
+
name,
|
| 80 |
+
kind,
|
| 81 |
+
default=CAPTURED_DEFAULT_VALUE if optional else self.empty,
|
| 82 |
+
annotation=type_constraint
|
| 83 |
+
if type_constraint is not None else self.empty)
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def from_proto(cls, proto: Any) -> "Parameter":
|
| 87 |
+
"""Generate a Parameter from the proto representation."""
|
| 88 |
+
deserialized_type_constraint = serialization.deserialize(
|
| 89 |
+
proto.type_constraint) if proto.HasField("type_constraint") else None
|
| 90 |
+
return Parameter(proto.name, PROTO_TO_PY_ENUM[proto.kind],
|
| 91 |
+
proto.is_optional, deserialized_type_constraint)
|
| 92 |
+
|
| 93 |
+
def to_proto(self) -> function_type_pb2.Parameter:
|
| 94 |
+
"""Generate a proto representation of the Parameter."""
|
| 95 |
+
serialized_type_constraint = serialization.serialize(
|
| 96 |
+
self.type_constraint) if self.type_constraint else None
|
| 97 |
+
return function_type_pb2.Parameter(
|
| 98 |
+
name=self.name,
|
| 99 |
+
kind=PY_TO_PROTO_ENUM[self.kind],
|
| 100 |
+
is_optional=self.optional,
|
| 101 |
+
type_constraint=serialized_type_constraint)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def optional(self) -> bool:
|
| 105 |
+
"""If this parameter might not be supplied for a call."""
|
| 106 |
+
return self.default is not self.empty
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def type_constraint(self) -> Optional[trace.TraceType]:
|
| 110 |
+
"""A supertype that the parameter's type must subtype for validity."""
|
| 111 |
+
return self.annotation if self.annotation is not self.empty else None
|
| 112 |
+
|
| 113 |
+
def is_subtype_of(self, other: "Parameter") -> bool:
|
| 114 |
+
"""Returns True if self is a supertype of other Parameter."""
|
| 115 |
+
if not self.type_constraint or not other.type_constraint:
|
| 116 |
+
raise TypeError(
|
| 117 |
+
"Can not determine relationship between partially specified types.")
|
| 118 |
+
|
| 119 |
+
if ((self.name, self.kind, self.optional) !=
|
| 120 |
+
(other.name, other.kind, other.optional)):
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
return self.type_constraint.is_subtype_of(other.type_constraint)
|
| 124 |
+
|
| 125 |
+
def most_specific_common_supertype(
|
| 126 |
+
self, others: Sequence["Parameter"]) -> Optional["Parameter"]:
|
| 127 |
+
"""Returns a common supertype (if exists)."""
|
| 128 |
+
if not self.type_constraint or any(
|
| 129 |
+
not other.type_constraint for other in others):
|
| 130 |
+
raise TypeError(
|
| 131 |
+
"Can not determine relationship between partially specified types.")
|
| 132 |
+
|
| 133 |
+
for other in others:
|
| 134 |
+
if ((self.name, self.kind, self.optional) !=
|
| 135 |
+
(other.name, other.kind, other.optional)):
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
supertyped_constraint = self.type_constraint.most_specific_common_supertype(
|
| 139 |
+
[other.type_constraint for other in others])
|
| 140 |
+
if supertyped_constraint:
|
| 141 |
+
return Parameter(self.name, self.kind, self.optional,
|
| 142 |
+
supertyped_constraint)
|
| 143 |
+
else:
|
| 144 |
+
return None
|
| 145 |
+
|
| 146 |
+
def __eq__(self, other: Any) -> bool:
|
| 147 |
+
if not isinstance(other, Parameter):
|
| 148 |
+
return NotImplemented
|
| 149 |
+
|
| 150 |
+
return ((self.name, self.kind, self.optional,
|
| 151 |
+
self.type_constraint) == (other.name, other.kind, other.optional,
|
| 152 |
+
other.type_constraint))
|
| 153 |
+
|
| 154 |
+
def __hash__(self):
|
| 155 |
+
return hash((self.name, self.kind, self.optional, self.type_constraint))
|
| 156 |
+
|
| 157 |
+
def __repr__(self):
|
| 158 |
+
return ("Parameter(name=" + self.name + ", kind=" + str(self.kind) +
|
| 159 |
+
", optional=" + repr(self.optional) + ", type_constraint=" +
|
| 160 |
+
repr(self.type_constraint) + ")")
|
| 161 |
+
|
| 162 |
+
def __reduce__(self):
|
| 163 |
+
return (self.__class__, (self.name, self.kind, self.optional,
|
| 164 |
+
self.type_constraint))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class FunctionType(core.FunctionType):
|
| 168 |
+
"""Represents the type of a TensorFlow function.
|
| 169 |
+
|
| 170 |
+
FunctionType is the canonical way to represent the input/output contract of
|
| 171 |
+
all kinds of functions within the tf.function domain, including:
|
| 172 |
+
- Polymorphic Function
|
| 173 |
+
- Concrete Function
|
| 174 |
+
- Atomic Function
|
| 175 |
+
|
| 176 |
+
It provides consistent, centralized and layered logic for:
|
| 177 |
+
- Canonicalization of Python input arguments
|
| 178 |
+
- Type-based dispatch to monomorphic functions
|
| 179 |
+
- Packing/unpacking structured python values to Tensors
|
| 180 |
+
- Generation of structured placeholder values for tracing
|
| 181 |
+
|
| 182 |
+
Additionaly, it also provides:
|
| 183 |
+
- Lossless serialization
|
| 184 |
+
- Native integration with Python function signature representation
|
| 185 |
+
- Seamless migration from older representation formats
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
def __init__(self,
|
| 189 |
+
parameters: Sequence[inspect.Parameter],
|
| 190 |
+
captures: Optional[collections.OrderedDict] = None,
|
| 191 |
+
**kwargs):
|
| 192 |
+
super().__init__(parameters, **kwargs)
|
| 193 |
+
self._captures = captures if captures else collections.OrderedDict()
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def parameters(self) -> Mapping[str, Any]:
|
| 197 |
+
"""Returns an ordered mapping of parameter name to specification."""
|
| 198 |
+
return super().parameters
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def captures(self) -> collections.OrderedDict:
|
| 202 |
+
"""Returns an ordered mapping of capture id to type."""
|
| 203 |
+
return self._captures
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def output(self) -> Optional[trace.TraceType]:
|
| 207 |
+
"""Return the output TraceType if specified."""
|
| 208 |
+
return (
|
| 209 |
+
self.return_annotation
|
| 210 |
+
if self.return_annotation is not self.empty
|
| 211 |
+
else None
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
@classmethod
|
| 215 |
+
def from_callable(cls,
|
| 216 |
+
obj: Callable[..., Any],
|
| 217 |
+
*,
|
| 218 |
+
follow_wrapped: bool = True) -> "FunctionType":
|
| 219 |
+
"""Generate FunctionType from a python Callable."""
|
| 220 |
+
signature = super().from_callable(obj, follow_wrapped=follow_wrapped)
|
| 221 |
+
# TODO(fmuham): Support TraceType-based annotations.
|
| 222 |
+
parameters = [
|
| 223 |
+
Parameter(p.name, p.kind, p.default is not p.empty, None)
|
| 224 |
+
for p in signature.parameters.values()
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
return FunctionType(parameters)
|
| 228 |
+
|
| 229 |
+
@classmethod
|
| 230 |
+
def get_default_values(cls,
|
| 231 |
+
obj: Callable[..., Any],
|
| 232 |
+
*,
|
| 233 |
+
follow_wrapped: bool = True) -> Dict[str, Any]:
|
| 234 |
+
"""Inspects and returns a dictionary of default values."""
|
| 235 |
+
signature = super().from_callable(obj, follow_wrapped=follow_wrapped)
|
| 236 |
+
default_values = {}
|
| 237 |
+
for p in signature.parameters.values():
|
| 238 |
+
if p.default is not p.empty:
|
| 239 |
+
default_values[p.name] = p.default
|
| 240 |
+
return default_values
|
| 241 |
+
|
| 242 |
+
@classmethod
|
| 243 |
+
def from_proto(cls, proto: Any) -> "FunctionType":
|
| 244 |
+
"""Generate a FunctionType from the proto representation."""
|
| 245 |
+
return FunctionType([Parameter.from_proto(p) for p in proto.parameters],
|
| 246 |
+
collections.OrderedDict([
|
| 247 |
+
(c.name,
|
| 248 |
+
serialization.deserialize(c.type_constraint))
|
| 249 |
+
for c in proto.captures
|
| 250 |
+
]))
|
| 251 |
+
|
| 252 |
+
def to_proto(self) -> Any:
|
| 253 |
+
"""Generate a proto representation from the FunctionType."""
|
| 254 |
+
return function_type_pb2.FunctionType(
|
| 255 |
+
parameters=[p.to_proto() for p in self.parameters.values()],
|
| 256 |
+
captures=[
|
| 257 |
+
function_type_pb2.Capture(
|
| 258 |
+
name=n, type_constraint=serialization.serialize(t))
|
| 259 |
+
for n, t in self.captures.items()
|
| 260 |
+
])
|
| 261 |
+
|
| 262 |
+
def bind_with_defaults(self, args, kwargs, default_values):
|
| 263 |
+
"""Returns BoundArguments with default values filled in."""
|
| 264 |
+
bound_arguments = self.bind(*args, **kwargs)
|
| 265 |
+
bound_arguments.apply_defaults()
|
| 266 |
+
|
| 267 |
+
with_default_args = collections.OrderedDict()
|
| 268 |
+
for name, value in bound_arguments.arguments.items():
|
| 269 |
+
if value is CAPTURED_DEFAULT_VALUE:
|
| 270 |
+
with_default_args[name] = default_values[name]
|
| 271 |
+
else:
|
| 272 |
+
with_default_args[name] = value
|
| 273 |
+
|
| 274 |
+
for arg_name in with_default_args:
|
| 275 |
+
constraint = self.parameters[arg_name].type_constraint
|
| 276 |
+
if constraint:
|
| 277 |
+
with_default_args[arg_name] = constraint.cast(
|
| 278 |
+
with_default_args[arg_name],
|
| 279 |
+
trace_type.InternalCastContext(allow_specs=True),
|
| 280 |
+
)
|
| 281 |
+
bound_arguments = inspect.BoundArguments(self, with_default_args)
|
| 282 |
+
return bound_arguments
|
| 283 |
+
|
| 284 |
+
def is_supertype_of(self, other: "FunctionType") -> bool:
|
| 285 |
+
"""Returns True if self is a supertype of other FunctionType."""
|
| 286 |
+
if len(self.parameters) != len(other.parameters):
|
| 287 |
+
return False
|
| 288 |
+
|
| 289 |
+
for self_param, other_param in zip(self.parameters.values(),
|
| 290 |
+
other.parameters.values()):
|
| 291 |
+
# Functions are contravariant on their parameter types.
|
| 292 |
+
if not self_param.is_subtype_of(other_param):
|
| 293 |
+
return False
|
| 294 |
+
|
| 295 |
+
# Other must have all capture names of self.
|
| 296 |
+
if not all(name in other.captures for name in self.captures):
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
# Functions are contravariant upon the capture types.
|
| 300 |
+
return all(capture_type.is_subtype_of(other.captures[name])
|
| 301 |
+
for name, capture_type in self.captures.items())
|
| 302 |
+
|
| 303 |
+
def most_specific_common_subtype(
|
| 304 |
+
self, others: Sequence["FunctionType"]) -> Optional["FunctionType"]:
|
| 305 |
+
"""Returns a common subtype (if exists)."""
|
| 306 |
+
subtyped_parameters = []
|
| 307 |
+
|
| 308 |
+
for i, parameter in enumerate(self.parameters.values()):
|
| 309 |
+
# Functions are contravariant on their parameter types.
|
| 310 |
+
subtyped_parameter = parameter.most_specific_common_supertype(
|
| 311 |
+
[list(other.parameters.values())[i] for other in others])
|
| 312 |
+
if subtyped_parameter is None:
|
| 313 |
+
return None
|
| 314 |
+
subtyped_parameters.append(subtyped_parameter)
|
| 315 |
+
|
| 316 |
+
if not all(subtyped_parameters):
|
| 317 |
+
return None
|
| 318 |
+
|
| 319 |
+
# Common subtype has superset of all captures.
|
| 320 |
+
capture_names = set(self.captures.keys())
|
| 321 |
+
for other in others:
|
| 322 |
+
capture_names = capture_names.union(other.captures.keys())
|
| 323 |
+
|
| 324 |
+
subtyped_captures = collections.OrderedDict()
|
| 325 |
+
for name in capture_names:
|
| 326 |
+
containing = [t for t in [self, *others] if name in t.captures]
|
| 327 |
+
# Pick the first type that has the capture as the base.
|
| 328 |
+
base = containing[0]
|
| 329 |
+
relevant_others = containing[1:]
|
| 330 |
+
|
| 331 |
+
# Functions are contravariant upon the capture types.
|
| 332 |
+
common_type = base.captures[name].most_specific_common_supertype(
|
| 333 |
+
[other.captures[name] for other in relevant_others]
|
| 334 |
+
)
|
| 335 |
+
if common_type is None:
|
| 336 |
+
return None
|
| 337 |
+
else:
|
| 338 |
+
subtyped_captures[name] = common_type
|
| 339 |
+
|
| 340 |
+
return FunctionType(subtyped_parameters, subtyped_captures)
|
| 341 |
+
|
| 342 |
+
def placeholder_arguments(
|
| 343 |
+
self, placeholder_context: trace.PlaceholderContext
|
| 344 |
+
) -> inspect.BoundArguments:
|
| 345 |
+
"""Returns BoundArguments of values that can be used for tracing."""
|
| 346 |
+
arguments = collections.OrderedDict()
|
| 347 |
+
for parameter in self.parameters.values():
|
| 348 |
+
if parameter.kind in {Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD}:
|
| 349 |
+
raise ValueError("Can not generate placeholder values for "
|
| 350 |
+
"variable length function type.")
|
| 351 |
+
|
| 352 |
+
if not parameter.type_constraint:
|
| 353 |
+
raise ValueError("Can not generate placeholder value for "
|
| 354 |
+
"partially defined function type.")
|
| 355 |
+
placeholder_context.update_naming_scope(parameter.name)
|
| 356 |
+
arguments[parameter.name] = parameter.type_constraint.placeholder_value(
|
| 357 |
+
placeholder_context)
|
| 358 |
+
|
| 359 |
+
return inspect.BoundArguments(self, arguments)
|
| 360 |
+
|
| 361 |
+
@property
|
| 362 |
+
def flat_inputs(self) -> List[trace.TraceType]:
|
| 363 |
+
"""Flat tensor inputs accepted by this FunctionType."""
|
| 364 |
+
if not hasattr(self, "_cached_flat_inputs"):
|
| 365 |
+
cached_flat_inputs = []
|
| 366 |
+
for p in self.parameters.values():
|
| 367 |
+
cached_flat_inputs.extend(p.type_constraint.flatten())
|
| 368 |
+
self._cached_flat_inputs = cached_flat_inputs
|
| 369 |
+
|
| 370 |
+
return self._cached_flat_inputs
|
| 371 |
+
|
| 372 |
+
def unpack_inputs(
|
| 373 |
+
self, bound_parameters: inspect.BoundArguments
|
| 374 |
+
) -> List[core.Tensor]:
|
| 375 |
+
"""Unpacks python arguments to flat tensor inputs accepted by this type."""
|
| 376 |
+
# Sort keyword-only parameters by name.
|
| 377 |
+
sorted_parameters = []
|
| 378 |
+
kwonly_parameters = []
|
| 379 |
+
for p in self.parameters.values():
|
| 380 |
+
if p.kind is Parameter.KEYWORD_ONLY:
|
| 381 |
+
kwonly_parameters.append(p)
|
| 382 |
+
else:
|
| 383 |
+
sorted_parameters.append(p)
|
| 384 |
+
sorted_parameters = sorted_parameters + sorted(
|
| 385 |
+
kwonly_parameters, key=lambda p: p.name
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
flat = []
|
| 389 |
+
for p in sorted_parameters:
|
| 390 |
+
flat.extend(
|
| 391 |
+
p.type_constraint.to_tensors(bound_parameters.arguments[p.name])
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
dealiased_inputs = []
|
| 395 |
+
ids_used = set()
|
| 396 |
+
for tensor, input_type in zip(flat, self.flat_inputs):
|
| 397 |
+
alias_id = input_type._alias_id() # pylint: disable=protected-access
|
| 398 |
+
if alias_id is None or alias_id not in ids_used:
|
| 399 |
+
dealiased_inputs.append(tensor)
|
| 400 |
+
|
| 401 |
+
if alias_id is not None:
|
| 402 |
+
ids_used.add(alias_id)
|
| 403 |
+
|
| 404 |
+
return dealiased_inputs
|
| 405 |
+
|
| 406 |
+
@property
|
| 407 |
+
def flat_captures(self) -> List[trace.TraceType]:
|
| 408 |
+
"""Flat tensor captures needed by this FunctionType."""
|
| 409 |
+
if not hasattr(self, "_cached_flat_captures"):
|
| 410 |
+
cached_flat_captures = []
|
| 411 |
+
for t in self.captures.values():
|
| 412 |
+
cached_flat_captures.extend(t.flatten())
|
| 413 |
+
self._cached_flat_captures = cached_flat_captures
|
| 414 |
+
|
| 415 |
+
return self._cached_flat_captures
|
| 416 |
+
|
| 417 |
+
def unpack_captures(self, captures) -> List[core.Tensor]:
|
| 418 |
+
"""Unpacks captures to flat tensors."""
|
| 419 |
+
flat = []
|
| 420 |
+
for v, t in zip(captures, self.captures.values()):
|
| 421 |
+
flat.extend(t.to_tensors(v))
|
| 422 |
+
if len(flat) != len(self.flat_captures):
|
| 423 |
+
raise TypeError(
|
| 424 |
+
f"Flattening captures {captures} with type {self!r} produced"
|
| 425 |
+
f" {len(flat)} tensors instead of {len(self.flat_captures)}"
|
| 426 |
+
)
|
| 427 |
+
return flat
|
| 428 |
+
|
| 429 |
+
@property
|
| 430 |
+
def flat_outputs(self) -> List[trace.TraceType]:
|
| 431 |
+
"""Flat tensor outputs returned by this FunctionType."""
|
| 432 |
+
if not hasattr(self, "_cached_flat_outputs"):
|
| 433 |
+
if self.output is not None:
|
| 434 |
+
self._cached_flat_outputs = self.output.flatten()
|
| 435 |
+
|
| 436 |
+
return self._cached_flat_outputs
|
| 437 |
+
|
| 438 |
+
def pack_output(self, flat_values: Sequence[core.Tensor]) -> Any:
|
| 439 |
+
"""Packs flat tensors to generate a value of the output type."""
|
| 440 |
+
if flat_values is None:
|
| 441 |
+
flat_values = []
|
| 442 |
+
|
| 443 |
+
if self.output is None:
|
| 444 |
+
raise ValueError("Can not pack outputs for undefined output type.")
|
| 445 |
+
else:
|
| 446 |
+
return self.output.from_tensors(iter(flat_values))
|
| 447 |
+
|
| 448 |
+
def __eq__(self, other: Any) -> bool:
|
| 449 |
+
if not isinstance(other, FunctionType):
|
| 450 |
+
return NotImplemented
|
| 451 |
+
|
| 452 |
+
return (self.parameters, self.captures) == (other.parameters,
|
| 453 |
+
other.captures)
|
| 454 |
+
|
| 455 |
+
def __hash__(self) -> int:
|
| 456 |
+
return hash((tuple(self.parameters.items()), tuple(self.captures.items())))
|
| 457 |
+
|
| 458 |
+
def __repr__(self):
|
| 459 |
+
if hasattr(self, "_cached_repr"):
|
| 460 |
+
return self._cached_repr
|
| 461 |
+
|
| 462 |
+
lines = ["Input Parameters:"]
|
| 463 |
+
for parameter in self.parameters.values():
|
| 464 |
+
lines.append(
|
| 465 |
+
f" {parameter.name} ({parameter.kind}): {parameter.type_constraint}"
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
lines.append("Output Type:")
|
| 469 |
+
lines.append(f" {self.output}")
|
| 470 |
+
|
| 471 |
+
lines.append("Captures:")
|
| 472 |
+
if self.captures:
|
| 473 |
+
for capture_id, capture_type in self.captures.items():
|
| 474 |
+
lines.append(f" {capture_id}: {capture_type}")
|
| 475 |
+
else:
|
| 476 |
+
lines.append(" None")
|
| 477 |
+
|
| 478 |
+
self._cached_repr = "\n".join(lines)
|
| 479 |
+
return self._cached_repr
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
MAX_SANITIZATION_WARNINGS = 5
|
| 483 |
+
sanitization_warnings_given = 0
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
# TODO(fmuham): In future, replace warning with exception.
|
| 487 |
+
# TODO(fmuham): Sanitize to graph node conventions.
|
| 488 |
+
def sanitize_arg_name(name: str) -> str:
|
| 489 |
+
"""Sanitizes function argument names.
|
| 490 |
+
|
| 491 |
+
Matches Python symbol naming rules.
|
| 492 |
+
|
| 493 |
+
Without sanitization, names that are not legal Python parameter names can be
|
| 494 |
+
set which makes it challenging to represent callables supporting the named
|
| 495 |
+
calling capability.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
name: The name to sanitize.
|
| 499 |
+
|
| 500 |
+
Returns:
|
| 501 |
+
A string that meets Python parameter conventions.
|
| 502 |
+
"""
|
| 503 |
+
# Replace non-alphanumeric chars with '_'
|
| 504 |
+
swapped = "".join([c if c.isalnum() else "_" for c in name])
|
| 505 |
+
result = swapped if swapped[0].isalpha() else "arg_" + swapped
|
| 506 |
+
|
| 507 |
+
global sanitization_warnings_given
|
| 508 |
+
if name != result and sanitization_warnings_given < MAX_SANITIZATION_WARNINGS:
|
| 509 |
+
logging.warning(
|
| 510 |
+
"`%s` is not a valid tf.function parameter name. Sanitizing to `%s`.",
|
| 511 |
+
name, result)
|
| 512 |
+
sanitization_warnings_given += 1
|
| 513 |
+
|
| 514 |
+
return result
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
# TODO(fmuham): Consider forcing kind to be always POSITIONAL_OR_KEYWORD.
|
| 518 |
+
def _make_validated_mono_param(
|
| 519 |
+
name, value, kind, type_context, poly_type
|
| 520 |
+
) -> Parameter:
|
| 521 |
+
"""Generates and validates a parameter for Monomorphic FunctionType."""
|
| 522 |
+
mono_type = trace_type.from_value(value, type_context)
|
| 523 |
+
|
| 524 |
+
if poly_type and not mono_type.is_subtype_of(poly_type):
|
| 525 |
+
raise TypeError(f"Parameter `{name}` was expected to be of type "
|
| 526 |
+
f"{poly_type} but is {mono_type}")
|
| 527 |
+
|
| 528 |
+
return Parameter(name, kind, False, mono_type)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def canonicalize_to_monomorphic(
|
| 532 |
+
args: Tuple[Any, ...], kwargs: Dict[Any, Any], default_values: Dict[Any,
|
| 533 |
+
Any],
|
| 534 |
+
capture_types: collections.OrderedDict, polymorphic_type: FunctionType
|
| 535 |
+
) -> Tuple[FunctionType, trace_type.InternalTracingContext]:
|
| 536 |
+
"""Generates a monomorphic type out of polymorphic type for given args."""
|
| 537 |
+
poly_bound_arguments = polymorphic_type.bind(*args, **kwargs)
|
| 538 |
+
|
| 539 |
+
# Inject Default Values.
|
| 540 |
+
if default_values:
|
| 541 |
+
poly_bound_arguments.apply_defaults()
|
| 542 |
+
default_values_injected = poly_bound_arguments.arguments
|
| 543 |
+
for name, value in default_values_injected.items():
|
| 544 |
+
if value is CAPTURED_DEFAULT_VALUE:
|
| 545 |
+
default_values_injected[name] = default_values[name]
|
| 546 |
+
poly_bound_arguments = inspect.BoundArguments(
|
| 547 |
+
poly_bound_arguments.signature, default_values_injected
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
parameters = []
|
| 551 |
+
type_context = trace_type.InternalTracingContext()
|
| 552 |
+
has_var_positional = any(p.kind is Parameter.VAR_POSITIONAL
|
| 553 |
+
for p in polymorphic_type.parameters.values())
|
| 554 |
+
|
| 555 |
+
for name, arg in poly_bound_arguments.arguments.items():
|
| 556 |
+
poly_parameter = polymorphic_type.parameters[name]
|
| 557 |
+
if (has_var_positional and
|
| 558 |
+
poly_parameter.kind is Parameter.POSITIONAL_OR_KEYWORD):
|
| 559 |
+
# If there is a VAR_POSITIONAL, all POSITIONAL_OR_KEYWORD become
|
| 560 |
+
# POSITIONAL_ONLY.
|
| 561 |
+
parameters.append(
|
| 562 |
+
_make_validated_mono_param(name, arg, Parameter.POSITIONAL_ONLY,
|
| 563 |
+
type_context,
|
| 564 |
+
poly_parameter.type_constraint))
|
| 565 |
+
|
| 566 |
+
elif poly_parameter.kind is Parameter.VAR_POSITIONAL:
|
| 567 |
+
# Unbundle VAR_POSITIONAL into individual POSITIONAL_ONLY args.
|
| 568 |
+
for i, value in enumerate(arg):
|
| 569 |
+
parameters.append(
|
| 570 |
+
_make_validated_mono_param(f"{poly_parameter.name}_{i}", value,
|
| 571 |
+
Parameter.POSITIONAL_ONLY, type_context,
|
| 572 |
+
poly_parameter.type_constraint))
|
| 573 |
+
|
| 574 |
+
elif poly_parameter.kind is Parameter.VAR_KEYWORD:
|
| 575 |
+
# Unbundle VAR_KEYWORD into individual KEYWORD_ONLY args.
|
| 576 |
+
for kwarg_name in sorted(arg.keys()):
|
| 577 |
+
parameters.append(
|
| 578 |
+
_make_validated_mono_param(kwarg_name, arg[kwarg_name],
|
| 579 |
+
Parameter.KEYWORD_ONLY, type_context,
|
| 580 |
+
poly_parameter.type_constraint))
|
| 581 |
+
else:
|
| 582 |
+
parameters.append(
|
| 583 |
+
_make_validated_mono_param(name, arg, poly_parameter.kind,
|
| 584 |
+
type_context,
|
| 585 |
+
poly_parameter.type_constraint))
|
| 586 |
+
|
| 587 |
+
return FunctionType(parameters, capture_types), type_context
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# TODO(fmuham): Share code with canonicalize_to_monomorphic.
|
| 591 |
+
# TODO(fmuham): Lift unnecessary restrictions on input_signature validity.
|
| 592 |
+
def add_type_constraints(function_type: FunctionType, input_signature: Any,
|
| 593 |
+
default_values: Dict[str, Any]) -> FunctionType:
|
| 594 |
+
"""Adds type constraints to a FunctionType based on the input_signature."""
|
| 595 |
+
context = trace_type.InternalTracingContext(is_legacy_signature=True)
|
| 596 |
+
constraints = [trace_type.from_value(c, context) for c in input_signature]
|
| 597 |
+
parameters = []
|
| 598 |
+
|
| 599 |
+
has_var_pos = any(
|
| 600 |
+
p.kind is p.VAR_POSITIONAL for p in function_type.parameters.values())
|
| 601 |
+
|
| 602 |
+
for param in function_type.parameters.values():
|
| 603 |
+
# VAR_POSITIONAL does not allow POSITIONAL_OR_KEYWORD args.
|
| 604 |
+
sanitized_kind = (
|
| 605 |
+
param.POSITIONAL_ONLY if has_var_pos and
|
| 606 |
+
param.kind is param.POSITIONAL_OR_KEYWORD else param.kind)
|
| 607 |
+
|
| 608 |
+
if param.name == "self":
|
| 609 |
+
# Type constraints do not apply on them.
|
| 610 |
+
parameters.append(Parameter("self", sanitized_kind, param.optional, None))
|
| 611 |
+
|
| 612 |
+
elif param.kind is param.VAR_KEYWORD:
|
| 613 |
+
# Disabled when input_signature is specified.
|
| 614 |
+
continue
|
| 615 |
+
|
| 616 |
+
elif param.kind is param.VAR_POSITIONAL:
|
| 617 |
+
# Convert into Positional Only args based on length of constraints.
|
| 618 |
+
for i in range(len(constraints)):
|
| 619 |
+
parameters.append(
|
| 620 |
+
Parameter(param.name + "_" + str(i), Parameter.POSITIONAL_ONLY,
|
| 621 |
+
False, constraints.pop(0)))
|
| 622 |
+
|
| 623 |
+
elif (param.kind in [
|
| 624 |
+
param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY
|
| 625 |
+
]):
|
| 626 |
+
if param.kind is param.KEYWORD_ONLY and param.name not in default_values:
|
| 627 |
+
raise TypeError(
|
| 628 |
+
"Since input_signature is defined, keyword-only parameter"
|
| 629 |
+
f" `{param.name}` must have a default value"
|
| 630 |
+
)
|
| 631 |
+
|
| 632 |
+
if constraints:
|
| 633 |
+
parameters.append(
|
| 634 |
+
Parameter(param.name, sanitized_kind, param.optional,
|
| 635 |
+
constraints.pop(0)))
|
| 636 |
+
elif param.name in default_values:
|
| 637 |
+
type_constraint = trace_type.from_value(default_values[param.name])
|
| 638 |
+
parameters.append(
|
| 639 |
+
Parameter(param.name, sanitized_kind, param.optional,
|
| 640 |
+
type_constraint))
|
| 641 |
+
else:
|
| 642 |
+
raise TypeError(
|
| 643 |
+
f"input_signature missing type constraint for {param.name}")
|
| 644 |
+
|
| 645 |
+
if constraints:
|
| 646 |
+
raise TypeError(
|
| 647 |
+
f"input_signature contains {len(constraints)} extra type constraints.")
|
| 648 |
+
|
| 649 |
+
return FunctionType(parameters)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def from_structured_signature(
|
| 653 |
+
input_signature=None, output_signature=None, capture_types=None
|
| 654 |
+
) -> FunctionType:
|
| 655 |
+
"""Generates a FunctionType from legacy signature representation."""
|
| 656 |
+
if input_signature is None:
|
| 657 |
+
input_signature = ((), {})
|
| 658 |
+
|
| 659 |
+
args, kwargs = input_signature
|
| 660 |
+
parameters = []
|
| 661 |
+
|
| 662 |
+
for i, arg in enumerate(args):
|
| 663 |
+
parameters.append(
|
| 664 |
+
Parameter(
|
| 665 |
+
"arg_" + str(i),
|
| 666 |
+
Parameter.POSITIONAL_ONLY,
|
| 667 |
+
False,
|
| 668 |
+
trace_type.from_value(
|
| 669 |
+
arg, trace_type.InternalTracingContext(is_legacy_signature=True)
|
| 670 |
+
),
|
| 671 |
+
)
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
for name, kwarg in kwargs.items():
|
| 675 |
+
parameters.append(
|
| 676 |
+
Parameter(
|
| 677 |
+
sanitize_arg_name(name),
|
| 678 |
+
Parameter.KEYWORD_ONLY,
|
| 679 |
+
False,
|
| 680 |
+
trace_type.from_value(
|
| 681 |
+
kwarg,
|
| 682 |
+
trace_type.InternalTracingContext(is_legacy_signature=True),
|
| 683 |
+
),
|
| 684 |
+
)
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
return_type = trace_type.from_value(
|
| 688 |
+
output_signature,
|
| 689 |
+
trace_type.InternalTracingContext(is_legacy_signature=True),
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
return FunctionType(
|
| 693 |
+
parameters, capture_types or {}, return_annotation=return_type
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def to_structured_signature(function_type: FunctionType) -> Tuple[Any, Any]:
|
| 698 |
+
"""Returns structured input and output signatures from a FunctionType."""
|
| 699 |
+
def to_signature(x_type):
|
| 700 |
+
if x_type is None:
|
| 701 |
+
raise TypeError(
|
| 702 |
+
"Can not generate structured signature if FunctionType is not fully"
|
| 703 |
+
f" specified. Received {function_type}"
|
| 704 |
+
)
|
| 705 |
+
return x_type.placeholder_value(
|
| 706 |
+
trace_type.InternalPlaceholderContext(unnest_only=True)
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
args_signature = []
|
| 710 |
+
kwargs_signature = {}
|
| 711 |
+
for p in function_type.parameters.values():
|
| 712 |
+
if p.kind == Parameter.POSITIONAL_ONLY:
|
| 713 |
+
args_signature.append(to_signature(p.type_constraint))
|
| 714 |
+
else:
|
| 715 |
+
kwargs_signature[p.name] = to_signature(p.type_constraint)
|
| 716 |
+
|
| 717 |
+
input_signature = (tuple(args_signature), kwargs_signature)
|
| 718 |
+
output_signature = to_signature(function_type.output)
|
| 719 |
+
|
| 720 |
+
return input_signature, output_signature
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/function_type_pb2.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/core/function/polymorphism/function_type.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from tensorflow.core.function.trace_type import serialization_pb2 as tensorflow_dot_core_dot_function_dot_trace__type_dot_serialization__pb2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n9tensorflow/core/function/polymorphism/function_type.proto\x12\x33tensorflow.core.function.polymorphism.function_type\x1a\x37tensorflow/core/function/trace_type/serialization.proto\"\xe0\x02\n\tParameter\x12\x0c\n\x04name\x18\x01 \x01(\t\x12Q\n\x04kind\x18\x02 \x01(\x0e\x32\x43.tensorflow.core.function.polymorphism.function_type.Parameter.Kind\x12\x13\n\x0bis_optional\x18\x03 \x01(\x08\x12_\n\x0ftype_constraint\x18\x04 \x01(\x0b\x32\x46.tensorflow.core.function.trace_type.serialization.SerializedTraceType\"|\n\x04Kind\x12\r\n\tUNDEFINED\x10\x00\x12\x13\n\x0fPOSITIONAL_ONLY\x10\x01\x12\x19\n\x15POSITIONAL_OR_KEYWORD\x10\x02\x12\x12\n\x0eVAR_POSITIONAL\x10\x03\x12\x10\n\x0cKEYWORD_ONLY\x10\x04\x12\x0f\n\x0bVAR_KEYWORD\x10\x05\"x\n\x07\x43\x61pture\x12\x0c\n\x04name\x18\x01 \x01(\t\x12_\n\x0ftype_constraint\x18\x02 \x01(\x0b\x32\x46.tensorflow.core.function.trace_type.serialization.SerializedTraceType\"\xb2\x01\n\x0c\x46unctionType\x12R\n\nparameters\x18\x01 \x03(\x0b\x32>.tensorflow.core.function.polymorphism.function_type.Parameter\x12N\n\x08\x63\x61ptures\x18\x02 \x03(\x0b\x32<.tensorflow.core.function.polymorphism.function_type.Capture')
|
| 18 |
+
|
| 19 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 20 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.core.function.polymorphism.function_type_pb2', globals())
|
| 21 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 22 |
+
|
| 23 |
+
DESCRIPTOR._options = None
|
| 24 |
+
_PARAMETER._serialized_start=172
|
| 25 |
+
_PARAMETER._serialized_end=524
|
| 26 |
+
_PARAMETER_KIND._serialized_start=400
|
| 27 |
+
_PARAMETER_KIND._serialized_end=524
|
| 28 |
+
_CAPTURE._serialized_start=526
|
| 29 |
+
_CAPTURE._serialized_end=646
|
| 30 |
+
_FUNCTIONTYPE._serialized_start=649
|
| 31 |
+
_FUNCTIONTYPE._serialized_end=827
|
| 32 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/polymorphism/type_dispatch.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Polymorphic Type Dispatch."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
from typing import Optional, Iterable
|
| 19 |
+
|
| 20 |
+
from tensorflow.core.function.polymorphism import function_type
|
| 21 |
+
|
| 22 |
+
# The maximum number of dispatch lookups to cache.
|
| 23 |
+
_MAX_DISPATCH_CACHE = 1024
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TypeDispatchTable:
|
| 27 |
+
"""Type dispatch table implementation.
|
| 28 |
+
|
| 29 |
+
A type dispatch table is a list, L, of target types. Given a request type, R,
|
| 30 |
+
the table selects a target type, T, according to the following dispatch rules:
|
| 31 |
+
1. R == T or R is supertype of T (functions are contravariant on args)
|
| 32 |
+
2. There does not exist O in L such that R is supertype of O and O is a
|
| 33 |
+
supertype of T (in other words, T is the closest to R, within list L).
|
| 34 |
+
3. If the above two rules are satisfied by multiple targets, the earliest
|
| 35 |
+
inserted one is chosen.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
"""Creates a TypeDispatchTable object."""
|
| 40 |
+
# Holds all inserted types as keys mapping to None.
|
| 41 |
+
# (Using OrderedDict as a set for determinism)
|
| 42 |
+
self._dispatch_table = collections.OrderedDict()
|
| 43 |
+
|
| 44 |
+
# LRU cache for dispatch results.
|
| 45 |
+
# Maps request types to target types (see class description).
|
| 46 |
+
# Does not contain exact matches, i.e, if cache[a] is b then a is not b.
|
| 47 |
+
self._dispatch_cache = collections.OrderedDict()
|
| 48 |
+
|
| 49 |
+
def add_target(self, target: function_type.FunctionType) -> None:
|
| 50 |
+
"""Adds a new target type."""
|
| 51 |
+
self._dispatch_table[target] = None
|
| 52 |
+
for request in self._dispatch_cache:
|
| 53 |
+
if target.is_supertype_of(self._dispatch_cache[request]):
|
| 54 |
+
self._dispatch_cache[request] = target
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def targets(self) -> Iterable[function_type.FunctionType]:
|
| 58 |
+
"""Returns an iterable to all targets in the table."""
|
| 59 |
+
return self._dispatch_table.keys()
|
| 60 |
+
|
| 61 |
+
def delete(self, target: function_type.FunctionType) -> None:
|
| 62 |
+
"""Deletes a target in the table if it exists."""
|
| 63 |
+
if target in self._dispatch_table:
|
| 64 |
+
del self._dispatch_table[target]
|
| 65 |
+
for request in list(self._dispatch_cache.keys()):
|
| 66 |
+
if self._dispatch_cache[request] == target:
|
| 67 |
+
del self._dispatch_cache[request]
|
| 68 |
+
|
| 69 |
+
# TODO(b/205971333): remove once FunctionCache 'clear' is removed.
|
| 70 |
+
def clear(self) -> None:
|
| 71 |
+
"""Deletes all targets in the table."""
|
| 72 |
+
self._dispatch_table.clear()
|
| 73 |
+
self._dispatch_cache.clear()
|
| 74 |
+
|
| 75 |
+
def dispatch(
|
| 76 |
+
self, request: function_type.FunctionType
|
| 77 |
+
) -> Optional[function_type.FunctionType]:
|
| 78 |
+
"""Returns the most specific supertype target if it exists in the table."""
|
| 79 |
+
# For known exact matches.
|
| 80 |
+
if request in self._dispatch_table:
|
| 81 |
+
return request
|
| 82 |
+
|
| 83 |
+
# For known non-exact matches.
|
| 84 |
+
# (self._dispatch cache does not contain exact matches)
|
| 85 |
+
if request in self._dispatch_cache:
|
| 86 |
+
# Move to the front of LRU cache.
|
| 87 |
+
result = self._dispatch_cache.pop(request)
|
| 88 |
+
self._dispatch_cache[request] = result
|
| 89 |
+
return result
|
| 90 |
+
|
| 91 |
+
most_specific_supertype = None
|
| 92 |
+
for other in self._dispatch_table:
|
| 93 |
+
if request.is_supertype_of(other):
|
| 94 |
+
if most_specific_supertype is None or other.is_supertype_of(
|
| 95 |
+
most_specific_supertype):
|
| 96 |
+
most_specific_supertype = other
|
| 97 |
+
|
| 98 |
+
self._cache_dispatch(request, most_specific_supertype)
|
| 99 |
+
return most_specific_supertype
|
| 100 |
+
|
| 101 |
+
def _cache_dispatch(self, request, target):
|
| 102 |
+
"""Caches the dispatch lookup result for a target."""
|
| 103 |
+
if target is not None:
|
| 104 |
+
# LRU Cache removes oldest item
|
| 105 |
+
if len(self._dispatch_cache) > _MAX_DISPATCH_CACHE:
|
| 106 |
+
self._dispatch_cache.popitem(last=False)
|
| 107 |
+
self._dispatch_cache[request] = target
|
| 108 |
+
|
| 109 |
+
def try_generalizing_function_type(
|
| 110 |
+
self, target: function_type.FunctionType) -> function_type.FunctionType:
|
| 111 |
+
"""Returns a generalized subtype of the one given.
|
| 112 |
+
|
| 113 |
+
This heuristic aims to reduce the number of future traces by computing a
|
| 114 |
+
type that represents more general function inputs.
|
| 115 |
+
|
| 116 |
+
The original "experimental_relax_shapes" heuristic identified a known type
|
| 117 |
+
which shared a common subtype with the current unknown type and then
|
| 118 |
+
traced with that common subtype. However, the notion of "common subtype"
|
| 119 |
+
was only limited to shapes. This heuristic extends that to FunctionType.
|
| 120 |
+
|
| 121 |
+
Returns `target` if a generalized subtype can not be found.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
target: The FunctionType to generalize
|
| 125 |
+
"""
|
| 126 |
+
relaxed = target
|
| 127 |
+
for other in self._dispatch_table:
|
| 128 |
+
subtype = relaxed.most_specific_common_subtype([other])
|
| 129 |
+
if subtype is not None:
|
| 130 |
+
relaxed = subtype
|
| 131 |
+
return relaxed
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Trace-time type system for tf.function (TraceType).
|
| 16 |
+
|
| 17 |
+
Trace-time types describe things like tf.function signatures and type
|
| 18 |
+
constraints in some ops.
|
| 19 |
+
|
| 20 |
+
This module provides utilities and concrete tf.types.experimental.TraceType
|
| 21 |
+
definitions for common Python types like containers, along with a generic
|
| 22 |
+
implementation for Python objects.
|
| 23 |
+
See also: tf.types.experimental.TraceType
|
| 24 |
+
|
| 25 |
+
Other implementations of TraceType include tf.TypeSpec and its subclasses.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from tensorflow.core.function.trace_type.default_types import register_tensor_type
|
| 29 |
+
from tensorflow.core.function.trace_type.default_types import Weakref
|
| 30 |
+
from tensorflow.core.function.trace_type.serialization import deserialize
|
| 31 |
+
from tensorflow.core.function.trace_type.serialization import register_serializable
|
| 32 |
+
from tensorflow.core.function.trace_type.serialization import Serializable
|
| 33 |
+
from tensorflow.core.function.trace_type.serialization import serialize
|
| 34 |
+
from tensorflow.core.function.trace_type.serialization import SerializedTraceType
|
| 35 |
+
from tensorflow.core.function.trace_type.trace_type_builder import from_value
|
| 36 |
+
from tensorflow.core.function.trace_type.trace_type_builder import InternalCastContext
|
| 37 |
+
from tensorflow.core.function.trace_type.trace_type_builder import InternalPlaceholderContext
|
| 38 |
+
from tensorflow.core.function.trace_type.trace_type_builder import InternalTracingContext
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__pycache__/serialization_test_pb2.cpython-310.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__pycache__/trace_type_builder.cpython-310.pyc
ADDED
|
Binary file (7.85 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (1.33 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/custom_nest_trace_type.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""TraceType implementations for classes thatimplement the CustomNestProtocol."""
|
| 16 |
+
|
| 17 |
+
from typing import Any, Iterator, List as PythonList, Optional, Sequence, Tuple as PythonTuple, Type
|
| 18 |
+
|
| 19 |
+
from tensorflow.core.function.trace_type import util
|
| 20 |
+
from tensorflow.python.types import trace
|
| 21 |
+
from tensorflow.python.util import custom_nest_protocol
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class CustomNestTraceType(trace.TraceType):
|
| 25 |
+
"""Represents the TraceType of a class implmenting the CustomNestProtocol."""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
value_type: Type[Any],
|
| 30 |
+
metadata: Any,
|
| 31 |
+
components: PythonTuple[trace.TraceType],
|
| 32 |
+
):
|
| 33 |
+
if not issubclass(value_type, custom_nest_protocol.CustomNestProtocol):
|
| 34 |
+
raise ValueError(f"{value_type!r} does not implement CustomNestProtocol.")
|
| 35 |
+
self.value_type = value_type
|
| 36 |
+
self.metadata = metadata
|
| 37 |
+
self.components = components
|
| 38 |
+
|
| 39 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 40 |
+
if not self._is_same_trace_type(other):
|
| 41 |
+
return False
|
| 42 |
+
for c_self, c_other in zip(self.components, other.components): # pytype: disable=attribute-error
|
| 43 |
+
if not c_self.is_subtype_of(c_other):
|
| 44 |
+
return False
|
| 45 |
+
return True
|
| 46 |
+
|
| 47 |
+
def most_specific_common_supertype(
|
| 48 |
+
self, others: Sequence[trace.TraceType]
|
| 49 |
+
) -> Optional["CustomNestTraceType"]:
|
| 50 |
+
for other in others:
|
| 51 |
+
if not self._is_same_trace_type(other):
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
others_components = [other.components for other in others] # pytype: disable=attribute-error
|
| 55 |
+
supertyped_components = tuple(
|
| 56 |
+
self_component.most_specific_common_supertype(others_component)
|
| 57 |
+
for self_component, *others_component in zip(
|
| 58 |
+
self.components, *others_components
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
return CustomNestTraceType(
|
| 62 |
+
self.value_type, self.metadata, supertyped_components
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def __eq__(self, other: trace.TraceType) -> bool:
|
| 66 |
+
return (
|
| 67 |
+
isinstance(other, CustomNestTraceType)
|
| 68 |
+
and self.value_type == other.value_type
|
| 69 |
+
and self.metadata == other.metadata
|
| 70 |
+
and self.components == other.components
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def __hash__(self) -> int:
|
| 74 |
+
# The hash computation doesn't use self.metadata, so unhashable metadata can
|
| 75 |
+
# be used. The `self.__eq__` method is used instead to differentiate between
|
| 76 |
+
# two objects with the same components but different metadata.
|
| 77 |
+
return hash((self.value_type, self.components))
|
| 78 |
+
|
| 79 |
+
def __repr__(self) -> str:
|
| 80 |
+
return (
|
| 81 |
+
f"{self.__class__.__name__} [metadata={self.metadata!r}, "
|
| 82 |
+
f"components={self.components!r}]"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def placeholder_value(self, placeholder_context: Any) -> Any:
|
| 86 |
+
components_placeholder_value = tuple(
|
| 87 |
+
c.placeholder_value(placeholder_context) for c in self.components
|
| 88 |
+
)
|
| 89 |
+
return self.value_type.__tf_unflatten__(
|
| 90 |
+
self.metadata, components_placeholder_value
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def to_tensors(self, value: Any) -> PythonList[Any]:
|
| 94 |
+
if not isinstance(value, self.value_type):
|
| 95 |
+
raise TypeError(f"{value!r} is not of type {self.value_type}.")
|
| 96 |
+
_, value_components = value.__tf_flatten__()
|
| 97 |
+
flattened_values = []
|
| 98 |
+
for value_comp, type_comp in zip(value_components, self.components):
|
| 99 |
+
flattened_values.extend(type_comp.to_tensors(value_comp))
|
| 100 |
+
return flattened_values
|
| 101 |
+
|
| 102 |
+
def from_tensors(self, tensors: Iterator[Any]) -> Any:
|
| 103 |
+
return self.value_type.__tf_unflatten__(
|
| 104 |
+
self.metadata, tuple(c.from_tensors(tensors) for c in self.components)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def flatten(self) -> PythonList[trace.TraceType]:
|
| 108 |
+
flat_list = []
|
| 109 |
+
for c in self.components:
|
| 110 |
+
flat_list.extend(c.flatten())
|
| 111 |
+
return flat_list
|
| 112 |
+
|
| 113 |
+
def cast(self, value: Any, casting_context: Any) -> Any:
|
| 114 |
+
if not isinstance(value, self.value_type):
|
| 115 |
+
raise TypeError(f"[{value!r}] is not of type {self.value_type}.")
|
| 116 |
+
value_metadata, value_components = value.__tf_flatten__()
|
| 117 |
+
if self.metadata != value_metadata:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Metadata mismatch: [{self.metadata!r}] != [{value_metadata!r}]."
|
| 120 |
+
)
|
| 121 |
+
if len(self.components) != len(value_components):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"Lengths of components mismatch: {len(self.components)} != "
|
| 124 |
+
f"{len(value_components)}."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
casted_value_components, was_casted = util.cast_and_return_whether_casted(
|
| 128 |
+
self.components, value_components, casting_context
|
| 129 |
+
)
|
| 130 |
+
if was_casted:
|
| 131 |
+
return self.value_type.__tf_unflatten__(
|
| 132 |
+
self.metadata, casted_value_components
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
return value
|
| 136 |
+
|
| 137 |
+
def _is_same_trace_type(self, other: trace.TraceType) -> bool:
|
| 138 |
+
return (
|
| 139 |
+
isinstance(other, CustomNestTraceType)
|
| 140 |
+
and self.value_type == other.value_type
|
| 141 |
+
and self.metadata == other.metadata
|
| 142 |
+
and len(self.components) == len(other.components)
|
| 143 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/default_types.py
ADDED
|
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""TraceType implementations for common Python types."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import math
|
| 19 |
+
import numbers
|
| 20 |
+
from typing import Any, Dict as PythonDict, Hashable, List as PythonList, Optional, Sequence, Tuple as PythonTuple, Type
|
| 21 |
+
import weakref
|
| 22 |
+
|
| 23 |
+
from tensorflow.core.function.trace_type import default_types_pb2
|
| 24 |
+
from tensorflow.core.function.trace_type import serialization
|
| 25 |
+
from tensorflow.core.function.trace_type import util
|
| 26 |
+
from tensorflow.python.types import trace
|
| 27 |
+
|
| 28 |
+
# Register the TraceType of Tensor (aka TensorSpec) to avoid cyclic dependency.
|
| 29 |
+
TENSOR = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def register_tensor_type(tensor_type):
|
| 33 |
+
global TENSOR
|
| 34 |
+
if not TENSOR:
|
| 35 |
+
TENSOR = tensor_type
|
| 36 |
+
else:
|
| 37 |
+
raise AssertionError("Tensor type is already registered.")
|
| 38 |
+
|
| 39 |
+
NanMarker = object()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def is_nan(x):
|
| 43 |
+
"""Checks if given value is a Python NaN."""
|
| 44 |
+
if not isinstance(x, numbers.Number):
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
if isinstance(x, complex):
|
| 48 |
+
return math.isnan(x.real) or math.isnan(x.imag)
|
| 49 |
+
else:
|
| 50 |
+
return math.isnan(x)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Literal(trace.TraceType, serialization.Serializable):
|
| 54 |
+
"""Represents a Literal type like bool, int or string."""
|
| 55 |
+
|
| 56 |
+
def __init__(self, value: Any):
|
| 57 |
+
# We match nan values against each other even though Python doesn't.
|
| 58 |
+
if is_nan(value):
|
| 59 |
+
value = NanMarker
|
| 60 |
+
|
| 61 |
+
self.value = value
|
| 62 |
+
self._value_hash = hash(value)
|
| 63 |
+
|
| 64 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 65 |
+
return self == other
|
| 66 |
+
|
| 67 |
+
def most_specific_common_supertype(
|
| 68 |
+
self, types: Sequence[trace.TraceType]) -> Optional["Literal"]:
|
| 69 |
+
return self if all(self == other for other in types) else None
|
| 70 |
+
|
| 71 |
+
@classmethod
|
| 72 |
+
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedLiteral]:
|
| 73 |
+
return default_types_pb2.SerializedLiteral
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def experimental_from_proto(
|
| 77 |
+
cls, proto: default_types_pb2.SerializedLiteral) -> "Literal":
|
| 78 |
+
if proto.HasField("bool_value"):
|
| 79 |
+
return Literal(proto.bool_value)
|
| 80 |
+
|
| 81 |
+
if proto.HasField("int_value"):
|
| 82 |
+
return Literal(proto.int_value)
|
| 83 |
+
|
| 84 |
+
if proto.HasField("float_value"):
|
| 85 |
+
return Literal(proto.float_value)
|
| 86 |
+
|
| 87 |
+
if proto.HasField("str_value"):
|
| 88 |
+
return Literal(proto.str_value)
|
| 89 |
+
|
| 90 |
+
if proto.HasField("none_value"):
|
| 91 |
+
return Literal(None)
|
| 92 |
+
|
| 93 |
+
raise ValueError("Malformed Literal proto can not be deserialized")
|
| 94 |
+
|
| 95 |
+
def experimental_as_proto(self) -> default_types_pb2.SerializedLiteral:
|
| 96 |
+
if isinstance(self.value, bool):
|
| 97 |
+
return default_types_pb2.SerializedLiteral(bool_value=self.value)
|
| 98 |
+
|
| 99 |
+
if isinstance(self.value, int):
|
| 100 |
+
return default_types_pb2.SerializedLiteral(int_value=self.value)
|
| 101 |
+
|
| 102 |
+
if isinstance(self.value, float):
|
| 103 |
+
return default_types_pb2.SerializedLiteral(float_value=self.value)
|
| 104 |
+
|
| 105 |
+
if isinstance(self.value, str):
|
| 106 |
+
return default_types_pb2.SerializedLiteral(str_value=self.value)
|
| 107 |
+
|
| 108 |
+
if self.value is None:
|
| 109 |
+
return default_types_pb2.SerializedLiteral(
|
| 110 |
+
none_value=default_types_pb2.SerializedLiteral.NoneValue())
|
| 111 |
+
|
| 112 |
+
raise ValueError("Can not serialize Literal of type " +
|
| 113 |
+
type(self.value).__name__)
|
| 114 |
+
|
| 115 |
+
def placeholder_value(self, placeholder_context) -> Any:
|
| 116 |
+
# TODO(b/263505796): Remove this check when a range's placeholder output
|
| 117 |
+
# is expected to be a range and not a list.
|
| 118 |
+
if isinstance(self.value, range):
|
| 119 |
+
return list(self.value)
|
| 120 |
+
|
| 121 |
+
if self.value is NanMarker:
|
| 122 |
+
return float("nan")
|
| 123 |
+
|
| 124 |
+
return self.value
|
| 125 |
+
|
| 126 |
+
def cast(self, value: Any, casting_context: Any) -> Any:
|
| 127 |
+
if self.value is NanMarker and is_nan(value):
|
| 128 |
+
return value
|
| 129 |
+
|
| 130 |
+
if value == self.value:
|
| 131 |
+
return value
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Can not cast {value!r} to {self!r}")
|
| 134 |
+
|
| 135 |
+
def __eq__(self, other) -> bool:
|
| 136 |
+
if not isinstance(other, trace.TraceType):
|
| 137 |
+
return NotImplemented
|
| 138 |
+
|
| 139 |
+
return isinstance(other, Literal) and self.value == other.value
|
| 140 |
+
|
| 141 |
+
def __hash__(self) -> int:
|
| 142 |
+
return self._value_hash
|
| 143 |
+
|
| 144 |
+
def __repr__(self) -> str:
|
| 145 |
+
return f"{self.__class__.__name__}[{self.value!r}]"
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Weakref(trace.TraceType):
|
| 149 |
+
"""Represents weakref of an arbitrary Python object.
|
| 150 |
+
|
| 151 |
+
When a function argument is a custom class, instead of making a copy of it
|
| 152 |
+
just for the sake of function cache, a weakref is instead kept to save memory.
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, ref: weakref.ReferenceType):
|
| 156 |
+
self._ref = ref
|
| 157 |
+
self._ref_hash = hash(ref)
|
| 158 |
+
|
| 159 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 160 |
+
return self == other
|
| 161 |
+
|
| 162 |
+
def most_specific_common_supertype(
|
| 163 |
+
self, types: Sequence[trace.TraceType]) -> Optional["Weakref"]:
|
| 164 |
+
return self if all(self == other for other in types) else None
|
| 165 |
+
|
| 166 |
+
def placeholder_value(self, placeholder_context) -> Any:
|
| 167 |
+
return self._ref()
|
| 168 |
+
|
| 169 |
+
def cast(self, value, _):
|
| 170 |
+
if value is self._ref() or value == self._ref():
|
| 171 |
+
return value
|
| 172 |
+
|
| 173 |
+
# We unwrap objects when generating the TraceType so we allow matching now.
|
| 174 |
+
while hasattr(value, "__wrapped__"):
|
| 175 |
+
value = value.__wrapped__
|
| 176 |
+
if value is self._ref():
|
| 177 |
+
return value
|
| 178 |
+
|
| 179 |
+
raise ValueError(f"Can not cast {value!r} to {self!r}")
|
| 180 |
+
|
| 181 |
+
def __eq__(self, other):
|
| 182 |
+
if not isinstance(other, trace.TraceType):
|
| 183 |
+
return NotImplemented
|
| 184 |
+
|
| 185 |
+
if not isinstance(other, Weakref):
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
if self._ref() is None or other._ref() is None:
|
| 189 |
+
return False
|
| 190 |
+
|
| 191 |
+
if self._ref() is other._ref():
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
return self._ref == other._ref
|
| 195 |
+
|
| 196 |
+
def __hash__(self):
|
| 197 |
+
return self._ref_hash
|
| 198 |
+
|
| 199 |
+
def __repr__(self) -> str:
|
| 200 |
+
return f"{self.__class__.__name__}[{self._ref!r}])"
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class Tuple(trace.TraceType, serialization.Serializable):
|
| 204 |
+
"""Represents a tuple of TraceType objects."""
|
| 205 |
+
|
| 206 |
+
def __init__(self, *components: trace.TraceType):
|
| 207 |
+
self.components = components
|
| 208 |
+
|
| 209 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 210 |
+
if (not isinstance(other, Tuple) or
|
| 211 |
+
len(self.components) != len(other.components)):
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
return all(
|
| 215 |
+
self_component.is_subtype_of(other_component) for self_component,
|
| 216 |
+
other_component in zip(self.components, other.components))
|
| 217 |
+
|
| 218 |
+
def most_specific_common_supertype(
|
| 219 |
+
self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
|
| 220 |
+
"""See base class."""
|
| 221 |
+
if not all(
|
| 222 |
+
isinstance(other, Tuple) and
|
| 223 |
+
len(self.components) == len(other.components) for other in others):
|
| 224 |
+
return None
|
| 225 |
+
|
| 226 |
+
supertyped_components = []
|
| 227 |
+
for i, component in enumerate(self.components):
|
| 228 |
+
supertyped_component = component.most_specific_common_supertype(
|
| 229 |
+
[other.components[i] for other in others])
|
| 230 |
+
if supertyped_component is None:
|
| 231 |
+
return None
|
| 232 |
+
supertyped_components.append(supertyped_component)
|
| 233 |
+
|
| 234 |
+
return Tuple(*supertyped_components)
|
| 235 |
+
|
| 236 |
+
@classmethod
|
| 237 |
+
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedTuple]:
|
| 238 |
+
return default_types_pb2.SerializedTuple
|
| 239 |
+
|
| 240 |
+
@classmethod
|
| 241 |
+
def experimental_from_proto(
|
| 242 |
+
cls, proto: default_types_pb2.SerializedTuple) -> "Tuple":
|
| 243 |
+
return Tuple(*[serialization.deserialize(c) for c in proto.components])
|
| 244 |
+
|
| 245 |
+
def experimental_as_proto(self) -> default_types_pb2.SerializedTuple:
|
| 246 |
+
return default_types_pb2.SerializedTuple(
|
| 247 |
+
components=[serialization.serialize(c) for c in self.components])
|
| 248 |
+
|
| 249 |
+
def placeholder_value(self, placeholder_context) -> Any:
|
| 250 |
+
components = [
|
| 251 |
+
component.placeholder_value(placeholder_context)
|
| 252 |
+
for component in self.components
|
| 253 |
+
]
|
| 254 |
+
return tuple(components)
|
| 255 |
+
|
| 256 |
+
def to_tensors(self, value) -> Any:
|
| 257 |
+
assert isinstance(value, tuple)
|
| 258 |
+
flattened_values = []
|
| 259 |
+
for comp_value, comp_type in zip(value, self.components):
|
| 260 |
+
flattened_values.extend(comp_type.to_tensors(comp_value))
|
| 261 |
+
return flattened_values
|
| 262 |
+
|
| 263 |
+
def from_tensors(self, tensors) -> Any:
|
| 264 |
+
return tuple(c.from_tensors(tensors) for c in self.components)
|
| 265 |
+
|
| 266 |
+
def flatten(self) -> PythonList[trace.TraceType]:
|
| 267 |
+
flattened_types = []
|
| 268 |
+
for component in self.components:
|
| 269 |
+
flattened_types.extend(component.flatten())
|
| 270 |
+
return flattened_types
|
| 271 |
+
|
| 272 |
+
def cast(self, value: Any, casting_context) -> Any:
|
| 273 |
+
assert isinstance(value, tuple), f"Can not cast {value!r} to tuple type."
|
| 274 |
+
assert len(value) == len(
|
| 275 |
+
self.components
|
| 276 |
+
), f"Expected {value} to have length of {len(self.components)}"
|
| 277 |
+
|
| 278 |
+
casted_values, was_casted = util.cast_and_return_whether_casted(
|
| 279 |
+
self.components, value, casting_context
|
| 280 |
+
)
|
| 281 |
+
if was_casted:
|
| 282 |
+
return tuple(casted_values)
|
| 283 |
+
else:
|
| 284 |
+
return value
|
| 285 |
+
|
| 286 |
+
def __eq__(self, other: Any) -> bool:
|
| 287 |
+
if not isinstance(other, trace.TraceType):
|
| 288 |
+
return NotImplemented
|
| 289 |
+
|
| 290 |
+
if not isinstance(other, Tuple):
|
| 291 |
+
return False
|
| 292 |
+
|
| 293 |
+
return self.components == other.components
|
| 294 |
+
|
| 295 |
+
def __hash__(self) -> int:
|
| 296 |
+
return hash(self.components)
|
| 297 |
+
|
| 298 |
+
def __repr__(self) -> str:
|
| 299 |
+
return f"Tuple[{', '.join(map(repr, self.components))}]"
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class List(trace.TraceType, serialization.Serializable):
|
| 303 |
+
"""Represents a list of TraceType objects."""
|
| 304 |
+
|
| 305 |
+
def __init__(self, *components: trace.TraceType):
|
| 306 |
+
self.components_tuple = Tuple(*components)
|
| 307 |
+
|
| 308 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 309 |
+
if not isinstance(other, List):
|
| 310 |
+
return False
|
| 311 |
+
|
| 312 |
+
return self.components_tuple.is_subtype_of(other.components_tuple)
|
| 313 |
+
|
| 314 |
+
def most_specific_common_supertype(
|
| 315 |
+
self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
|
| 316 |
+
"""See base class."""
|
| 317 |
+
if not all(isinstance(other, List) for other in others):
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
supertyped_components_tuple = (
|
| 321 |
+
self.components_tuple.most_specific_common_supertype(
|
| 322 |
+
[other.components_tuple for other in others]
|
| 323 |
+
)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if supertyped_components_tuple is None:
|
| 327 |
+
return None
|
| 328 |
+
|
| 329 |
+
return List(*supertyped_components_tuple.components)
|
| 330 |
+
|
| 331 |
+
@classmethod
|
| 332 |
+
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedList]:
|
| 333 |
+
return default_types_pb2.SerializedList
|
| 334 |
+
|
| 335 |
+
@classmethod
|
| 336 |
+
def experimental_from_proto(
|
| 337 |
+
cls, proto: default_types_pb2.SerializedList) -> "List":
|
| 338 |
+
return List(
|
| 339 |
+
*Tuple.experimental_from_proto(proto.components_tuple).components)
|
| 340 |
+
|
| 341 |
+
def experimental_as_proto(self) -> default_types_pb2.SerializedList:
|
| 342 |
+
return default_types_pb2.SerializedList(
|
| 343 |
+
components_tuple=self.components_tuple.experimental_as_proto())
|
| 344 |
+
|
| 345 |
+
def placeholder_value(self, placeholder_context) -> Any:
|
| 346 |
+
return list(self.components_tuple.placeholder_value(placeholder_context))
|
| 347 |
+
|
| 348 |
+
def to_tensors(self, value):
|
| 349 |
+
assert isinstance(value, list)
|
| 350 |
+
return self.components_tuple.to_tensors(tuple(value))
|
| 351 |
+
|
| 352 |
+
def from_tensors(self, tensors) -> Any:
|
| 353 |
+
return list(self.components_tuple.from_tensors(tensors))
|
| 354 |
+
|
| 355 |
+
def flatten(self) -> PythonList[trace.TraceType]:
|
| 356 |
+
return self.components_tuple.flatten()
|
| 357 |
+
|
| 358 |
+
def cast(self, value: Any, casting_context) -> Any:
|
| 359 |
+
assert isinstance(value, list), f"Can not cast {value!r} to list type."
|
| 360 |
+
assert len(value) == len(
|
| 361 |
+
self.components_tuple.components
|
| 362 |
+
), f"Expected {value} to have length of {len(self.components_tuple)}"
|
| 363 |
+
|
| 364 |
+
casted_values, was_casted = util.cast_and_return_whether_casted(
|
| 365 |
+
self.components_tuple.components, value, casting_context
|
| 366 |
+
)
|
| 367 |
+
if was_casted:
|
| 368 |
+
return list(casted_values)
|
| 369 |
+
else:
|
| 370 |
+
return value
|
| 371 |
+
|
| 372 |
+
def __eq__(self, other: Any) -> bool:
|
| 373 |
+
if not isinstance(other, trace.TraceType):
|
| 374 |
+
return NotImplemented
|
| 375 |
+
|
| 376 |
+
if not isinstance(other, List):
|
| 377 |
+
return False
|
| 378 |
+
|
| 379 |
+
return self.components_tuple == other.components_tuple
|
| 380 |
+
|
| 381 |
+
def __hash__(self) -> int:
|
| 382 |
+
return hash(self.components_tuple)
|
| 383 |
+
|
| 384 |
+
def __repr__(self) -> str:
|
| 385 |
+
return f"List[{', '.join(map(repr, self.components_tuple.components))}]"
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class NamedTuple(trace.TraceType, serialization.Serializable):
|
| 389 |
+
"""Represents a NamedTuple of TraceType objects."""
|
| 390 |
+
|
| 391 |
+
def __init__(self,
|
| 392 |
+
type_name: str,
|
| 393 |
+
attribute_names: PythonTuple[str],
|
| 394 |
+
attributes: PythonTuple[trace.TraceType],
|
| 395 |
+
placeholder_type: Optional[Type[Any]] = None):
|
| 396 |
+
self.type_name = type_name
|
| 397 |
+
self.attribute_names = attribute_names
|
| 398 |
+
self.attributes = Tuple(*attributes)
|
| 399 |
+
self._placeholder_type = placeholder_type
|
| 400 |
+
|
| 401 |
+
@classmethod
|
| 402 |
+
def from_type_and_attributes(
|
| 403 |
+
cls, named_tuple_type: Any,
|
| 404 |
+
attributes: PythonTuple[trace.TraceType]) -> "NamedTuple":
|
| 405 |
+
return NamedTuple(named_tuple_type.__name__, named_tuple_type._fields,
|
| 406 |
+
attributes, named_tuple_type)
|
| 407 |
+
|
| 408 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 409 |
+
if not isinstance(other, NamedTuple):
|
| 410 |
+
return False
|
| 411 |
+
|
| 412 |
+
return (self.type_name == other.type_name and
|
| 413 |
+
self.attribute_names == other.attribute_names and
|
| 414 |
+
self.attributes.is_subtype_of(other.attributes))
|
| 415 |
+
|
| 416 |
+
def most_specific_common_supertype(
|
| 417 |
+
self, others: Sequence[trace.TraceType]) -> Optional["NamedTuple"]:
|
| 418 |
+
"""See base class."""
|
| 419 |
+
if not all(
|
| 420 |
+
isinstance(other, NamedTuple) and self.type_name == other.type_name and
|
| 421 |
+
self.attribute_names == other.attribute_names for other in others):
|
| 422 |
+
return None
|
| 423 |
+
|
| 424 |
+
supertyped_attributes = self.attributes.most_specific_common_supertype(
|
| 425 |
+
[other.attributes for other in others])
|
| 426 |
+
|
| 427 |
+
if supertyped_attributes is None:
|
| 428 |
+
return None
|
| 429 |
+
|
| 430 |
+
return NamedTuple(self.type_name, self.attribute_names,
|
| 431 |
+
supertyped_attributes.components, self._placeholder_type)
|
| 432 |
+
|
| 433 |
+
@classmethod
|
| 434 |
+
def experimental_type_proto(
|
| 435 |
+
cls) -> Type[default_types_pb2.SerializedNamedTuple]:
|
| 436 |
+
return default_types_pb2.SerializedNamedTuple
|
| 437 |
+
|
| 438 |
+
@classmethod
|
| 439 |
+
def experimental_from_proto(
|
| 440 |
+
cls, proto: default_types_pb2.SerializedNamedTuple) -> "NamedTuple":
|
| 441 |
+
return NamedTuple(
|
| 442 |
+
proto.type_name, tuple(proto.attribute_names),
|
| 443 |
+
Tuple.experimental_from_proto(proto.attributes).components)
|
| 444 |
+
|
| 445 |
+
def experimental_as_proto(self) -> default_types_pb2.SerializedNamedTuple:
|
| 446 |
+
return default_types_pb2.SerializedNamedTuple(
|
| 447 |
+
type_name=self.type_name,
|
| 448 |
+
attribute_names=list(self.attribute_names),
|
| 449 |
+
attributes=self.attributes.experimental_as_proto())
|
| 450 |
+
|
| 451 |
+
def placeholder_value(self, placeholder_context) -> Any:
|
| 452 |
+
if self._placeholder_type is None:
|
| 453 |
+
# We don't need to trace after serialization so it is not needed but we
|
| 454 |
+
# can generate a placeholder type using the description if ever needed.
|
| 455 |
+
raise ValueError("Can not generate placeholder value for NamedTuple with"
|
| 456 |
+
" unspecified placeholder_type. Note: placeholder_type "
|
| 457 |
+
"is lost during serialization.")
|
| 458 |
+
attribute_placeholders = [
|
| 459 |
+
attribute.placeholder_value(placeholder_context)
|
| 460 |
+
for attribute in self.attributes.components
|
| 461 |
+
]
|
| 462 |
+
return self._placeholder_type(*attribute_placeholders)
|
| 463 |
+
|
| 464 |
+
def to_tensors(self, value: Any):
|
| 465 |
+
assert util.is_namedtuple(value)
|
| 466 |
+
flattened_values = []
|
| 467 |
+
for attribute_name, attribute_type in zip(
|
| 468 |
+
self.attribute_names, self.attributes.components):
|
| 469 |
+
attribute_value = getattr(value, attribute_name)
|
| 470 |
+
flattened_values.extend(attribute_type.to_tensors(attribute_value))
|
| 471 |
+
return flattened_values
|
| 472 |
+
|
| 473 |
+
def from_tensors(self, tensors) -> Any:
|
| 474 |
+
if self._placeholder_type is None:
|
| 475 |
+
raise ValueError("Packing serialized NamedTuples is not supported.")
|
| 476 |
+
|
| 477 |
+
return self._placeholder_type(
|
| 478 |
+
*[c.from_tensors(tensors) for c in self.attributes.components]
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
def flatten(self) -> PythonList[trace.TraceType]:
|
| 482 |
+
flattened_types = []
|
| 483 |
+
|
| 484 |
+
for component in self.attributes.components:
|
| 485 |
+
flattened_types.extend(component.flatten())
|
| 486 |
+
|
| 487 |
+
return flattened_types
|
| 488 |
+
|
| 489 |
+
def cast(self, value: Any, casting_context) -> Any:
|
| 490 |
+
# Value must have same attributes with the TraceType
|
| 491 |
+
assert util.is_namedtuple(
|
| 492 |
+
value
|
| 493 |
+
), f"Cannot cast {value!r} to type {self._placeholder_type!r}."
|
| 494 |
+
value_dict = value._asdict()
|
| 495 |
+
assert set(value_dict.keys()) == set(
|
| 496 |
+
self.attribute_names
|
| 497 |
+
), f"{value!r} has different attributes with the TraceType {self!r}"
|
| 498 |
+
|
| 499 |
+
casted_values, was_casted = util.cast_and_return_whether_casted(
|
| 500 |
+
self.attributes.components,
|
| 501 |
+
[getattr(value, name) for name in self.attribute_names],
|
| 502 |
+
casting_context,
|
| 503 |
+
)
|
| 504 |
+
if was_casted:
|
| 505 |
+
return self._placeholder_type(*casted_values)
|
| 506 |
+
else:
|
| 507 |
+
return value
|
| 508 |
+
|
| 509 |
+
def __hash__(self) -> int:
|
| 510 |
+
return hash((self.type_name, self.attribute_names, self.attributes))
|
| 511 |
+
|
| 512 |
+
def __eq__(self, other: Any) -> bool:
|
| 513 |
+
if not isinstance(other, trace.TraceType):
|
| 514 |
+
return NotImplemented
|
| 515 |
+
|
| 516 |
+
if not isinstance(other, NamedTuple):
|
| 517 |
+
return False
|
| 518 |
+
|
| 519 |
+
return (self.type_name == other.type_name and
|
| 520 |
+
self.attribute_names == other.attribute_names and
|
| 521 |
+
self.attributes == other.attributes)
|
| 522 |
+
|
| 523 |
+
def __repr__(self) -> str:
|
| 524 |
+
paired = [
|
| 525 |
+
f"[{n!r}, {c!r}]"
|
| 526 |
+
for n, c in zip(self.attribute_names, self.attributes.components)
|
| 527 |
+
]
|
| 528 |
+
return f"{self.type_name}[{', '.join(paired)}]"
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
class Attrs(trace.TraceType):
|
| 532 |
+
"""Represents a class annotated by attr.s."""
|
| 533 |
+
|
| 534 |
+
def __init__(self,
|
| 535 |
+
type_name: str,
|
| 536 |
+
attribute_names: PythonTuple[str],
|
| 537 |
+
attributes: PythonTuple[trace.TraceType],
|
| 538 |
+
placeholder_type: Optional[Type[Any]] = None):
|
| 539 |
+
self.named_attributes = NamedTuple(type_name, attribute_names, attributes)
|
| 540 |
+
self._placeholder_type = placeholder_type
|
| 541 |
+
|
| 542 |
+
@classmethod
|
| 543 |
+
def from_type_and_attributes(
|
| 544 |
+
cls, attrs_type: Any,
|
| 545 |
+
attributes: PythonTuple[trace.TraceType]) -> "Attrs":
|
| 546 |
+
return Attrs(attrs_type.__name__,
|
| 547 |
+
tuple(attr.name for attr in attrs_type.__attrs_attrs__),
|
| 548 |
+
attributes, attrs_type)
|
| 549 |
+
|
| 550 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 551 |
+
if not isinstance(other, Attrs):
|
| 552 |
+
return False
|
| 553 |
+
|
| 554 |
+
return self.named_attributes.is_subtype_of(other.named_attributes)
|
| 555 |
+
|
| 556 |
+
def most_specific_common_supertype(
|
| 557 |
+
self, others: Sequence[trace.TraceType]) -> Optional["Attrs"]:
|
| 558 |
+
"""See base class."""
|
| 559 |
+
if not all(isinstance(other, Attrs) for other in others):
|
| 560 |
+
return None
|
| 561 |
+
|
| 562 |
+
supertyped_attributes = (
|
| 563 |
+
self.named_attributes.most_specific_common_supertype(
|
| 564 |
+
[other.named_attributes for other in others]
|
| 565 |
+
)
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
if supertyped_attributes is None:
|
| 569 |
+
return None
|
| 570 |
+
|
| 571 |
+
return Attrs(self.named_attributes.type_name,
|
| 572 |
+
self.named_attributes.attribute_names,
|
| 573 |
+
supertyped_attributes.attributes.components,
|
| 574 |
+
self._placeholder_type)
|
| 575 |
+
|
| 576 |
+
@classmethod
|
| 577 |
+
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedAttrs]:
|
| 578 |
+
return default_types_pb2.SerializedAttrs
|
| 579 |
+
|
| 580 |
+
@classmethod
|
| 581 |
+
def experimental_from_proto(
|
| 582 |
+
cls, proto: default_types_pb2.SerializedAttrs) -> "Attrs":
|
| 583 |
+
return Attrs(
|
| 584 |
+
proto.named_attributes.type_name,
|
| 585 |
+
tuple(proto.named_attributes.attribute_names),
|
| 586 |
+
Tuple.experimental_from_proto(
|
| 587 |
+
proto.named_attributes.attributes).components)
|
| 588 |
+
|
| 589 |
+
def experimental_as_proto(self) -> default_types_pb2.SerializedAttrs:
|
| 590 |
+
return default_types_pb2.SerializedAttrs(
|
| 591 |
+
named_attributes=self.named_attributes.experimental_as_proto())
|
| 592 |
+
|
| 593 |
+
def placeholder_value(self, placeholder_context) -> Any:
|
| 594 |
+
if self._placeholder_type is None:
|
| 595 |
+
# We don't need to trace after serialization so it is not needed but we
|
| 596 |
+
# can generate a placeholder type using the description if ever needed.
|
| 597 |
+
raise ValueError("Can not generate placeholder value for Attrs with"
|
| 598 |
+
" unspecified placeholder_type. Note: placeholder_type "
|
| 599 |
+
"is lost during serialization.")
|
| 600 |
+
attribute_placeholders = [
|
| 601 |
+
attribute.placeholder_value(placeholder_context)
|
| 602 |
+
for attribute in self.named_attributes.attributes.components
|
| 603 |
+
]
|
| 604 |
+
return self._placeholder_type(*attribute_placeholders)
|
| 605 |
+
|
| 606 |
+
def to_tensors(self, value: Any):
|
| 607 |
+
assert util.is_attrs(value)
|
| 608 |
+
flattened_values = []
|
| 609 |
+
for attribute_name, attribute_type in zip(
|
| 610 |
+
self.named_attributes.attribute_names,
|
| 611 |
+
self.named_attributes.attributes.components):
|
| 612 |
+
attribute_value = getattr(value, attribute_name)
|
| 613 |
+
flattened_values.extend(attribute_type.to_tensors(attribute_value))
|
| 614 |
+
return flattened_values
|
| 615 |
+
|
| 616 |
+
def from_tensors(self, tensors):
|
| 617 |
+
if self._placeholder_type is None:
|
| 618 |
+
raise ValueError("Packing serialized NamedTuples is not supported.")
|
| 619 |
+
|
| 620 |
+
return self._placeholder_type(
|
| 621 |
+
*[
|
| 622 |
+
c.from_tensors(tensors)
|
| 623 |
+
for c in self.named_attributes.attributes.components
|
| 624 |
+
]
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
def flatten(self) -> PythonList[trace.TraceType]:
|
| 628 |
+
flattened_types = []
|
| 629 |
+
|
| 630 |
+
for component in self.named_attributes.attributes.components:
|
| 631 |
+
flattened_types.extend(component.flatten())
|
| 632 |
+
|
| 633 |
+
return flattened_types
|
| 634 |
+
|
| 635 |
+
def cast(self, value: Any, casting_context) -> Any:
|
| 636 |
+
assert util.is_attrs(value)
|
| 637 |
+
|
| 638 |
+
attr_names = self.named_attributes.attribute_names
|
| 639 |
+
casted_values, was_casted = util.cast_and_return_whether_casted(
|
| 640 |
+
self.named_attributes.attributes.components,
|
| 641 |
+
[getattr(value, name) for name in attr_names],
|
| 642 |
+
casting_context,
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
if was_casted:
|
| 646 |
+
return self._placeholder_type(*casted_values)
|
| 647 |
+
else:
|
| 648 |
+
return value
|
| 649 |
+
|
| 650 |
+
def __hash__(self) -> int:
|
| 651 |
+
return hash(self.named_attributes)
|
| 652 |
+
|
| 653 |
+
def __eq__(self, other: Any) -> bool:
|
| 654 |
+
if not isinstance(other, trace.TraceType):
|
| 655 |
+
return NotImplemented
|
| 656 |
+
|
| 657 |
+
if not isinstance(other, Attrs):
|
| 658 |
+
return False
|
| 659 |
+
|
| 660 |
+
return self.named_attributes == other.named_attributes
|
| 661 |
+
|
| 662 |
+
def __repr__(self) -> str:
|
| 663 |
+
name_component_zip = zip(
|
| 664 |
+
self.named_attributes.attribute_names,
|
| 665 |
+
self.named_attributes.attributes.components,
|
| 666 |
+
)
|
| 667 |
+
paired = [f"[{n!r}, {c!r}]" for n, c in name_component_zip]
|
| 668 |
+
return f"{self.named_attributes.type_name}[{', '.join(paired)}]"
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
class Dict(trace.TraceType, serialization.Serializable):
|
| 672 |
+
"""Represents a dictionary of TraceType objects.
|
| 673 |
+
|
| 674 |
+
Attributes:
|
| 675 |
+
mapping: A mapping from keys to corresponding TraceTypes of the dict values.
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
def __init__(self,
|
| 679 |
+
mapping: PythonDict[Hashable, trace.TraceType],
|
| 680 |
+
placeholder_type: Optional[Type[Any]] = None):
|
| 681 |
+
self.mapping = mapping
|
| 682 |
+
self._placeholder_type = placeholder_type
|
| 683 |
+
|
| 684 |
+
def _has_same_structure(self, other):
|
| 685 |
+
if not isinstance(other, Dict):
|
| 686 |
+
return False
|
| 687 |
+
|
| 688 |
+
return self.mapping.keys() == other.mapping.keys()
|
| 689 |
+
|
| 690 |
+
def is_subtype_of(self, other: trace.TraceType) -> bool:
|
| 691 |
+
"""See base class."""
|
| 692 |
+
if not self._has_same_structure(other):
|
| 693 |
+
return False
|
| 694 |
+
|
| 695 |
+
# We need all keys to be present because there can be logic relying on
|
| 696 |
+
# their existence or lack thereof and hence can not guarantee subtype based
|
| 697 |
+
# on a subset or superset of keys.
|
| 698 |
+
# Only the tracing code can explicitly check for key dependencies and inform
|
| 699 |
+
# that decision.
|
| 700 |
+
return all(self.mapping[key].is_subtype_of(other.mapping[key])
|
| 701 |
+
for key in self.mapping)
|
| 702 |
+
|
| 703 |
+
def most_specific_common_supertype(
|
| 704 |
+
self, types: Sequence[trace.TraceType]) -> Optional["Dict"]:
|
| 705 |
+
"""See base class."""
|
| 706 |
+
if not all(self._has_same_structure(other) for other in types):
|
| 707 |
+
return None
|
| 708 |
+
|
| 709 |
+
new_mapping = {}
|
| 710 |
+
for key in self.mapping.keys():
|
| 711 |
+
common = self.mapping[key].most_specific_common_supertype(
|
| 712 |
+
[other.mapping[key] for other in types])
|
| 713 |
+
if common is None:
|
| 714 |
+
return None
|
| 715 |
+
else:
|
| 716 |
+
new_mapping[key] = common
|
| 717 |
+
|
| 718 |
+
return Dict(new_mapping, self._placeholder_type)
|
| 719 |
+
|
| 720 |
+
@classmethod
|
| 721 |
+
def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedDict]:
|
| 722 |
+
return default_types_pb2.SerializedDict
|
| 723 |
+
|
| 724 |
+
@classmethod
|
| 725 |
+
def experimental_from_proto(
|
| 726 |
+
cls, proto: default_types_pb2.SerializedDict) -> "Dict":
|
| 727 |
+
return Dict({
|
| 728 |
+
Literal.experimental_from_proto(k).value: serialization.deserialize(v)
|
| 729 |
+
for k, v in zip(proto.keys, proto.values)
|
| 730 |
+
})
|
| 731 |
+
|
| 732 |
+
def experimental_as_proto(self) -> default_types_pb2.SerializedDict:
|
| 733 |
+
return default_types_pb2.SerializedDict(
|
| 734 |
+
keys=[Literal(k).experimental_as_proto() for k in self.mapping.keys()],
|
| 735 |
+
values=[serialization.serialize(v) for v in self.mapping.values()])
|
| 736 |
+
|
| 737 |
+
def placeholder_value(self, placeholder_context) -> Any:
|
| 738 |
+
if self._placeholder_type is None:
|
| 739 |
+
raise ValueError("Can not generate placeholder value for Dict with"
|
| 740 |
+
" unspecified placeholder_type. Note: placeholder_type "
|
| 741 |
+
"is lost during serialization.")
|
| 742 |
+
attribute_placeholders = [
|
| 743 |
+
(key, value.placeholder_value(placeholder_context))
|
| 744 |
+
for key, value in self.mapping.items()
|
| 745 |
+
]
|
| 746 |
+
if self._placeholder_type is collections.defaultdict:
|
| 747 |
+
return dict(attribute_placeholders)
|
| 748 |
+
return self._placeholder_type(attribute_placeholders)
|
| 749 |
+
|
| 750 |
+
def to_tensors(self, value: Any):
|
| 751 |
+
assert isinstance(value, collections.abc.Mapping)
|
| 752 |
+
flattened_values = []
|
| 753 |
+
for key in sorted(self.mapping.keys()):
|
| 754 |
+
comp_value, comp_type = value[key], self.mapping[key]
|
| 755 |
+
flattened_values.extend(comp_type.to_tensors(comp_value))
|
| 756 |
+
return flattened_values
|
| 757 |
+
|
| 758 |
+
def from_tensors(self, tensors):
|
| 759 |
+
if self._placeholder_type is None:
|
| 760 |
+
raise ValueError("Packing serialized Dict is not supported.")
|
| 761 |
+
|
| 762 |
+
sorted_traversal = {
|
| 763 |
+
key: self.mapping[key].from_tensors(tensors)
|
| 764 |
+
for key in sorted(self.mapping)
|
| 765 |
+
}
|
| 766 |
+
|
| 767 |
+
if self._placeholder_type is collections.defaultdict:
|
| 768 |
+
return {key: sorted_traversal[key] for key in self.mapping}
|
| 769 |
+
|
| 770 |
+
return self._placeholder_type(
|
| 771 |
+
(key, sorted_traversal[key]) for key in self.mapping
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
def flatten(self) -> PythonList[trace.TraceType]:
|
| 775 |
+
flattened_types = []
|
| 776 |
+
|
| 777 |
+
for key in sorted(self.mapping.keys()):
|
| 778 |
+
flattened_types.extend(self.mapping[key].flatten())
|
| 779 |
+
|
| 780 |
+
return flattened_types
|
| 781 |
+
|
| 782 |
+
def cast(self, value: Any, casting_context) -> Any:
|
| 783 |
+
# Value must have same keys with the TraceType
|
| 784 |
+
assert isinstance(
|
| 785 |
+
value, collections.abc.Mapping
|
| 786 |
+
), f"Can not cast {value!r} to a Dict type."
|
| 787 |
+
assert set(value.keys()) == set(
|
| 788 |
+
self.mapping.keys()
|
| 789 |
+
), f"{value!r} has different keys with the TraceType {self!r}."
|
| 790 |
+
|
| 791 |
+
casted_values, was_casted = util.cast_and_return_whether_casted(
|
| 792 |
+
self.mapping.values(),
|
| 793 |
+
[value[k] for k in self.mapping.keys()],
|
| 794 |
+
casting_context,
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
if was_casted:
|
| 798 |
+
return self._placeholder_type(
|
| 799 |
+
**{k: v for k, v in zip(self.mapping.keys(), casted_values)}
|
| 800 |
+
)
|
| 801 |
+
else:
|
| 802 |
+
return value
|
| 803 |
+
|
| 804 |
+
def __eq__(self, other) -> bool:
|
| 805 |
+
if not isinstance(other, trace.TraceType):
|
| 806 |
+
return NotImplemented
|
| 807 |
+
|
| 808 |
+
if not isinstance(other, Dict):
|
| 809 |
+
return False
|
| 810 |
+
|
| 811 |
+
return self.mapping == other.mapping
|
| 812 |
+
|
| 813 |
+
def __hash__(self) -> int:
|
| 814 |
+
return hash(frozenset(self.mapping.keys()))
|
| 815 |
+
|
| 816 |
+
def __repr__(self) -> str:
|
| 817 |
+
paired = [f"[{n!r}, {t!r}]" for n, t in self.mapping.items()]
|
| 818 |
+
return f"{self.__class__.__name__}[{', '.join(paired)}]"
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
serialization.register_serializable(Literal)
|
| 822 |
+
serialization.register_serializable(Tuple)
|
| 823 |
+
serialization.register_serializable(List)
|
| 824 |
+
serialization.register_serializable(NamedTuple)
|
| 825 |
+
serialization.register_serializable(Attrs)
|
| 826 |
+
serialization.register_serializable(Dict)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/default_types_pb2.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/core/function/trace_type/default_types.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from tensorflow.core.function.trace_type import serialization_pb2 as tensorflow_dot_core_dot_function_dot_trace__type_dot_serialization__pb2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7tensorflow/core/function/trace_type/default_types.proto\x12\x31tensorflow.core.function.trace_type.default_types\x1a\x37tensorflow/core/function/trace_type/serialization.proto\"\xe6\x01\n\x11SerializedLiteral\x12\x14\n\nbool_value\x18\x01 \x01(\x08H\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x01H\x00\x12\x13\n\tstr_value\x18\x04 \x01(\tH\x00\x12\x64\n\nnone_value\x18\x05 \x01(\x0b\x32N.tensorflow.core.function.trace_type.default_types.SerializedLiteral.NoneValueH\x00\x1a\x0b\n\tNoneValueB\x07\n\x05value\"m\n\x0fSerializedTuple\x12Z\n\ncomponents\x18\x01 \x03(\x0b\x32\x46.tensorflow.core.function.trace_type.serialization.SerializedTraceType\"n\n\x0eSerializedList\x12\\\n\x10\x63omponents_tuple\x18\x01 \x01(\x0b\x32\x42.tensorflow.core.function.trace_type.default_types.SerializedTuple\"\x9a\x01\n\x14SerializedNamedTuple\x12\x11\n\ttype_name\x18\x01 \x01(\t\x12\x17\n\x0f\x61ttribute_names\x18\x02 \x03(\t\x12V\n\nattributes\x18\x03 \x01(\x0b\x32\x42.tensorflow.core.function.trace_type.default_types.SerializedTuple\"t\n\x0fSerializedAttrs\x12\x61\n\x10named_attributes\x18\x01 \x01(\x0b\x32G.tensorflow.core.function.trace_type.default_types.SerializedNamedTuple\"\xbc\x01\n\x0eSerializedDict\x12R\n\x04keys\x18\x01 \x03(\x0b\x32\x44.tensorflow.core.function.trace_type.default_types.SerializedLiteral\x12V\n\x06values\x18\x02 \x03(\x0b\x32\x46.tensorflow.core.function.trace_type.serialization.SerializedTraceType')
|
| 18 |
+
|
| 19 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 20 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.core.function.trace_type.default_types_pb2', globals())
|
| 21 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 22 |
+
|
| 23 |
+
DESCRIPTOR._options = None
|
| 24 |
+
_SERIALIZEDLITERAL._serialized_start=168
|
| 25 |
+
_SERIALIZEDLITERAL._serialized_end=398
|
| 26 |
+
_SERIALIZEDLITERAL_NONEVALUE._serialized_start=378
|
| 27 |
+
_SERIALIZEDLITERAL_NONEVALUE._serialized_end=389
|
| 28 |
+
_SERIALIZEDTUPLE._serialized_start=400
|
| 29 |
+
_SERIALIZEDTUPLE._serialized_end=509
|
| 30 |
+
_SERIALIZEDLIST._serialized_start=511
|
| 31 |
+
_SERIALIZEDLIST._serialized_end=621
|
| 32 |
+
_SERIALIZEDNAMEDTUPLE._serialized_start=624
|
| 33 |
+
_SERIALIZEDNAMEDTUPLE._serialized_end=778
|
| 34 |
+
_SERIALIZEDATTRS._serialized_start=780
|
| 35 |
+
_SERIALIZEDATTRS._serialized_end=896
|
| 36 |
+
_SERIALIZEDDICT._serialized_start=899
|
| 37 |
+
_SERIALIZEDDICT._serialized_end=1087
|
| 38 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/serialization.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utils for serializing and deserializing TraceTypes."""
|
| 16 |
+
|
| 17 |
+
import abc
|
| 18 |
+
from typing import Type
|
| 19 |
+
|
| 20 |
+
from google.protobuf import message
|
| 21 |
+
from tensorflow.core.function.trace_type import serialization_pb2
|
| 22 |
+
|
| 23 |
+
SerializedTraceType = serialization_pb2.SerializedTraceType
|
| 24 |
+
|
| 25 |
+
PROTO_CLASS_TO_PY_CLASS = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Serializable(metaclass=abc.ABCMeta):
|
| 29 |
+
"""TraceTypes implementing this additional interface are portable."""
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
@abc.abstractmethod
|
| 33 |
+
def experimental_type_proto(cls) -> Type[message.Message]:
|
| 34 |
+
"""Returns the unique type of proto associated with this class."""
|
| 35 |
+
raise NotImplementedError
|
| 36 |
+
|
| 37 |
+
@classmethod
|
| 38 |
+
@abc.abstractmethod
|
| 39 |
+
def experimental_from_proto(cls, proto: message.Message) -> "Serializable":
|
| 40 |
+
"""Returns an instance based on a proto."""
|
| 41 |
+
raise NotImplementedError
|
| 42 |
+
|
| 43 |
+
@abc.abstractmethod
|
| 44 |
+
def experimental_as_proto(self) -> message.Message:
|
| 45 |
+
"""Returns a proto representing this instance."""
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def register_serializable(cls: Type[Serializable]):
|
| 50 |
+
"""Registers a Python class to support serialization.
|
| 51 |
+
|
| 52 |
+
Only register standard TF types. Custom types should NOT be registered.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
cls: Python class to register.
|
| 56 |
+
"""
|
| 57 |
+
if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"Existing Python class " +
|
| 60 |
+
PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ +
|
| 61 |
+
" already has " + cls.experimental_type_proto().__name__ +
|
| 62 |
+
" as its associated proto representation. Please ensure " +
|
| 63 |
+
cls.__name__ + " has a unique proto representation.")
|
| 64 |
+
|
| 65 |
+
PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def serialize(to_serialize: Serializable) -> SerializedTraceType:
|
| 69 |
+
"""Converts Serializable to a proto SerializedTraceType."""
|
| 70 |
+
|
| 71 |
+
if not isinstance(to_serialize, Serializable):
|
| 72 |
+
raise ValueError("Can not serialize " + type(to_serialize).__name__ +
|
| 73 |
+
" since it is not Serializable. For object " +
|
| 74 |
+
str(to_serialize))
|
| 75 |
+
actual_proto = to_serialize.experimental_as_proto()
|
| 76 |
+
|
| 77 |
+
if not isinstance(actual_proto, to_serialize.experimental_type_proto()):
|
| 78 |
+
raise ValueError(
|
| 79 |
+
type(to_serialize).__name__ +
|
| 80 |
+
" returned different type of proto than specified by " +
|
| 81 |
+
"experimental_type_proto()")
|
| 82 |
+
|
| 83 |
+
serialized = SerializedTraceType()
|
| 84 |
+
serialized.representation.Pack(actual_proto)
|
| 85 |
+
return serialized
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def deserialize(proto: SerializedTraceType) -> Serializable:
|
| 89 |
+
"""Converts a proto SerializedTraceType to instance of Serializable."""
|
| 90 |
+
for proto_class in PROTO_CLASS_TO_PY_CLASS:
|
| 91 |
+
if proto.representation.Is(proto_class.DESCRIPTOR):
|
| 92 |
+
actual_proto = proto_class()
|
| 93 |
+
proto.representation.Unpack(actual_proto)
|
| 94 |
+
return PROTO_CLASS_TO_PY_CLASS[proto_class].experimental_from_proto(
|
| 95 |
+
actual_proto)
|
| 96 |
+
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"Can not deserialize proto of url: ", proto.representation.type_url,
|
| 99 |
+
" since no matching Python class could be found. For value ",
|
| 100 |
+
proto.representation.value)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/serialization_pb2.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/core/function/trace_type/serialization.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7tensorflow/core/function/trace_type/serialization.proto\x12\x31tensorflow.core.function.trace_type.serialization\x1a\x19google/protobuf/any.proto\"C\n\x13SerializedTraceType\x12,\n\x0erepresentation\x18\x01 \x01(\x0b\x32\x14.google.protobuf.Any')
|
| 18 |
+
|
| 19 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 20 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.core.function.trace_type.serialization_pb2', globals())
|
| 21 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 22 |
+
|
| 23 |
+
DESCRIPTOR._options = None
|
| 24 |
+
_SERIALIZEDTRACETYPE._serialized_start=137
|
| 25 |
+
_SERIALIZEDTRACETYPE._serialized_end=204
|
| 26 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/serialization_test_pb2.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/core/function/trace_type/serialization_test.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from tensorflow.core.function.trace_type import serialization_pb2 as tensorflow_dot_core_dot_function_dot_trace__type_dot_serialization__pb2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n<tensorflow/core/function/trace_type/serialization_test.proto\x12\x36tensorflow.core.function.trace_type.serialization_test\x1a\x37tensorflow/core/function/trace_type/serialization.proto\"5\n\x16MyCustomRepresentation\x12\r\n\x05index\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\"u\n\x19MyCompositeRepresentation\x12X\n\x08\x65lements\x18\x01 \x03(\x0b\x32\x46.tensorflow.core.function.trace_type.serialization.SerializedTraceType\"(\n\x1aMyMultiClassRepresentation\x12\n\n\x02id\x18\x01 \x01(\x05')
|
| 18 |
+
|
| 19 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 20 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.core.function.trace_type.serialization_test_pb2', globals())
|
| 21 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 22 |
+
|
| 23 |
+
DESCRIPTOR._options = None
|
| 24 |
+
_MYCUSTOMREPRESENTATION._serialized_start=177
|
| 25 |
+
_MYCUSTOMREPRESENTATION._serialized_end=230
|
| 26 |
+
_MYCOMPOSITEREPRESENTATION._serialized_start=232
|
| 27 |
+
_MYCOMPOSITEREPRESENTATION._serialized_end=349
|
| 28 |
+
_MYMULTICLASSREPRESENTATION._serialized_start=351
|
| 29 |
+
_MYMULTICLASSREPRESENTATION._serialized_end=391
|
| 30 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/trace_type_builder.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utitiles for Cache Key generation based on Function Trace Type."""
|
| 16 |
+
|
| 17 |
+
import collections.abc
|
| 18 |
+
from typing import Any, Dict, Hashable, Optional
|
| 19 |
+
import weakref
|
| 20 |
+
|
| 21 |
+
from tensorflow.core.function.trace_type import custom_nest_trace_type
|
| 22 |
+
from tensorflow.core.function.trace_type import default_types
|
| 23 |
+
from tensorflow.core.function.trace_type import util
|
| 24 |
+
from tensorflow.python.types import trace
|
| 25 |
+
from tensorflow.python.util import custom_nest_protocol
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class InternalTracingContext(trace.TracingContext):
|
| 29 |
+
"""Container for variables and flags shared across TraceType generation."""
|
| 30 |
+
|
| 31 |
+
def __init__(self, is_legacy_signature: bool = False):
|
| 32 |
+
self._global_to_local_id = {}
|
| 33 |
+
self._alias_id_to_placeholder = {}
|
| 34 |
+
self._is_legacy_signature = is_legacy_signature
|
| 35 |
+
|
| 36 |
+
def alias_global_id(self, global_id: Hashable) -> Hashable:
|
| 37 |
+
if global_id not in self._global_to_local_id:
|
| 38 |
+
self._global_to_local_id[global_id] = len(self._global_to_local_id)
|
| 39 |
+
|
| 40 |
+
return self._global_to_local_id[global_id]
|
| 41 |
+
|
| 42 |
+
def add_placeholder(self, alias_id: Hashable, variable) -> None:
|
| 43 |
+
self._alias_id_to_placeholder[alias_id] = variable
|
| 44 |
+
|
| 45 |
+
def get_placeholder_mapping(self) -> Dict[Hashable, Any]:
|
| 46 |
+
return self._alias_id_to_placeholder
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def is_legacy_signature(self) -> bool:
|
| 50 |
+
"""If the value is from a legacy signature representation.
|
| 51 |
+
|
| 52 |
+
Legacy signature representations include tf.function.input_signature and
|
| 53 |
+
ConcreteFunction.structured_input_signature.
|
| 54 |
+
"""
|
| 55 |
+
return self._is_legacy_signature
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class InternalPlaceholderContext(trace.PlaceholderContext):
|
| 59 |
+
"""Container with mappings shared across TraceTypes for placeholder values."""
|
| 60 |
+
|
| 61 |
+
def __init__(self,
|
| 62 |
+
context_graph=None,
|
| 63 |
+
placeholder_mapping=None,
|
| 64 |
+
unnest_only=False,
|
| 65 |
+
with_none_control_dependencies=False,
|
| 66 |
+
composite_device_name=None):
|
| 67 |
+
self._alias_id_to_placeholder = placeholder_mapping or {}
|
| 68 |
+
self._naming_scope = None
|
| 69 |
+
self._context_graph = context_graph
|
| 70 |
+
self._unnest_only = unnest_only
|
| 71 |
+
self._with_none_control_dependencies = with_none_control_dependencies
|
| 72 |
+
self._composite_device_name = composite_device_name
|
| 73 |
+
|
| 74 |
+
def has_placeholder(self, alias_id: Hashable) -> bool:
|
| 75 |
+
return alias_id in self._alias_id_to_placeholder
|
| 76 |
+
|
| 77 |
+
def get_placeholder(self, alias_id: Hashable) -> Hashable:
|
| 78 |
+
if not self.has_placeholder(alias_id):
|
| 79 |
+
raise KeyError(f"alias_id: {alias_id} not found in this instance of "
|
| 80 |
+
"placeholder context.")
|
| 81 |
+
return self._alias_id_to_placeholder[alias_id]
|
| 82 |
+
|
| 83 |
+
def add_placeholder(self, alias_id: Hashable, placeholder: Hashable) -> None:
|
| 84 |
+
if alias_id in self._alias_id_to_placeholder:
|
| 85 |
+
raise KeyError(f"alias id: {alias_id} is already stored in this "
|
| 86 |
+
"instance of placeholder context.")
|
| 87 |
+
self._alias_id_to_placeholder[alias_id] = placeholder
|
| 88 |
+
|
| 89 |
+
def update_naming_scope(self, naming_scope: Optional[str]) -> None:
|
| 90 |
+
self._naming_scope = naming_scope
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def naming_scope(self) -> Optional[str]:
|
| 94 |
+
return self._naming_scope
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def context_graph(self):
|
| 98 |
+
return self._context_graph
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def unnest_only(self) -> bool:
|
| 102 |
+
return self._unnest_only
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def with_none_control_dependencies(self) -> bool:
|
| 106 |
+
return self._with_none_control_dependencies
|
| 107 |
+
|
| 108 |
+
@property
|
| 109 |
+
def composite_device_name(self) -> Any:
|
| 110 |
+
return self._composite_device_name
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class InternalCastContext(trace.CastContext):
|
| 114 |
+
"""Default casting behaviors."""
|
| 115 |
+
|
| 116 |
+
def __init__(self, allow_specs=False):
|
| 117 |
+
self._allow_specs = allow_specs
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def allow_specs(self) -> bool:
|
| 121 |
+
"""Allow TypeSpecs to be casted (instead of the actual CompositeTensors)."""
|
| 122 |
+
# Public APIs like get_concrete_function allow users to pass in specs
|
| 123 |
+
# instead which need to pass through input binding etc.
|
| 124 |
+
return self._allow_specs
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def from_value(value: Any,
|
| 128 |
+
context: trace.TracingContext = None) -> trace.TraceType:
|
| 129 |
+
"""Returns a TraceType corresponding to the value based on the context.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
value: The value to generate a TraceType for.
|
| 133 |
+
context: The TracingContext to be shared during protocol calls.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
A TraceType object representing the given value.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
if context is None:
|
| 140 |
+
context = InternalTracingContext()
|
| 141 |
+
|
| 142 |
+
if context.is_legacy_signature and isinstance(value, trace.TraceType):
|
| 143 |
+
return value
|
| 144 |
+
elif isinstance(value, trace.SupportsTracingProtocol):
|
| 145 |
+
generated_type = value.__tf_tracing_type__(context)
|
| 146 |
+
if not isinstance(generated_type, trace.TraceType):
|
| 147 |
+
raise TypeError(
|
| 148 |
+
"Expected an instance of TraceType for Tracing Protocol call to " +
|
| 149 |
+
str(value) + " but got " + str(generated_type))
|
| 150 |
+
return generated_type
|
| 151 |
+
|
| 152 |
+
# TODO(b/183107079): Allow these once they're handled properly.
|
| 153 |
+
if isinstance(value, weakref.ref):
|
| 154 |
+
raise TypeError(
|
| 155 |
+
f"weakref input {value} not supported for tf.function."
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if hasattr(value, "__wrapped__"):
|
| 159 |
+
return from_value(value.__wrapped__, context)
|
| 160 |
+
|
| 161 |
+
if isinstance(value, list):
|
| 162 |
+
return default_types.List(*(from_value(c, context) for c in value))
|
| 163 |
+
|
| 164 |
+
if isinstance(value, tuple):
|
| 165 |
+
if util.is_namedtuple(value):
|
| 166 |
+
named_tuple_type = type(value)
|
| 167 |
+
return default_types.NamedTuple.from_type_and_attributes(
|
| 168 |
+
named_tuple_type, tuple(from_value(c, context) for c in value))
|
| 169 |
+
else:
|
| 170 |
+
return default_types.Tuple(*(from_value(c, context) for c in value))
|
| 171 |
+
|
| 172 |
+
if isinstance(value, collections.abc.Mapping):
|
| 173 |
+
mapping_type = type(value)
|
| 174 |
+
return default_types.Dict(
|
| 175 |
+
{k: from_value(value[k], context) for k in value}, mapping_type)
|
| 176 |
+
|
| 177 |
+
if util.is_attrs(value):
|
| 178 |
+
return default_types.Attrs.from_type_and_attributes(
|
| 179 |
+
type(value),
|
| 180 |
+
tuple(
|
| 181 |
+
from_value(getattr(value, a.name), context)
|
| 182 |
+
for a in value.__attrs_attrs__))
|
| 183 |
+
|
| 184 |
+
if util.is_np_ndarray(value):
|
| 185 |
+
ndarray = value.__array__()
|
| 186 |
+
return default_types.TENSOR(ndarray.shape, ndarray.dtype)
|
| 187 |
+
|
| 188 |
+
if isinstance(value, custom_nest_protocol.CustomNestProtocol):
|
| 189 |
+
metadata, components = value.__tf_flatten__()
|
| 190 |
+
return custom_nest_trace_type.CustomNestTraceType(
|
| 191 |
+
type(value), metadata, tuple(from_value(c, context) for c in components)
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
ref = weakref.ref(value)
|
| 196 |
+
if ref is None:
|
| 197 |
+
raise TypeError(
|
| 198 |
+
f"Deleted objects are not valid tf.function arguments, Got {value!r}")
|
| 199 |
+
else:
|
| 200 |
+
return default_types.Weakref(ref)
|
| 201 |
+
except TypeError:
|
| 202 |
+
try:
|
| 203 |
+
return default_types.Literal(value)
|
| 204 |
+
except:
|
| 205 |
+
raise TypeError( # pylint: disable=raise-missing-from
|
| 206 |
+
f"Could not generate a generic TraceType for {value!r}."
|
| 207 |
+
f"Please verify that it is immutable/hashable. Otheriwse, consider "
|
| 208 |
+
f"implementing the Tracing Protocol for it.")
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/function/trace_type/util.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utilities for the trace_type module."""
|
| 16 |
+
|
| 17 |
+
from typing import Any, List, Tuple
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# TODO(b/225045380): Depend on the abstracted `leaf` lib from 'nest'.
|
| 23 |
+
def is_namedtuple(obj):
|
| 24 |
+
return hasattr(obj, "_fields") and all(
|
| 25 |
+
isinstance(field, str) for field in obj._fields)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# TODO(b/225045380): Depend on the abstracted `leaf` lib from 'nest'.
|
| 29 |
+
def is_attrs(obj):
|
| 30 |
+
return hasattr(type(obj), "__attrs_attrs__")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# TODO(b/225045380): Depend on the abstracted `leaf` lib from 'nest'.
|
| 34 |
+
def is_np_ndarray(value):
|
| 35 |
+
return hasattr(value, "__array__") and not (
|
| 36 |
+
# For legacy reasons we do not automatically promote Numpy strings.
|
| 37 |
+
isinstance(value, np.str_)
|
| 38 |
+
# NumPy dtypes have __array__ as unbound methods.
|
| 39 |
+
or isinstance(value, type))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def cast_and_return_whether_casted(
|
| 43 |
+
trace_types, values, context
|
| 44 |
+
) -> Tuple[List[Any], bool]:
|
| 45 |
+
did_cast = False
|
| 46 |
+
casted_values = []
|
| 47 |
+
for t, v in zip(trace_types, values):
|
| 48 |
+
casted_v = t.cast(v, context)
|
| 49 |
+
casted_values.append(casted_v)
|
| 50 |
+
if casted_v is not v:
|
| 51 |
+
did_cast = True
|
| 52 |
+
return casted_values, did_cast
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (196 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/__pycache__/toco_flags_pb2.cpython-310.pyc
ADDED
|
Binary file (4.68 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/__pycache__/types_pb2.cpython-310.pyc
ADDED
|
Binary file (1.19 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (204 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__pycache__/gen_html.cpython-310.pyc
ADDED
|
Binary file (8.53 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/__pycache__/toco_conversion_log_pb2.cpython-310.pyc
ADDED
|
Binary file (1.79 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/gen_html.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""A utility class to generate the report HTML based on a common template."""
|
| 16 |
+
|
| 17 |
+
import io
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
from tensorflow.lite.toco.logging import toco_conversion_log_pb2 as _toco_conversion_log_pb2
|
| 21 |
+
from tensorflow.python.lib.io import file_io as _file_io
|
| 22 |
+
from tensorflow.python.platform import resource_loader as _resource_loader
|
| 23 |
+
|
| 24 |
+
html_escape_table = {
|
| 25 |
+
"&": "&",
|
| 26 |
+
'"': """,
|
| 27 |
+
"'": "'",
|
| 28 |
+
">": ">",
|
| 29 |
+
"<": "<",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def html_escape(text):
|
| 34 |
+
return "".join(html_escape_table.get(c, c) for c in text)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_input_type_from_signature(op_signature):
|
| 38 |
+
"""Parses op_signature and returns a string denoting the input tensor type.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
op_signature: a string specifying the signature of a particular operator.
|
| 42 |
+
The signature of an operator contains the input tensor's shape and type,
|
| 43 |
+
output tensor's shape and type, operator's name and its version. It has
|
| 44 |
+
the following schema:
|
| 45 |
+
INPUT:input_1_shape::input_1_type::input_2_shape::input_2_type::..
|
| 46 |
+
::OUTPUT:output_1_shape::output_1_type::output_2_shape::output_2_type::
|
| 47 |
+
..::NAME:operator_name ::VERSION:operator_version
|
| 48 |
+
An example of an operator signature is:
|
| 49 |
+
INPUT:[1,73,73,160]::float::[64,1,1,160]::float::[64]::float::
|
| 50 |
+
OUTPUT:[1,73,73,64]::float::NAME:Conv::VERSION:1
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
A string denoting the input tensors' type. In the form of shape/type
|
| 54 |
+
separated
|
| 55 |
+
by comma. For example:
|
| 56 |
+
shape:[1,73,73,160],type:float,shape:[64,1,1,160],type:float,shape:[64],
|
| 57 |
+
type:float
|
| 58 |
+
"""
|
| 59 |
+
start = op_signature.find(":")
|
| 60 |
+
end = op_signature.find("::OUTPUT")
|
| 61 |
+
inputs = op_signature[start + 1:end]
|
| 62 |
+
lst = inputs.split("::")
|
| 63 |
+
out_str = ""
|
| 64 |
+
for i in range(len(lst)):
|
| 65 |
+
if i % 2 == 0:
|
| 66 |
+
out_str += "shape:"
|
| 67 |
+
else:
|
| 68 |
+
out_str += "type:"
|
| 69 |
+
out_str += lst[i]
|
| 70 |
+
out_str += ","
|
| 71 |
+
return out_str[:-1]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_operator_type(op_name, conversion_log):
|
| 75 |
+
if op_name in conversion_log.built_in_ops:
|
| 76 |
+
return "BUILT-IN"
|
| 77 |
+
elif op_name in conversion_log.custom_ops:
|
| 78 |
+
return "CUSTOM OP"
|
| 79 |
+
else:
|
| 80 |
+
return "SELECT OP"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class HTMLGenerator:
|
| 84 |
+
"""Utility class to generate an HTML report."""
|
| 85 |
+
|
| 86 |
+
def __init__(self, html_template_path, export_report_path):
|
| 87 |
+
"""Reads the HTML template content.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
html_template_path: A string, path to the template HTML file.
|
| 91 |
+
export_report_path: A string, path to the generated HTML report. This path
|
| 92 |
+
should point to a '.html' file with date and time in its name.
|
| 93 |
+
e.g. 2019-01-01-10:05.toco_report.html.
|
| 94 |
+
|
| 95 |
+
Raises:
|
| 96 |
+
IOError: File doesn't exist.
|
| 97 |
+
"""
|
| 98 |
+
# Load the template HTML.
|
| 99 |
+
if not _file_io.file_exists(html_template_path):
|
| 100 |
+
raise IOError("File '{0}' does not exist.".format(html_template_path))
|
| 101 |
+
with _file_io.FileIO(html_template_path, "r") as f:
|
| 102 |
+
self.html_template = f.read()
|
| 103 |
+
|
| 104 |
+
_file_io.recursive_create_dir(os.path.dirname(export_report_path))
|
| 105 |
+
self.export_report_path = export_report_path
|
| 106 |
+
|
| 107 |
+
def generate(self,
|
| 108 |
+
toco_conversion_log_before,
|
| 109 |
+
toco_conversion_log_after,
|
| 110 |
+
post_training_quant_enabled,
|
| 111 |
+
dot_before,
|
| 112 |
+
dot_after,
|
| 113 |
+
toco_err_log="",
|
| 114 |
+
tflite_graph_path=""):
|
| 115 |
+
"""Generates the HTML report and writes it to local directory.
|
| 116 |
+
|
| 117 |
+
This function uses the fields in `toco_conversion_log_before` and
|
| 118 |
+
`toco_conversion_log_after` to populate the HTML content. Certain markers
|
| 119 |
+
(placeholders) in the HTML template are then substituted with the fields
|
| 120 |
+
from the protos. Once finished it will write the HTML file to the specified
|
| 121 |
+
local file path.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
toco_conversion_log_before: A `TocoConversionLog` protobuf generated
|
| 125 |
+
before the model is converted by TOCO.
|
| 126 |
+
toco_conversion_log_after: A `TocoConversionLog` protobuf generated after
|
| 127 |
+
the model is converted by TOCO.
|
| 128 |
+
post_training_quant_enabled: A boolean, whether post-training quantization
|
| 129 |
+
is enabled.
|
| 130 |
+
dot_before: A string, the dot representation of the model
|
| 131 |
+
before the conversion.
|
| 132 |
+
dot_after: A string, the dot representation of the model after
|
| 133 |
+
the conversion.
|
| 134 |
+
toco_err_log: A string, the logs emitted by TOCO during conversion. Caller
|
| 135 |
+
need to ensure that this string is properly anonymized (any kind of
|
| 136 |
+
user data should be eliminated).
|
| 137 |
+
tflite_graph_path: A string, the filepath to the converted TFLite model.
|
| 138 |
+
|
| 139 |
+
Raises:
|
| 140 |
+
RuntimeError: When error occurs while generating the template.
|
| 141 |
+
"""
|
| 142 |
+
html_dict = {}
|
| 143 |
+
html_dict["<!--CONVERSION_STATUS-->"] = (
|
| 144 |
+
r'<span class="label label-danger">Fail</span>'
|
| 145 |
+
) if toco_err_log else r'<span class="label label-success">Success</span>'
|
| 146 |
+
html_dict["<!--TOTAL_OPS_BEFORE_CONVERT-->"] = str(
|
| 147 |
+
toco_conversion_log_before.model_size)
|
| 148 |
+
html_dict["<!--TOTAL_OPS_AFTER_CONVERT-->"] = str(
|
| 149 |
+
toco_conversion_log_after.model_size)
|
| 150 |
+
html_dict["<!--BUILT_IN_OPS_COUNT-->"] = str(
|
| 151 |
+
sum(toco_conversion_log_after.built_in_ops.values()))
|
| 152 |
+
html_dict["<!--SELECT_OPS_COUNT-->"] = str(
|
| 153 |
+
sum(toco_conversion_log_after.select_ops.values()))
|
| 154 |
+
html_dict["<!--CUSTOM_OPS_COUNT-->"] = str(
|
| 155 |
+
sum(toco_conversion_log_after.custom_ops.values()))
|
| 156 |
+
html_dict["<!--POST_TRAINING_QUANT_ENABLED-->"] = (
|
| 157 |
+
"is" if post_training_quant_enabled else "isn't")
|
| 158 |
+
|
| 159 |
+
pre_op_profile = ""
|
| 160 |
+
post_op_profile = ""
|
| 161 |
+
|
| 162 |
+
# Generate pre-conversion op profiles as a list of HTML table rows.
|
| 163 |
+
for i in range(len(toco_conversion_log_before.op_list)):
|
| 164 |
+
# Append operator name column.
|
| 165 |
+
pre_op_profile += "<tr><td>" + toco_conversion_log_before.op_list[
|
| 166 |
+
i] + "</td>"
|
| 167 |
+
# Append input type column.
|
| 168 |
+
if i < len(toco_conversion_log_before.op_signatures):
|
| 169 |
+
pre_op_profile += "<td>" + get_input_type_from_signature(
|
| 170 |
+
toco_conversion_log_before.op_signatures[i]) + "</td></tr>"
|
| 171 |
+
else:
|
| 172 |
+
pre_op_profile += "<td></td></tr>"
|
| 173 |
+
|
| 174 |
+
# Generate post-conversion op profiles as a list of HTML table rows.
|
| 175 |
+
for op in toco_conversion_log_after.op_list:
|
| 176 |
+
supported_type = get_operator_type(op, toco_conversion_log_after)
|
| 177 |
+
post_op_profile += ("<tr><td>" + op + "</td><td>" + supported_type +
|
| 178 |
+
"</td></tr>")
|
| 179 |
+
|
| 180 |
+
html_dict["<!--REPEAT_TABLE1_ROWS-->"] = pre_op_profile
|
| 181 |
+
html_dict["<!--REPEAT_TABLE2_ROWS-->"] = post_op_profile
|
| 182 |
+
html_dict["<!--DOT_BEFORE_CONVERT-->"] = dot_before
|
| 183 |
+
html_dict["<!--DOT_AFTER_CONVERT-->"] = dot_after
|
| 184 |
+
if toco_err_log:
|
| 185 |
+
html_dict["<!--TOCO_INFO_LOG-->"] = html_escape(toco_err_log)
|
| 186 |
+
else:
|
| 187 |
+
success_info = ("TFLite graph conversion successful. You can preview the "
|
| 188 |
+
"converted model at: ") + tflite_graph_path
|
| 189 |
+
html_dict["<!--TOCO_INFO_LOG-->"] = html_escape(success_info)
|
| 190 |
+
|
| 191 |
+
# Replace each marker (as keys of html_dict) with the actual text (as values
|
| 192 |
+
# of html_dict) in the HTML template string.
|
| 193 |
+
template = self.html_template
|
| 194 |
+
for marker in html_dict:
|
| 195 |
+
template = template.replace(marker, html_dict[marker], 1)
|
| 196 |
+
# Check that the marker text is replaced.
|
| 197 |
+
if template.find(marker) != -1:
|
| 198 |
+
raise RuntimeError("Could not populate marker text %r" % marker)
|
| 199 |
+
|
| 200 |
+
with _file_io.FileIO(self.export_report_path, "w") as f:
|
| 201 |
+
f.write(template)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def gen_conversion_log_html(conversion_log_dir, quantization_enabled,
|
| 205 |
+
tflite_graph_path):
|
| 206 |
+
"""Generates an HTML report about the conversion process.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
conversion_log_dir: A string specifying the file directory of the conversion
|
| 210 |
+
logs. It's required that before calling this function, the
|
| 211 |
+
`conversion_log_dir`
|
| 212 |
+
already contains the following files: `toco_log_before.pb`,
|
| 213 |
+
`toco_log_after.pb`, `toco_tf_graph.dot`,
|
| 214 |
+
`toco_tflite_graph.dot`.
|
| 215 |
+
quantization_enabled: A boolean, passed from the tflite converter to
|
| 216 |
+
indicate whether post-training quantization is enabled during conversion.
|
| 217 |
+
tflite_graph_path: A string, the filepath to the converted TFLite model.
|
| 218 |
+
|
| 219 |
+
Raises:
|
| 220 |
+
IOError: When any of the required files doesn't exist.
|
| 221 |
+
"""
|
| 222 |
+
template_filename = _resource_loader.get_path_to_datafile("template.html")
|
| 223 |
+
if not os.path.exists(template_filename):
|
| 224 |
+
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
|
| 225 |
+
template_filename))
|
| 226 |
+
|
| 227 |
+
toco_log_before_path = os.path.join(conversion_log_dir, "toco_log_before.pb")
|
| 228 |
+
toco_log_after_path = os.path.join(conversion_log_dir, "toco_log_after.pb")
|
| 229 |
+
dot_before_path = os.path.join(conversion_log_dir, "toco_tf_graph.dot")
|
| 230 |
+
dot_after_path = os.path.join(conversion_log_dir, "toco_tflite_graph.dot")
|
| 231 |
+
if not os.path.exists(toco_log_before_path):
|
| 232 |
+
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
|
| 233 |
+
toco_log_before_path))
|
| 234 |
+
if not os.path.exists(toco_log_after_path):
|
| 235 |
+
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
|
| 236 |
+
toco_log_after_path))
|
| 237 |
+
if not os.path.exists(dot_before_path):
|
| 238 |
+
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
|
| 239 |
+
dot_before_path))
|
| 240 |
+
if not os.path.exists(dot_after_path):
|
| 241 |
+
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
|
| 242 |
+
dot_after_path))
|
| 243 |
+
|
| 244 |
+
html_generator = HTMLGenerator(
|
| 245 |
+
template_filename,
|
| 246 |
+
os.path.join(conversion_log_dir, "toco_conversion_summary.html"))
|
| 247 |
+
|
| 248 |
+
# Parse the generated `TocoConversionLog`.
|
| 249 |
+
toco_conversion_log_before = _toco_conversion_log_pb2.TocoConversionLog()
|
| 250 |
+
toco_conversion_log_after = _toco_conversion_log_pb2.TocoConversionLog()
|
| 251 |
+
with open(toco_log_before_path, "rb") as f:
|
| 252 |
+
toco_conversion_log_before.ParseFromString(f.read())
|
| 253 |
+
with open(toco_log_after_path, "rb") as f:
|
| 254 |
+
toco_conversion_log_after.ParseFromString(f.read())
|
| 255 |
+
|
| 256 |
+
# Read the dot file before/after the conversion.
|
| 257 |
+
with io.open(dot_before_path, "r", encoding="utf-8") as f:
|
| 258 |
+
dot_before = f.read().rstrip()
|
| 259 |
+
with io.open(dot_after_path, "r", encoding="utf-8") as f:
|
| 260 |
+
dot_after = f.read().rstrip()
|
| 261 |
+
|
| 262 |
+
html_generator.generate(toco_conversion_log_before, toco_conversion_log_after,
|
| 263 |
+
quantization_enabled, dot_before, dot_after,
|
| 264 |
+
toco_conversion_log_after.toco_err_logs,
|
| 265 |
+
tflite_graph_path)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/logging/toco_conversion_log_pb2.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
| 3 |
+
# source: tensorflow/lite/toco/logging/toco_conversion_log.proto
|
| 4 |
+
"""Generated protocol buffer code."""
|
| 5 |
+
from google.protobuf.internal import builder as _builder
|
| 6 |
+
from google.protobuf import descriptor as _descriptor
|
| 7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
| 8 |
+
from google.protobuf import symbol_database as _symbol_database
|
| 9 |
+
# @@protoc_insertion_point(imports)
|
| 10 |
+
|
| 11 |
+
_sym_db = _symbol_database.Default()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n6tensorflow/lite/toco/logging/toco_conversion_log.proto\x12\x04toco\"\xc9\x04\n\x11TocoConversionLog\x12\x0f\n\x07op_list\x18\x01 \x03(\t\x12=\n\x0c\x62uilt_in_ops\x18\x02 \x03(\x0b\x32\'.toco.TocoConversionLog.BuiltInOpsEntry\x12:\n\ncustom_ops\x18\x03 \x03(\x0b\x32&.toco.TocoConversionLog.CustomOpsEntry\x12:\n\nselect_ops\x18\x04 \x03(\x0b\x32&.toco.TocoConversionLog.SelectOpsEntry\x12\x15\n\rop_signatures\x18\x05 \x03(\t\x12\x1a\n\x12input_tensor_types\x18\x06 \x03(\t\x12\x1b\n\x13output_tensor_types\x18\x07 \x03(\t\x12\x19\n\x11log_generation_ts\x18\x08 \x01(\x03\x12\x12\n\nmodel_size\x18\t \x01(\x05\x12\x17\n\x0ftf_lite_version\x18\n \x01(\t\x12\x12\n\nos_version\x18\x0b \x01(\t\x12\x12\n\nmodel_hash\x18\x0c \x01(\t\x12\x15\n\rtoco_err_logs\x18\r \x01(\t\x1a\x31\n\x0f\x42uiltInOpsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\x1a\x30\n\x0e\x43ustomOpsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\x1a\x30\n\x0eSelectOpsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01')
|
| 17 |
+
|
| 18 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
| 19 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.lite.toco.logging.toco_conversion_log_pb2', globals())
|
| 20 |
+
if _descriptor._USE_C_DESCRIPTORS == False:
|
| 21 |
+
|
| 22 |
+
DESCRIPTOR._options = None
|
| 23 |
+
_TOCOCONVERSIONLOG_BUILTINOPSENTRY._options = None
|
| 24 |
+
_TOCOCONVERSIONLOG_BUILTINOPSENTRY._serialized_options = b'8\001'
|
| 25 |
+
_TOCOCONVERSIONLOG_CUSTOMOPSENTRY._options = None
|
| 26 |
+
_TOCOCONVERSIONLOG_CUSTOMOPSENTRY._serialized_options = b'8\001'
|
| 27 |
+
_TOCOCONVERSIONLOG_SELECTOPSENTRY._options = None
|
| 28 |
+
_TOCOCONVERSIONLOG_SELECTOPSENTRY._serialized_options = b'8\001'
|
| 29 |
+
_TOCOCONVERSIONLOG._serialized_start=65
|
| 30 |
+
_TOCOCONVERSIONLOG._serialized_end=650
|
| 31 |
+
_TOCOCONVERSIONLOG_BUILTINOPSENTRY._serialized_start=501
|
| 32 |
+
_TOCOCONVERSIONLOG_BUILTINOPSENTRY._serialized_end=550
|
| 33 |
+
_TOCOCONVERSIONLOG_CUSTOMOPSENTRY._serialized_start=552
|
| 34 |
+
_TOCOCONVERSIONLOG_CUSTOMOPSENTRY._serialized_end=600
|
| 35 |
+
_TOCOCONVERSIONLOG_SELECTOPSENTRY._serialized_start=602
|
| 36 |
+
_TOCOCONVERSIONLOG_SELECTOPSENTRY._serialized_end=650
|
| 37 |
+
# @@protoc_insertion_point(module_scope)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (203 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/__pycache__/toco_from_protos.cpython-310.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/toco/python/toco_from_protos.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Python console command to invoke TOCO from serialized protos."""
|
| 16 |
+
import argparse
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
# We need to import pywrap_tensorflow prior to the toco wrapper.
|
| 20 |
+
# pylint: disable=invalid-import-order,g-bad-import-order
|
| 21 |
+
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
|
| 22 |
+
from tensorflow.python import _pywrap_toco_api
|
| 23 |
+
from absl import app
|
| 24 |
+
|
| 25 |
+
FLAGS = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def execute(unused_args):
|
| 29 |
+
"""Runs the converter."""
|
| 30 |
+
with open(FLAGS.model_proto_file, "rb") as model_file:
|
| 31 |
+
model_str = model_file.read()
|
| 32 |
+
|
| 33 |
+
with open(FLAGS.toco_proto_file, "rb") as toco_file:
|
| 34 |
+
toco_str = toco_file.read()
|
| 35 |
+
|
| 36 |
+
with open(FLAGS.model_input_file, "rb") as input_file:
|
| 37 |
+
input_str = input_file.read()
|
| 38 |
+
|
| 39 |
+
output_str = _pywrap_toco_api.TocoConvert(
|
| 40 |
+
model_str,
|
| 41 |
+
toco_str,
|
| 42 |
+
input_str,
|
| 43 |
+
False, # extended_return
|
| 44 |
+
)
|
| 45 |
+
open(FLAGS.model_output_file, "wb").write(output_str)
|
| 46 |
+
sys.exit(0)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main():
|
| 50 |
+
global FLAGS
|
| 51 |
+
parser = argparse.ArgumentParser(
|
| 52 |
+
description="Invoke toco using protos as input.")
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"model_proto_file",
|
| 55 |
+
type=str,
|
| 56 |
+
help="File containing serialized proto that describes the model.")
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"toco_proto_file",
|
| 59 |
+
type=str,
|
| 60 |
+
help="File containing serialized proto describing how TOCO should run.")
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"model_input_file", type=str, help="Input model is read from this file.")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"model_output_file",
|
| 65 |
+
type=str,
|
| 66 |
+
help="Result of applying TOCO conversion is written here.")
|
| 67 |
+
|
| 68 |
+
FLAGS, unparsed = parser.parse_known_args()
|
| 69 |
+
|
| 70 |
+
app.run(main=execute, argv=[sys.argv[0]] + unparsed)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/tsl/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/tsl/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (190 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/tsl/profiler/__init__.py
ADDED
|
File without changes
|