diff --git a/.gitattributes b/.gitattributes index 479a9020a7d49dc7b0115df94c8bf5ddbdb44607..b3d46e01c005ca592001dd904f0d4d9470b34ee5 100644 --- a/.gitattributes +++ b/.gitattributes @@ -204,3 +204,5 @@ SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/pyth SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/python/analyzer_wrapper/_pywrap_analyzer_wrapper.so filter=lfs diff=lfs merge=lfs -text SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/lite/experimental/microfrontend/python/ops/_audio_microfrontend_op.so filter=lfs diff=lfs merge=lfs -text SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so filter=lfs diff=lfs merge=lfs -text +SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/_xla_ops.so filter=lfs diff=lfs merge=lfs -text +SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/stablehlo_extension.so filter=lfs diff=lfs merge=lfs -text diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33b3c474cc33dda150466858661a21c85383f0a2 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..661631baa952f6917f3f8c555c6485727cd7cf89 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/converter_flags_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/converter_flags_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e781e26acc7ed0dcdb94b20bca2e802deb3d1fb Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/converter_flags_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/model_flags_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/model_flags_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6afa5d1b27abcdee72bedea10ef88295e62c8da8 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/model_flags_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/types_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/types_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acb29ee43445e92ba5e8a3692a68b76b6c9cea35 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/__pycache__/types_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/converter_flags_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/converter_flags_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..4a9e3dcc285638b442e63aa3652b37f357b98c48 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/converter_flags_pb2.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/lite/converter_flags.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from tensorflow.compiler.mlir.lite.debug import debug_options_pb2 as tensorflow_dot_compiler_dot_mlir_dot_lite_dot_debug_dot_debug__options__pb2 +from tensorflow.compiler.mlir.lite import types_pb2 as tensorflow_dot_compiler_dot_mlir_dot_lite_dot_types__pb2 +from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as tensorflow_dot_compiler_dot_mlir_dot_quantization_dot_stablehlo_dot_quantization__config__pb2 +from tensorflow.compiler.mlir.quantization.stablehlo import quantization_options_pb2 as tensorflow_dot_compiler_dot_mlir_dot_quantization_dot_stablehlo_dot_quantization__options__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n3tensorflow/compiler/mlir/lite/converter_flags.proto\x12\x06tflite\x1a\x37tensorflow/compiler/mlir/lite/debug/debug_options.proto\x1a)tensorflow/compiler/mlir/lite/types.proto\x1aItensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto\x1aJtensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto\"\xd3\x14\n\x0e\x43onverterFlags\x12(\n\x0cinput_format\x18\x01 \x01(\x0e\x32\x12.tflite.FileFormat\x12)\n\routput_format\x18\x02 \x01(\x0e\x32\x12.tflite.FileFormat\x12\x30\n\x14inference_input_type\x18\x0b \x01(\x0e\x32\x12.tflite.IODataType\x12*\n\x0einference_type\x18\x04 \x01(\x0e\x32\x12.tflite.IODataType\x12\x1a\n\x12\x64\x65\x66\x61ult_ranges_min\x18\x05 \x01(\x02\x12\x1a\n\x12\x64\x65\x66\x61ult_ranges_max\x18\x06 \x01(\x02\x12 \n\x18\x64\x65\x66\x61ult_int16_ranges_min\x18\x0f \x01(\x02\x12 \n\x18\x64\x65\x66\x61ult_int16_ranges_max\x18\x10 \x01(\x02\x12\x17\n\x0f\x64rop_fake_quant\x18\x07 \x01(\x08\x12!\n\x19reorder_across_fake_quant\x18\x08 \x01(\x08\x12\x18\n\x10\x61llow_custom_ops\x18\n \x01(\x08\x12\x1f\n\x17\x64rop_control_dependency\x18\x0c \x01(\x08\x12+\n#debug_disable_recurrent_cell_fusion\x18\r \x01(\x08\x12%\n\x1dpropagate_fake_quant_num_bits\x18\x0e \x01(\x08\x12\x35\n-allow_nudging_weights_to_use_fast_gemm_kernel\x18\x11 \x01(\x08\x12\'\n\x1b\x64\x65\x64upe_array_min_size_bytes\x18\x12 \x01(\x03:\x02\x36\x34\x12&\n\x18split_tflite_lstm_inputs\x18\x13 \x01(\x08:\x04true\x12\x1f\n\x10quantize_weights\x18\x14 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x11\x64ump_graphviz_dir\x18\x18 \x01(\t\x12#\n\x1b\x64ump_graphviz_include_video\x18\x19 \x01(\x08\x12%\n\x16post_training_quantize\x18\x1a \x01(\x08:\x05\x66\x61lse\x12#\n\x14\x65nable_select_tf_ops\x18\x1b \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x66orce_select_tf_ops\x18\x1c \x01(\x08:\x05\x66\x61lse\x12\"\n\x13quantize_to_float16\x18\x1d \x01(\x08:\x05\x66\x61lse\x12#\n\x15\x61llow_dynamic_tensors\x18\x1e \x01(\x08:\x04true\x12\x1e\n\x16\x63onversion_summary_dir\x18\x1f \x01(\t\x12\x19\n\rcustom_opdefs\x18 \x03(\tB\x02\x18\x01\x12\x1a\n\x12select_user_tf_ops\x18! \x03(\t\x12.\n enable_tflite_resource_variables\x18\" \x01(\x08:\x04true\x12!\n\x12unfold_batchmatmul\x18# \x01(\x08:\x05\x66\x61lse\x12#\n\x15lower_tensor_list_ops\x18$ \x01(\x08:\x04true\x12-\n\x11\x61\x63\x63umulation_type\x18% \x01(\x0e\x32\x12.tflite.IODataType\x12\x1d\n\x0e\x61llow_bfloat16\x18& \x01(\x08:\x05\x66\x61lse\x12\x1f\n\x17\x61llow_all_select_tf_ops\x18\' \x01(\x08\x12*\n\x1bunfold_large_splat_constant\x18( \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x12supported_backends\x18) \x03(\t\x12\x39\n*default_to_single_batch_in_tensor_list_ops\x18* \x01(\x08:\x05\x66\x61lse\x12/\n disable_per_channel_quantization\x18+ \x01(\x08:\x05\x66\x61lse\x12\x32\n#enable_mlir_dynamic_range_quantizer\x18, \x01(\x08:\x05\x66\x61lse\x12\x1c\n\x14tf_quantization_mode\x18- \x01(\t\x12)\n\x1a\x64isable_infer_tensor_range\x18. \x01(\x08:\x05\x66\x61lse\x12&\n\x17use_fake_quant_num_bits\x18/ \x01(\x08:\x05\x66\x61lse\x12*\n\x1b\x65nable_dynamic_update_slice\x18\x30 \x01(\x08:\x05\x66\x61lse\x12!\n\x12preserve_assert_op\x18\x31 \x01(\x08:\x05\x66\x61lse\x12*\n\x1bguarantee_all_funcs_one_use\x18\x32 \x01(\x08:\x05\x66\x61lse\x12#\n\x14\x63onvert_to_stablehlo\x18\x33 \x01(\x08:\x05\x66\x61lse\x12\x30\n!enable_mlir_variable_quantization\x18\x34 \x01(\x08:\x05\x66\x61lse\x12&\n\x17\x64isable_fuse_mul_and_fc\x18\x35 \x01(\x08:\x05\x66\x61lse\x12M\n\x14quantization_options\x18\x36 \x01(\x0b\x32+.stablehlo.quantization.QuantizationOptionsB\x02\x18\x01\x12.\n\x1b\x65nable_hlo_to_tf_conversion\x18\x37 \x01(\x08:\x05\x66\x61lseB\x02\x18\x01\x12\x39\n\rdebug_options\x18\x38 \x01(\x0b\x32\".tensorflow.converter.DebugOptions\x12 \n\x11use_buffer_offset\x18\x39 \x01(\x08:\x05\x66\x61lse\x12.\n\x1flegalize_custom_tensor_list_ops\x18: \x01(\x08:\x05\x66\x61lse\x12$\n\x15reduce_type_precision\x18; \x01(\x08:\x05\x66\x61lse\x12!\n\x13qdq_conversion_mode\x18< \x01(\t:\x04NONE\x12G\n\x13quantization_config\x18= \x01(\x0b\x32*.stablehlo.quantization.QuantizationConfig\x12@\n1disable_per_channel_quantization_for_dense_layers\x18> \x01(\x08:\x05\x66\x61lse\x12/\n enable_composite_direct_lowering\x18? \x01(\x08:\x05\x66\x61lse\x12R\n\x16model_origin_framework\x18@ \x01(\x0e\x32+.tflite.ConverterFlags.ModelOriginFramework:\x05UNSET\x12\x32\n#canonicalizing_inf_as_min_max_float\x18\x41 \x01(\x08:\x05\x66\x61lse\x12\'\n\x18serialize_debug_metadata\x18\x42 \x01(\x08:\x05\x66\x61lse\"R\n\x14ModelOriginFramework\x12\t\n\x05UNSET\x10\x00\x12\x0e\n\nTENSORFLOW\x10\x01\x12\t\n\x05KERAS\x10\x02\x12\x07\n\x03JAX\x10\x03\x12\x0b\n\x07PYTORCH\x10\x04*\\\n\nFileFormat\x12\x17\n\x13\x46ILE_FORMAT_UNKNOWN\x10\x00\x12\x17\n\x13TENSORFLOW_GRAPHDEF\x10\x01\x12\n\n\x06TFLITE\x10\x02\x12\x10\n\x0cGRAPHVIZ_DOT\x10\x03') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.converter_flags_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _CONVERTERFLAGS.fields_by_name['custom_opdefs']._options = None + _CONVERTERFLAGS.fields_by_name['custom_opdefs']._serialized_options = b'\030\001' + _CONVERTERFLAGS.fields_by_name['quantization_options']._options = None + _CONVERTERFLAGS.fields_by_name['quantization_options']._serialized_options = b'\030\001' + _CONVERTERFLAGS.fields_by_name['enable_hlo_to_tf_conversion']._options = None + _CONVERTERFLAGS.fields_by_name['enable_hlo_to_tf_conversion']._serialized_options = b'\030\001' + _FILEFORMAT._serialized_start=2960 + _FILEFORMAT._serialized_end=3052 + _CONVERTERFLAGS._serialized_start=315 + _CONVERTERFLAGS._serialized_end=2958 + _CONVERTERFLAGS_MODELORIGINFRAMEWORK._serialized_start=2876 + _CONVERTERFLAGS_MODELORIGINFRAMEWORK._serialized_end=2958 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8876d657405a017f1943d8a8c2a690ebfce5557 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__pycache__/debug_options_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__pycache__/debug_options_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e16a44d1bea70d0eebdaef6a2b7e9471ca4ba88a Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/__pycache__/debug_options_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/debug_options_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/debug_options_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..4553a75fe3fc7e2d810bb77b4e51a2c4d2fc6046 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/debug/debug_options_pb2.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/lite/debug/debug_options.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7tensorflow/compiler/mlir/lite/debug/debug_options.proto\x12\x14tensorflow.converter\"\x84\x02\n\x0c\x44\x65\x62ugOptions\x12\x15\n\x0bir_dump_dir\x18\x01 \x01(\t:\x00\x12\x1e\n\x12ir_dump_pass_regex\x18\x02 \x01(\t:\x02.*\x12\x1e\n\x12ir_dump_func_regex\x18\x03 \x01(\t:\x02.*\x12\x1c\n\renable_timing\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x0fprint_ir_before\x18\x05 \x01(\t:\x00\x12\x18\n\x0eprint_ir_after\x18\x06 \x01(\t:\x00\x12#\n\x15print_ir_module_scope\x18\x07 \x01(\x08:\x04true\x12%\n\x1d\x65lide_elementsattrs_if_larger\x18\x08 \x01(\x03') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.debug.debug_options_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _DEBUGOPTIONS._serialized_start=82 + _DEBUGOPTIONS._serialized_end=342 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d4c1dc3c1538e5721dc50aec802366c5480b62 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__pycache__/converter_error_data_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__pycache__/converter_error_data_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7322babb78485fd933066e4e4157126b5fa2613d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/__pycache__/converter_error_data_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/converter_error_data_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/converter_error_data_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..b26fab0b58c2e60b15ec063d23af9c9a91d1bd32 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/metrics/converter_error_data_pb2.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/lite/metrics/converter_error_data.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n@tensorflow/compiler/mlir/lite/metrics/converter_error_data.proto\x12\x0etflite.metrics\"\xdc\x06\n\x12\x43onverterErrorData\x12\x11\n\tcomponent\x18\x01 \x01(\t\x12\x14\n\x0csubcomponent\x18\x02 \x01(\t\x12@\n\nerror_code\x18\x03 \x01(\x0e\x32,.tflite.metrics.ConverterErrorData.ErrorCode\x12\x15\n\rerror_message\x18\x04 \x01(\t\x12=\n\x08operator\x18\x05 \x01(\x0b\x32+.tflite.metrics.ConverterErrorData.Operator\x12=\n\x08location\x18\x06 \x01(\x0b\x32+.tflite.metrics.ConverterErrorData.Location\x1a\x18\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x01(\t\x1a\x39\n\x07\x46ileLoc\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x0c\n\x04line\x18\x02 \x01(\r\x12\x0e\n\x06\x63olumn\x18\x03 \x01(\r\x1aU\n\tSourceLoc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12:\n\x06source\x18\x02 \x01(\x0b\x32*.tflite.metrics.ConverterErrorData.FileLoc\x1a\x85\x01\n\x08Location\x12=\n\x04type\x18\x01 \x01(\x0e\x32/.tflite.metrics.ConverterErrorData.LocationType\x12:\n\x04\x63\x61ll\x18\x02 \x03(\x0b\x32,.tflite.metrics.ConverterErrorData.SourceLoc\"\xc5\x01\n\tErrorCode\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x18\n\x14\x45RROR_NEEDS_FLEX_OPS\x10\x01\x12\x1a\n\x16\x45RROR_NEEDS_CUSTOM_OPS\x10\x02\x12%\n!ERROR_UNSUPPORTED_CONTROL_FLOW_V1\x10\x03\x12/\n+ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR\x10\x04\x12\x1d\n\x18\x45RROR_GPU_NOT_COMPATIBLE\x10\xc8\x01\"J\n\x0cLocationType\x12\x0e\n\nUNKNOWNLOC\x10\x00\x12\x0b\n\x07NAMELOC\x10\x01\x12\x0f\n\x0b\x43\x41LLSITELOC\x10\x02\x12\x0c\n\x08\x46USEDLOC\x10\x03') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.metrics.converter_error_data_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _CONVERTERERRORDATA._serialized_start=85 + _CONVERTERERRORDATA._serialized_end=945 + _CONVERTERERRORDATA_OPERATOR._serialized_start=363 + _CONVERTERERRORDATA_OPERATOR._serialized_end=387 + _CONVERTERERRORDATA_FILELOC._serialized_start=389 + _CONVERTERERRORDATA_FILELOC._serialized_end=446 + _CONVERTERERRORDATA_SOURCELOC._serialized_start=448 + _CONVERTERERRORDATA_SOURCELOC._serialized_end=533 + _CONVERTERERRORDATA_LOCATION._serialized_start=536 + _CONVERTERERRORDATA_LOCATION._serialized_end=669 + _CONVERTERERRORDATA_ERRORCODE._serialized_start=672 + _CONVERTERERRORDATA_ERRORCODE._serialized_end=869 + _CONVERTERERRORDATA_LOCATIONTYPE._serialized_start=871 + _CONVERTERERRORDATA_LOCATIONTYPE._serialized_end=945 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/model_flags_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/model_flags_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..18e547c15c7d4fe46b6c0ce9f4e917f0e42accd8 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/model_flags_pb2.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/lite/model_flags.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from tensorflow.compiler.mlir.lite import types_pb2 as tensorflow_dot_compiler_dot_mlir_dot_lite_dot_types__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n/tensorflow/compiler/mlir/lite/model_flags.proto\x12\x06tflite\x1a)tensorflow/compiler/mlir/lite/types.proto\"5\n\x0fInputArrayShape\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x05\x12\x14\n\x0cunknown_rank\x18\x03 \x01(\x08\"\x93\x01\n\nInputArray\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x05shape\x18\x06 \x01(\x0b\x32\x17.tflite.InputArrayShape\x12\x12\n\nmean_value\x18\x03 \x01(\x02\x12\x14\n\tstd_value\x18\x04 \x01(\x02:\x01\x31\x12%\n\tdata_type\x18\x05 \x01(\x0e\x32\x12.tflite.IODataType\"t\n\x08RnnState\x12\x13\n\x0bstate_array\x18\x01 \x01(\t\x12\x1e\n\x16\x62\x61\x63k_edge_source_array\x18\x02 \x01(\t\x12\x13\n\x0b\x64iscardable\x18\x05 \x01(\x08\x12\x0c\n\x04size\x18\x03 \x01(\x05\x12\x10\n\x08num_dims\x18\x04 \x01(\x05\"\xf5\x01\n\x0f\x41rraysExtraInfo\x12.\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x1d.tflite.ArraysExtraInfo.Entry\x1a\xb1\x01\n\x05\x45ntry\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0bname_regexp\x18\x07 \x01(\t\x12\x0b\n\x03min\x18\x02 \x01(\x01\x12\x0b\n\x03max\x18\x03 \x01(\x01\x12%\n\tdata_type\x18\x04 \x01(\x0e\x32\x12.tflite.IODataType\x12&\n\x05shape\x18\x05 \x01(\x0b\x32\x17.tflite.InputArrayShape\x12\x1c\n\x14\x63onstant_float_value\x18\x06 \x01(\x02\"\xd0\x05\n\nModelFlags\x12(\n\x0cinput_arrays\x18\x01 \x03(\x0b\x32\x12.tflite.InputArray\x12\x15\n\routput_arrays\x18\x02 \x03(\t\x12\x1d\n\x15\x63ontrol_output_arrays\x18\x18 \x03(\t\x12\x16\n\x0evariable_batch\x18\n \x01(\x08\x12$\n\nrnn_states\x18\x0c \x03(\x0b\x32\x10.tflite.RnnState\x12\x33\n\x0cmodel_checks\x18\x0e \x03(\x0b\x32\x1d.tflite.ModelFlags.ModelCheck\x12 \n\x18\x61llow_nonexistent_arrays\x18\x10 \x01(\x08\x12\x1d\n\x15\x61llow_nonascii_arrays\x18\x11 \x01(\x08\x12\x32\n\x11\x61rrays_extra_info\x18\x12 \x01(\x0b\x32\x17.tflite.ArraysExtraInfo\x12(\n\x1a\x63hange_concat_input_ranges\x18\x13 \x01(\x08:\x04true\x12\x17\n\x0fsaved_model_dir\x18\x14 \x01(\t\x12\x1b\n\x13saved_model_version\x18\x15 \x01(\x05\x12\x18\n\x10saved_model_tags\x18\x16 \x03(\t\x12\"\n\x1asaved_model_exported_names\x18\x17 \x03(\t\x12\x16\n\x0euse_hlo_import\x18\x19 \x01(\x08\x12\x35\n\rhlo_file_type\x18\x1a \x01(\x0e\x32\x1e.tflite.ModelFlags.HloFileType\x1aT\n\nModelCheck\x12\x18\n\ncount_type\x18\x01 \x01(\t:\x04None\x12\x15\n\tcount_min\x18\x02 \x01(\x05:\x02-1\x12\x15\n\tcount_max\x18\x03 \x01(\x05:\x02-1\"7\n\x0bHloFileType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0c\n\x08HLO_TEXT\x10\x01\x12\r\n\tHLO_PROTO\x10\x02') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.model_flags_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _INPUTARRAYSHAPE._serialized_start=102 + _INPUTARRAYSHAPE._serialized_end=155 + _INPUTARRAY._serialized_start=158 + _INPUTARRAY._serialized_end=305 + _RNNSTATE._serialized_start=307 + _RNNSTATE._serialized_end=423 + _ARRAYSEXTRAINFO._serialized_start=426 + _ARRAYSEXTRAINFO._serialized_end=671 + _ARRAYSEXTRAINFO_ENTRY._serialized_start=494 + _ARRAYSEXTRAINFO_ENTRY._serialized_end=671 + _MODELFLAGS._serialized_start=674 + _MODELFLAGS._serialized_end=1394 + _MODELFLAGS_MODELCHECK._serialized_start=1253 + _MODELFLAGS_MODELCHECK._serialized_end=1337 + _MODELFLAGS_HLOFILETYPE._serialized_start=1339 + _MODELFLAGS_HLOFILETYPE._serialized_end=1394 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6484b4f6a47b276bc4b2f704821a7307c57fa868 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__pycache__/wrap_converter.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__pycache__/wrap_converter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b677cf2217345c45336924a45d578a910c64811 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/__pycache__/wrap_converter.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi new file mode 100644 index 0000000000000000000000000000000000000000..cdb1e881b7dc9fc2c070d2e7e657eb4219c8927e --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi @@ -0,0 +1,21 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +def Convert(model_flags_proto_txt_raw: object, toco_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., enable_mlir_converter: bool = ..., quantization_py_function_library = ...) -> object: ... +def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ..., debug_options_proto_txt_raw: object = ...) -> object: ... +def ExperimentalMlirSparsifyModel(input_contents_txt_raw: object) -> object: ... +def FlatBufferToMlir(arg0: str, arg1: bool) -> str: ... +def RegisterCustomOpdefs(custom_opdefs_txt_raw: object) -> object: ... +def RetrieveCollectedErrors() -> list: ... diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.so b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.so new file mode 100644 index 0000000000000000000000000000000000000000..2e06e1cbd8d641ebf149c816ea084167fc6c885a Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.so differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/wrap_converter.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/wrap_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..1c198f062388fcf325d6177939aee2d9e2e660e7 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/python/wrap_converter.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps TFLite Converter interface with python lazy loader.""" +# We need to import pywrap_tensorflow prior to the converter wrapper. +# pylint: disable=invalid-import-order,g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +from tensorflow.compiler.mlir.lite.python import _pywrap_converter_api +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib + + +def wrapped_convert( + model_flags_str, + toco_flags_str, + input_data_str, + debug_info_str, + enable_mlir_converter, +): + """Wraps TocoConvert with lazy loader.""" + return _pywrap_converter_api.Convert( + model_flags_str, + toco_flags_str, + input_data_str, + False, # extended_return + debug_info_str, + enable_mlir_converter, + py_function_lib.PyFunctionLibrary(), + ) + + +def wrapped_experimental_mlir_quantize( + input_data_str, + disable_per_channel, + fully_quantize, + inference_type, + input_data_type, + output_data_type, + enable_numeric_verify, + enable_whole_model_verify, + denylisted_ops, + denylisted_nodes, + enable_variable_quantization, + disable_per_channel_for_dense_layers, + debug_options_str, +): + """Wraps experimental mlir quantize model.""" + return _pywrap_converter_api.ExperimentalMlirQuantizeModel( + input_data_str, + disable_per_channel, + fully_quantize, + inference_type, + input_data_type, + output_data_type, + enable_numeric_verify, + enable_whole_model_verify, + denylisted_ops, + denylisted_nodes, + enable_variable_quantization, + disable_per_channel_for_dense_layers, + debug_options_str, + ) + + +def wrapped_experimental_mlir_sparsify(input_data_str): + """Wraps experimental mlir sparsify model.""" + return _pywrap_converter_api.ExperimentalMlirSparsifyModel(input_data_str) + + +def wrapped_register_custom_opdefs(custom_opdefs_list): + """Wraps RegisterCustomOpdefs with lazy loader.""" + return _pywrap_converter_api.RegisterCustomOpdefs(custom_opdefs_list) + + +def wrapped_retrieve_collected_errors(): + """Wraps RetrieveCollectedErrors with lazy loader.""" + return _pywrap_converter_api.RetrieveCollectedErrors() + + +def wrapped_flat_buffer_file_to_mlir(model, input_is_filepath): + """Wraps FlatBufferFileToMlir with lazy loader.""" + return _pywrap_converter_api.FlatBufferToMlir(model, input_is_filepath) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/types_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/types_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..4d842383abddea2fcf589b9e06fea1b17d069ae5 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/lite/types_pb2.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/lite/types.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)tensorflow/compiler/mlir/lite/types.proto\x12\x06tflite*\xb3\x02\n\nIODataType\x12\x18\n\x14IO_DATA_TYPE_UNKNOWN\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\x13\n\x0fQUANTIZED_UINT8\x10\x02\x12\t\n\x05INT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06STRING\x10\x05\x12\x13\n\x0fQUANTIZED_INT16\x10\x06\x12\x08\n\x04\x42OOL\x10\x07\x12\r\n\tCOMPLEX64\x10\x08\x12\x12\n\x0eQUANTIZED_INT8\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\x0b\n\x07\x46LOAT64\x10\x0b\x12\x0e\n\nCOMPLEX128\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\x0c\n\x08RESOURCE\x10\x0e\x12\x0b\n\x07VARIANT\x10\x0f\x12\n\n\x06UINT32\x10\x10\x12\t\n\x05UINT8\x10\x11\x12\x08\n\x04INT8\x10\x12\x12\t\n\x05INT16\x10\x13\x12\n\n\x06UINT16\x10\x14') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.types_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _IODATATYPE._serialized_start=54 + _IODATATYPE._serialized_end=361 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f124ca1043555ec13a8ee634b7d64563a95d8c73 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4dc33e4d6f6dbd48c0a6ebd3fb3379642f23a1f Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/quantization_config_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/quantization_config_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7084607c0cdd89ba82f882990bd9450f38ef4951 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/quantization_config_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/quantization_options_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/quantization_options_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f30a41cc88e0e6458c17243d9781cb6b4a1e064 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/__pycache__/quantization_options_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5dfa0ef57b6cd6e6c39c8a50e2572c9476a92d --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/quantization_config_pb2.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nItensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto\x12\x16stablehlo.quantization\"\x1c\n\x0cTfRecordFile\x12\x0c\n\x04path\x18\x01 \x01(\t\"\x8e\x01\n\x1bRepresentativeDatasetConfig\x12\x39\n\ttf_record\x18\x01 \x01(\x0b\x32$.stablehlo.quantization.TfRecordFileH\x00\x12\x1a\n\rsignature_key\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x06\n\x04\x66ileB\x10\n\x0e_signature_key\"\xc3\x01\n\x14StaticRangePtqPreset\x12T\n\x17representative_datasets\x18\x01 \x03(\x0b\x32\x33.stablehlo.quantization.RepresentativeDatasetConfig\x12/\n#enable_per_channel_quantized_weight\x18\x02 \x01(\x08\x42\x02\x18\x01\x12$\n\x1c\x65nable_full_int_quantization\x18\x03 \x01(\x08\"\x15\n\x13WeightOnlyPtqPreset\"\"\n\x12TfSavedModelConfig\x12\x0c\n\x04tags\x18\x01 \x03(\t\"v\n\x0ePipelineConfig\x12#\n\x16unpack_quantized_types\x18\x01 \x01(\x08H\x00\x88\x01\x01\x12$\n\x1cmerge_fusion_with_dequantize\x18\x02 \x01(\x08\x42\x19\n\x17_unpack_quantized_types\"\x1f\n\x0fQuantizableUnit\x12\x0c\n\x04name\x18\x01 \x01(\t\"\x87\x01\n\x12QuantizationResult\x12\x41\n\x10quantizable_unit\x18\x01 \x01(\x0b\x32\'.stablehlo.quantization.QuantizableUnit\x12.\n\x06method\x18\x02 \x01(\x0b\x32\x1e.stablehlo.quantization.Method\"R\n\x13QuantizationResults\x12;\n\x07results\x18\x01 \x03(\x0b\x32*.stablehlo.quantization.QuantizationResult\":\n\x12QuantizedDimension\x12\x16\n\tdimension\x18\x01 \x01(\x05H\x00\x88\x01\x01\x42\x0c\n\n_dimension\"\x0b\n\tPerTensor\"\x97\x01\n\rQuantizedType\x12\x45\n\x0f\x64imension_specs\x18\x01 \x01(\x0b\x32*.stablehlo.quantization.QuantizedDimensionH\x00\x12\x37\n\nper_tensor\x18\x02 \x01(\x0b\x32!.stablehlo.quantization.PerTensorH\x00\x42\x06\n\x04type\"\x10\n\x0eNoQuantization\"\xd3\x01\n\x0eStaticRangePtq\x12^\n\x15input_quantized_types\x18\x01 \x03(\x0b\x32?.stablehlo.quantization.StaticRangePtq.InputQuantizedTypesEntry\x1a\x61\n\x18InputQuantizedTypesEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x34\n\x05value\x18\x02 \x01(\x0b\x32%.stablehlo.quantization.QuantizedType:\x02\x38\x01\"\xd1\x01\n\rWeightOnlyPtq\x12]\n\x15input_quantized_types\x18\x01 \x03(\x0b\x32>.stablehlo.quantization.WeightOnlyPtq.InputQuantizedTypesEntry\x1a\x61\n\x18InputQuantizedTypesEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x34\n\x05value\x18\x02 \x01(\x0b\x32%.stablehlo.quantization.QuantizedType:\x02\x38\x01\"(\n\x17\x46unctionNameMatcherSpec\x12\r\n\x05regex\x18\x01 \x01(\t\"U\n\x0bMatcherSpec\x12\x46\n\rfunction_name\x18\x01 \x01(\x0b\x32/.stablehlo.quantization.FunctionNameMatcherSpec\"\xdb\x01\n\x06Method\x12\x41\n\x0fno_quantization\x18\x01 \x01(\x0b\x32&.stablehlo.quantization.NoQuantizationH\x00\x12\x42\n\x10static_range_ptq\x18\x02 \x01(\x0b\x32&.stablehlo.quantization.StaticRangePtqH\x00\x12@\n\x0fweight_only_ptq\x18\x03 \x01(\x0b\x32%.stablehlo.quantization.WeightOnlyPtqH\x00\x42\x08\n\x06method\"x\n\x10QuantizationSpec\x12\x34\n\x07matcher\x18\x01 \x01(\x0b\x32#.stablehlo.quantization.MatcherSpec\x12.\n\x06method\x18\x02 \x01(\x0b\x32\x1e.stablehlo.quantization.Method\"L\n\x11QuantizationSpecs\x12\x37\n\x05specs\x18\x01 \x03(\x0b\x32(.stablehlo.quantization.QuantizationSpec\"\xaa\x02\n\x0e\x44\x65\x62uggerConfig\x12J\n\rdebugger_type\x18\x01 \x01(\x0e\x32\x33.stablehlo.quantization.DebuggerConfig.DebuggerType\x12#\n\x1bunquantized_dump_model_path\x18\x02 \x01(\t\x12\x14\n\x0clog_dir_path\x18\x03 \x01(\t\"\x90\x01\n\x0c\x44\x65\x62uggerType\x12\x1d\n\x19\x44\x45\x42UGGER_TYPE_UNSPECIFIED\x10\x00\x12\x1d\n\x19\x44\x45\x42UGGER_TYPE_WHOLE_MODEL\x10\x01\x12\x1f\n\x1b\x44\x45\x42UGGER_TYPE_INT_PER_LAYER\x10\x02\x12!\n\x1d\x44\x45\x42UGGER_TYPE_FLOAT_PER_LAYER\x10\x03\"\x8e\x06\n\x12\x43\x61librationOptions\x12X\n\x12\x63\x61libration_method\x18\x01 \x01(\x0e\x32<.stablehlo.quantization.CalibrationOptions.CalibrationMethod\x12`\n\x16\x63\x61libration_parameters\x18\x02 \x01(\x0b\x32@.stablehlo.quantization.CalibrationOptions.CalibrationParameters\x12T\n\x17representative_datasets\x18\x03 \x03(\x0b\x32\x33.stablehlo.quantization.RepresentativeDatasetConfig\x12\x1c\n\x14\x63\x61libration_data_dir\x18\x04 \x01(\t\x12)\n!force_regenerate_calibration_data\x18\x05 \x01(\x08\x1aY\n\x15\x43\x61librationParameters\x12\x10\n\x08num_bins\x18\x01 \x01(\x05\x12\x16\n\x0emin_percentile\x18\x02 \x01(\x02\x12\x16\n\x0emax_percentile\x18\x03 \x01(\x02\"\xc1\x02\n\x11\x43\x61librationMethod\x12\"\n\x1e\x43\x41LIBRATION_METHOD_UNSPECIFIED\x10\x00\x12\x1e\n\x1a\x43\x41LIBRATION_METHOD_MIN_MAX\x10\x01\x12&\n\"CALIBRATION_METHOD_AVERAGE_MIN_MAX\x10\x02\x12+\n\'CALIBRATION_METHOD_HISTOGRAM_PERCENTILE\x10\x03\x12/\n+CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE\x10\x04\x12\x32\n.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY\x10\x05\x12.\n*CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC\x10\x06\"\xbb\x04\n\x12QuantizationConfig\x12O\n\x17static_range_ptq_preset\x18\x01 \x01(\x0b\x32,.stablehlo.quantization.StaticRangePtqPresetH\x00\x12M\n\x16weight_only_ptq_preset\x18\x07 \x01(\x0b\x32+.stablehlo.quantization.WeightOnlyPtqPresetH\x00\x12\x42\n\x0etf_saved_model\x18\x02 \x01(\x0b\x32*.stablehlo.quantization.TfSavedModelConfig\x12?\n\x0fpipeline_config\x18\x03 \x01(\x0b\x32&.stablehlo.quantization.PipelineConfig\x12\x38\n\x05specs\x18\x04 \x01(\x0b\x32).stablehlo.quantization.QuantizationSpecs\x12?\n\x0f\x64\x65\x62ugger_config\x18\x05 \x01(\x0b\x32&.stablehlo.quantization.DebuggerConfig\x12G\n\x13\x63\x61libration_options\x18\x06 \x01(\x0b\x32*.stablehlo.quantization.CalibrationOptions\x12\x1d\n\x10report_file_path\x18\x08 \x01(\tH\x01\x88\x01\x01\x42\x08\n\x06presetB\x13\n\x11_report_file_pathB\x03\xf8\x01\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.stablehlo.quantization_config_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\370\001\001' + _STATICRANGEPTQPRESET.fields_by_name['enable_per_channel_quantized_weight']._options = None + _STATICRANGEPTQPRESET.fields_by_name['enable_per_channel_quantized_weight']._serialized_options = b'\030\001' + _STATICRANGEPTQ_INPUTQUANTIZEDTYPESENTRY._options = None + _STATICRANGEPTQ_INPUTQUANTIZEDTYPESENTRY._serialized_options = b'8\001' + _WEIGHTONLYPTQ_INPUTQUANTIZEDTYPESENTRY._options = None + _WEIGHTONLYPTQ_INPUTQUANTIZEDTYPESENTRY._serialized_options = b'8\001' + _TFRECORDFILE._serialized_start=101 + _TFRECORDFILE._serialized_end=129 + _REPRESENTATIVEDATASETCONFIG._serialized_start=132 + _REPRESENTATIVEDATASETCONFIG._serialized_end=274 + _STATICRANGEPTQPRESET._serialized_start=277 + _STATICRANGEPTQPRESET._serialized_end=472 + _WEIGHTONLYPTQPRESET._serialized_start=474 + _WEIGHTONLYPTQPRESET._serialized_end=495 + _TFSAVEDMODELCONFIG._serialized_start=497 + _TFSAVEDMODELCONFIG._serialized_end=531 + _PIPELINECONFIG._serialized_start=533 + _PIPELINECONFIG._serialized_end=651 + _QUANTIZABLEUNIT._serialized_start=653 + _QUANTIZABLEUNIT._serialized_end=684 + _QUANTIZATIONRESULT._serialized_start=687 + _QUANTIZATIONRESULT._serialized_end=822 + _QUANTIZATIONRESULTS._serialized_start=824 + _QUANTIZATIONRESULTS._serialized_end=906 + _QUANTIZEDDIMENSION._serialized_start=908 + _QUANTIZEDDIMENSION._serialized_end=966 + _PERTENSOR._serialized_start=968 + _PERTENSOR._serialized_end=979 + _QUANTIZEDTYPE._serialized_start=982 + _QUANTIZEDTYPE._serialized_end=1133 + _NOQUANTIZATION._serialized_start=1135 + _NOQUANTIZATION._serialized_end=1151 + _STATICRANGEPTQ._serialized_start=1154 + _STATICRANGEPTQ._serialized_end=1365 + _STATICRANGEPTQ_INPUTQUANTIZEDTYPESENTRY._serialized_start=1268 + _STATICRANGEPTQ_INPUTQUANTIZEDTYPESENTRY._serialized_end=1365 + _WEIGHTONLYPTQ._serialized_start=1368 + _WEIGHTONLYPTQ._serialized_end=1577 + _WEIGHTONLYPTQ_INPUTQUANTIZEDTYPESENTRY._serialized_start=1268 + _WEIGHTONLYPTQ_INPUTQUANTIZEDTYPESENTRY._serialized_end=1365 + _FUNCTIONNAMEMATCHERSPEC._serialized_start=1579 + _FUNCTIONNAMEMATCHERSPEC._serialized_end=1619 + _MATCHERSPEC._serialized_start=1621 + _MATCHERSPEC._serialized_end=1706 + _METHOD._serialized_start=1709 + _METHOD._serialized_end=1928 + _QUANTIZATIONSPEC._serialized_start=1930 + _QUANTIZATIONSPEC._serialized_end=2050 + _QUANTIZATIONSPECS._serialized_start=2052 + _QUANTIZATIONSPECS._serialized_end=2128 + _DEBUGGERCONFIG._serialized_start=2131 + _DEBUGGERCONFIG._serialized_end=2429 + _DEBUGGERCONFIG_DEBUGGERTYPE._serialized_start=2285 + _DEBUGGERCONFIG_DEBUGGERTYPE._serialized_end=2429 + _CALIBRATIONOPTIONS._serialized_start=2432 + _CALIBRATIONOPTIONS._serialized_end=3214 + _CALIBRATIONOPTIONS_CALIBRATIONPARAMETERS._serialized_start=2801 + _CALIBRATIONOPTIONS_CALIBRATIONPARAMETERS._serialized_end=2890 + _CALIBRATIONOPTIONS_CALIBRATIONMETHOD._serialized_start=2893 + _CALIBRATIONOPTIONS_CALIBRATIONMETHOD._serialized_end=3214 + _QUANTIZATIONCONFIG._serialized_start=3217 + _QUANTIZATIONCONFIG._serialized_end=3788 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..e108e96ca0887457e073825d5d1b58ab60e8f3b6 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/stablehlo/quantization_options_pb2.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nJtensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto\x12\x16stablehlo.quantization\"^\n\x13QuantizationOptions\x12G\n\x13quantization_method\x18\x01 \x01(\x0b\x32*.stablehlo.quantization.QuantizationMethod\"\xdb\x01\n\x12QuantizationMethod\x12V\n\x1apreset_quantization_method\x18\x01 \x01(\x0b\x32\x30.stablehlo.quantization.PresetQuantizationMethodH\x00\x12V\n\x1a\x63ustom_quantization_method\x18\x02 \x01(\x0b\x32\x30.stablehlo.quantization.CustomQuantizationMethodH\x00\x42\x15\n\x13quantization_method\"\x92\x02\n\x18PresetQuantizationMethod\x12T\n\rpreset_method\x18\x01 \x01(\x0e\x32=.stablehlo.quantization.PresetQuantizationMethod.PresetMethod\"\x9f\x01\n\x0cPresetMethod\x12\x16\n\x12METHOD_UNSPECIFIED\x10\x00\x12\x0f\n\x0bWEIGHT_ONLY\x10\x01\x12,\n(POST_TRAINING_QUANTIZATION_DYNAMIC_RANGE\x10\x02\x12\x0b\n\x07\x46LOAT16\x10\x03\x12+\n\'POST_TRAINING_QUANTIZATION_STATIC_RANGE\x10\x04\"r\n\x18\x43ustomQuantizationMethod\x12V\n\x1bquantization_component_spec\x18\x01 \x03(\x0b\x32\x31.stablehlo.quantization.QuantizationComponentSpec\"\xc5\x05\n\x19QuantizationComponentSpec\x12g\n\x16quantization_component\x18\x01 \x01(\x0e\x32G.stablehlo.quantization.QuantizationComponentSpec.QuantizationComponent\x12M\n\tbit_width\x18\x02 \x01(\x0e\x32:.stablehlo.quantization.QuantizationComponentSpec.BitWidth\x12K\n\x08\x62it_type\x18\x03 \x01(\x0e\x32\x39.stablehlo.quantization.QuantizationComponentSpec.BitType\x12\x1b\n\x13\x65nable_narrow_range\x18\x04 \x01(\x08\x12\'\n\x1f\x65nable_per_channel_quantization\x18\x05 \x01(\x08\x12\x18\n\x10\x65nable_symmetric\x18\x06 \x01(\x08\"v\n\x15QuantizationComponent\x12\x19\n\x15\x43OMPONENT_UNSPECIFIED\x10\x00\x12\x18\n\x14\x43OMPONENT_ACTIVATION\x10\x01\x12\x14\n\x10\x43OMPONENT_WEIGHT\x10\x02\x12\x12\n\x0e\x43OMPONENT_BIAS\x10\x03\"k\n\x08\x42itWidth\x12\x19\n\x15\x42IT_WIDTH_UNSPECIFIED\x10\x00\x12\x0f\n\x0b\x42IT_WIDTH_4\x10\x01\x12\x0f\n\x0b\x42IT_WIDTH_8\x10\x02\x12\x10\n\x0c\x42IT_WIDTH_16\x10\x03\x12\x10\n\x0c\x42IT_WIDTH_32\x10\x04\"^\n\x07\x42itType\x12\x18\n\x14\x42IT_TYPE_UNSPECIFIED\x10\x00\x12\x10\n\x0c\x42IT_TYPE_INT\x10\x01\x12\x12\n\x0e\x42IT_TYPE_FLOAT\x10\x02\x12\x13\n\x0f\x42IT_TYPE_BFLOAT\x10\x03\x42\x03\xf8\x01\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.stablehlo.quantization_options_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\370\001\001' + _QUANTIZATIONOPTIONS._serialized_start=102 + _QUANTIZATIONOPTIONS._serialized_end=196 + _QUANTIZATIONMETHOD._serialized_start=199 + _QUANTIZATIONMETHOD._serialized_end=418 + _PRESETQUANTIZATIONMETHOD._serialized_start=421 + _PRESETQUANTIZATIONMETHOD._serialized_end=695 + _PRESETQUANTIZATIONMETHOD_PRESETMETHOD._serialized_start=536 + _PRESETQUANTIZATIONMETHOD_PRESETMETHOD._serialized_end=695 + _CUSTOMQUANTIZATIONMETHOD._serialized_start=697 + _CUSTOMQUANTIZATIONMETHOD._serialized_end=811 + _QUANTIZATIONCOMPONENTSPEC._serialized_start=814 + _QUANTIZATIONCOMPONENTSPEC._serialized_end=1523 + _QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT._serialized_start=1200 + _QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT._serialized_end=1318 + _QUANTIZATIONCOMPONENTSPEC_BITWIDTH._serialized_start=1320 + _QUANTIZATIONCOMPONENTSPEC_BITWIDTH._serialized_end=1427 + _QUANTIZATIONCOMPONENTSPEC_BITTYPE._serialized_start=1429 + _QUANTIZATIONCOMPONENTSPEC_BITTYPE._serialized_end=1523 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46dda5f302e05da2b074ed13afe6d6eccb2a2e61 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/exported_model_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/exported_model_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cd18aca9e323e8c45f8d349d80d477a4025b903 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/exported_model_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/quantization_options_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/quantization_options_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..200f67f8161aca3a8c93fc7085f8a09ce0b7ad0b Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/__pycache__/quantization_options_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb343210fca6451c46791f2c37b4a4774538a6e1 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/calibration_algorithm.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/calibration_algorithm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c931bce8137569541a2889b3d04bb0e6d67cbd5 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/calibration_algorithm.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/calibration_statistics_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/calibration_statistics_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a07231ed7e0e46ec145f4f1b73b1aaeb5f32429e Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/__pycache__/calibration_statistics_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_algorithm.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..23ec80c65b4d7d60af5199bf4f439cb1fca5a5f9 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_algorithm.py @@ -0,0 +1,395 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines CalibrationAlgorithm for calculating min and max values calculated by calibration method.""" +import abc +import itertools +import logging + +import numpy as np + +from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as stablehlo_quant_config_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 as calib_stats_pb2 + + +_CalibrationMethod = ( + stablehlo_quant_config_pb2.CalibrationOptions.CalibrationMethod +) +_REGISTRY = {} + + +def _implements(calib_method: _CalibrationMethod): + def decorator(cls): + assert calib_method not in _REGISTRY + _REGISTRY[calib_method] = cls + return cls + + return decorator + + +class _CalibrationAlgorithmBase(abc.ABC): + """Abstract base class for calibration algorithm.""" + + def __init__( + self, + statistics: calib_stats_pb2.CalibrationStatistics, + calib_opts: stablehlo_quant_config_pb2.CalibrationOptions, + ): + self._statistics = statistics + self._calib_opts = calib_opts + + @abc.abstractmethod + def get_min_max_value(self) -> tuple[float, float]: + pass + + +class _HistogramCalibrationAlgorithmBase(_CalibrationAlgorithmBase): + """Base class for histogram calibrators.""" + + def __init__( + self, + statistics: calib_stats_pb2.CalibrationStatistics, + calib_opts: stablehlo_quant_config_pb2.CalibrationOptions, + ): + """Builds histogram using statistics.histogram_statistics. + + lower_bound hist_mid + v v + |=========|=========|=========|=========|=========| + bin width + + Args: + statistics: Collected calibration statistics. + calib_opts: Calibration options used for calculating min and max. + """ + super().__init__(statistics, calib_opts) + hist_stats = statistics.histogram_statistics + self._bin_width = hist_stats.bin_width + self._lower_bound = hist_stats.lower_bound + self._hist_freq = np.array(hist_stats.hist_freq) + self._num_bins = len(self._hist_freq) + self._num_bits = 8 + # i-th bin has a range [bins[i], bins[i + 1]). + # bins[i] = lower_bound + i * bin_width + # bins[i + 1] = lower_bound + (i + 1) * bin_width + # So hist_mids[i] = (lower_bound + bin_width / 2) + bin_width * i + first_mid = self._lower_bound + self._bin_width / 2 + last_mid = first_mid + (self._num_bins - 1) * self._bin_width + self._hist_mids = np.linspace(first_mid, last_mid, self._num_bins) + + def _get_dequantized_hist_mids_after_quantize( + self, quant_min: float, quant_max: float + ) -> np.ndarray: + """Quantizes and dequantizes hist_mids using quant_min and quant_max. + + Quantization converts the range of numbers from [quant_min, quant_max] to + [0, 2^num_bits - 1]. Values less than quant_min are converted to 0, and + values greater than quant_max are converted to 2^num_bits - 1. + + The histogram represents the distribution of the data, and our goal is to + find the quant_min and quant_max that best describe this distribution. To do + this, we quantize hist_mids using quant_min and quant_max and dequantize + them again. Then the difference between hist_mids and dequantized hist_mids + equates to quantization error when using quant_min and quant_max. + + + Args: + quant_min: The minimum real value that can be represented by a quantized + value. + quant_max: The maximum real value that can be represented by a quantized + value. + + Returns: + dequantized hist_mids after quantizing by quant_min and quant_max + """ + maxbound = 2**self._num_bits - 1 + minbound = 0 + scale = (quant_max - quant_min) / maxbound + zero_point = -quant_min / scale + + # Limit the range of zero_point and scale in case (quant_max - quant_min) + # is unusually small. + if abs(zero_point) > 9e9: + zero_point = 9e9 + if abs(scale) < 1e-9: + scale = 1e-9 + + zero_point = round(zero_point) + quantized_hist_mids = np.clip( + np.round(self._hist_mids / scale) + zero_point, minbound, maxbound + ) + dequantized_hist_mids = scale * (quantized_hist_mids - zero_point) + return dequantized_hist_mids + + def _get_weighted_mean_squared_error( + self, quant_min, quant_max + ) -> tuple[float, float, float]: + """Gets mean squared error between hist_mids and dequantized hist_mids. + + Quantization converts the range of numbers from [quant_min, quant_max] to + [0, 2^num_bits - 1]. Values less than quant_min are converted to 0, and + values greater than quant_max are converted to 2^num_bits - 1. + + Args: + quant_min: The minimum real value that can be represented by a quantized + value. + quant_max: The maximum real value that can be represented by a quantized + value. + + Returns: + (error, quant_min, quant_max): Tuple of weighted mean squared error. + error = (hist_mids - dequantized_hist_mids)**2 * hist_freq + """ + dequantized_hist_mids = self._get_dequantized_hist_mids_after_quantize( + quant_min, quant_max + ) + squared_error = (self._hist_mids - dequantized_hist_mids) ** 2 + weighted_error = np.sum(squared_error * self._hist_freq) + return (weighted_error, quant_min, quant_max) + + def _get_min_max_value_by_expanding_range( + self, start_idx: int + ) -> tuple[float, float]: + """Starting from start_idx, expand left and right alternately to find the min value of mse loss. + + Args: + start_idx: Index to start quantization. + + Returns: + (min_value, max_value): Min and max calculated. + """ + # Tuple of (mse_error, quant_min, quant_max). + mse_min = (float('inf'), float('inf'), float('inf')) + left, right = start_idx, start_idx + + # If this value is true, it moves left, otherwise it moves right. + move_left = True + while not (left == 0 and right == self._num_bins - 1): + # Decrease left if right can't be moved or move_left is true. + if (move_left and left > 0) or (right == self._num_bins - 1): + left = max(left - 1, 0) + # Else increase right. + else: + right = min(right + 1, self._num_bins - 1) + # Toogle the move_left. + move_left = not move_left + quant_min, quant_max = self._hist_mids[left], self._hist_mids[right] + mse_tuple = self._get_weighted_mean_squared_error(quant_min, quant_max) + mse_min = min(mse_tuple, mse_min) + # Extract (quant_min, quant_max) from (mse_error, quant_min, quant_max). + min_value, max_value = mse_min[1], mse_min[2] + return min_value, max_value + + +@_implements(_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX) +class _MinMax(_CalibrationAlgorithmBase): + """MinMaxCalibrationAlgorithm for calculating min and max values of calibration result. + + MinMax calibration calculates the global min and global max values. + + global min = min of given sample inputs + global max = max of given sample inputs + """ + + def get_min_max_value(self) -> tuple[float, float]: + """Calculates the global min and max values. + + Returns: + (min_value, max_value): Min and max calculated using MinMax + """ + return ( + self._statistics.min_max_statistics.global_min, + self._statistics.min_max_statistics.global_max, + ) + + +@_implements(_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX) +class _AverageMinMax(_CalibrationAlgorithmBase): + """AverageMinMaxCalibrationAlgorithm for calculating min and max values of calibration result. + + AverageMinMax calibration calculates the average of min and max values. + average of min = sum of min values / number of samples + average of max = sum of max values / number of samples + """ + + def get_min_max_value(self) -> tuple[float, float]: + """Calculates the average of min and max values. + + Returns: + (min_value, max_value): Min and max calculated using AverageMinMax + + Raises: + ValueError: num_samples is 0. + """ + average_min_max_statistics = self._statistics.average_min_max_statistics + # num_samples is guaranteed to be larger than 0 because + # get_statistics_from_calibrator throws an exception if num_samples == 0. + num_samples = average_min_max_statistics.num_samples + if num_samples == 0: + raise ValueError( + 'num_samples must not be 0 when calibration method is' + f' AverageMinMax: {self._calib_opts}' + ) + min_value, max_value = ( + average_min_max_statistics.min_sum / num_samples, + average_min_max_statistics.max_sum / num_samples, + ) + + return min_value, max_value + + +@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE) +class _HistogramPercentile(_HistogramCalibrationAlgorithmBase): + """HistogramPercentile for calculating min and max values of calibration result.""" + + def get_min_max_value(self) -> tuple[float, float]: + """Calculates min and max from statistics using calibration options. + + A "percentile" is a statistical concept that represents the value below + which a given percentage of data falls in a dataset. It involves sorting the + data from smallest to largest and then finding the value at a specified + percentage position. For example, the 0.01 percentile represents the value + in a given data set that corresponds to the lowest 0.01% of the data. + + HistogramPercentile calibration uses min_percentile and max_percentile to + find min and max. + + min_percentile and max_percentile must be in range [0, 100]. + min_percentile is 0.001 by default. + max_percentile is 99.999 by default. + + Returns: + (min_value, max_value): Min and max calculated using HistogramPercentile + """ + total_freq = sum(self._hist_freq) + # hist_freq_cumsum is dividing cumulative sum of hist_freq by total_freq + # hist_freq_cumsum's value is in range [0, 1] by its definition + hist_freq_cumsum = np.cumsum(self._hist_freq) / total_freq + + # min_percentile and max_percentile are converted from [0, 100] to [0, 1]. + min_quantile, max_quantile = ( + self._calib_opts.calibration_parameters.min_percentile / 100.0, + self._calib_opts.calibration_parameters.max_percentile / 100.0, + ) + + # Get index of min/max quantile. + min_quantile_idx, max_quantile_idx = ( + np.searchsorted(hist_freq_cumsum, min_quantile, side='right'), + np.searchsorted(hist_freq_cumsum, max_quantile, side='left'), + ) + + # Get value of min/max quantile index. + min_value, max_value = ( + self._hist_mids[min_quantile_idx], + self._hist_mids[max_quantile_idx], + ) + + return min_value, max_value + + +@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE) +class _HistogramMseBruteforce(_HistogramCalibrationAlgorithmBase): + """HistogramMseBruteforce for calculating min and max values of calibration result.""" + + def get_min_max_value(self) -> tuple[float, float]: + """Finds the optimal quant_min and quant_max by testing all possible cases. + + It guarantees optimal quant_min and quant_max for the representative + dataset, but not for the test dataset. + + Returns: + (min_value, max_value): Min and max calculated using + HistogramMseBruteforce. + """ + if self._num_bins > 512: + logging.warning( + 'num_bins=%d is too large. The HISTOGRAM_MSE_BRUTEFORCE method tests' + ' all histogram mid value pairs, so it may take a long time.', + self._num_bins, + ) + # Tuple of (mse_error, quant_min, quant_max). + mse_min = (float('inf'), float('inf'), float('inf')) + + # Calculate the error for all hist_mid pairs. + for left, right in itertools.combinations(range(self._num_bins), 2): + quant_min, quant_max = self._hist_mids[left], self._hist_mids[right] + mse_tuple = self._get_weighted_mean_squared_error(quant_min, quant_max) + mse_min = min(mse_tuple, mse_min) + min_value, max_value = mse_min[1], mse_min[2] + + return min_value, max_value + + +@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY) +class _HistogramMseMaxFrequency(_HistogramCalibrationAlgorithmBase): + """HistogramMseMaxFrequency for calculating min and max values of calibration result.""" + + def get_min_max_value(self) -> tuple[float, float]: + """Finds min and max starting from the index of the max frequency. + + The HistogramMseMaxFrequency method starts from the bin with the highest + frequency and expands the range to both sides. This performs well when data + is well spread on both sides of the max frequency. + + Returns: + (min_value, max_value): Min and max calculated using method to expand the + range based on max frequency. + """ + # Find the index of max frequency. + freq_max_idx = np.argmax(self._hist_freq) + return self._get_min_max_value_by_expanding_range(freq_max_idx) + + +@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC) +class _HistogramMseSymmetric(_HistogramCalibrationAlgorithmBase): + """HistogramMseSymmetric for calculating min and max values of calibration result.""" + + def get_min_max_value(self) -> tuple[float, float]: + """Finds min and max starting from the center index. + + The HistogramMseSymmetric method starts from the center bin and expands the + range to both sides. This works better when the data is well-centered. + + Returns: + (min_value, max_value): Min and max calculated using the method starting + from center and expanding. + """ + + # This function is currently only called in this method, but will be used in + # other methods in the future. + return self._get_min_max_value_by_expanding_range(self._num_bins // 2) + + +def get_min_max_value( + statistics: calib_stats_pb2.CalibrationStatistics, + calib_opts: stablehlo_quant_config_pb2.CalibrationOptions, +) -> tuple[float, float]: + """Calculates min and max from statistics using calibration options. + + Args: + statistics: Collected calibration statistics. + calib_opts: Calibration options used for calculating min and max. + + Returns: + (min_value, max_value): Min and max calculated using calib_opts. + + Raises: + ValueError: Unsupported calibration method is given. + """ + calib_method = calib_opts.calibration_method + if calib_method not in _REGISTRY: + raise ValueError(f'Unsupported calibration method: {calib_method}') + + calibration_algorithm = _REGISTRY[calib_method](statistics, calib_opts) + return calibration_algorithm.get_min_max_value() diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..96926c33e2ff61ca8f30aba8ae61915496ffecf1 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_pb2.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nXtensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto\x12\x15tensorflow.calibrator\"\x9c\x04\n\x15\x43\x61librationStatistics\x12Y\n\x12min_max_statistics\x18\x01 \x01(\x0b\x32=.tensorflow.calibrator.CalibrationStatistics.MinMaxStatistics\x12h\n\x1a\x61verage_min_max_statistics\x18\x02 \x01(\x0b\x32\x44.tensorflow.calibrator.CalibrationStatistics.AverageMinMaxStatistics\x12^\n\x14histogram_statistics\x18\x03 \x01(\x0b\x32@.tensorflow.calibrator.CalibrationStatistics.HistogramStatistics\x1a:\n\x10MinMaxStatistics\x12\x12\n\nglobal_min\x18\x01 \x01(\x02\x12\x12\n\nglobal_max\x18\x02 \x01(\x02\x1aP\n\x17\x41verageMinMaxStatistics\x12\x0f\n\x07min_sum\x18\x01 \x01(\x02\x12\x0f\n\x07max_sum\x18\x02 \x01(\x02\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x1aP\n\x13HistogramStatistics\x12\x11\n\tbin_width\x18\x01 \x01(\x02\x12\x13\n\x0blower_bound\x18\x02 \x01(\x02\x12\x11\n\thist_freq\x18\x03 \x03(\x02\"\xd0\x01\n\x18\x43\x61librationStatisticsMap\x12S\n\nstatistics\x18\x01 \x03(\x0b\x32?.tensorflow.calibrator.CalibrationStatisticsMap.StatisticsEntry\x1a_\n\x0fStatisticsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12;\n\x05value\x18\x02 \x01(\x0b\x32,.tensorflow.calibrator.CalibrationStatistics:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.tensorflow.calibrator.calibration_statistics_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\370\001\001' + _CALIBRATIONSTATISTICSMAP_STATISTICSENTRY._options = None + _CALIBRATIONSTATISTICSMAP_STATISTICSENTRY._serialized_options = b'8\001' + _CALIBRATIONSTATISTICS._serialized_start=116 + _CALIBRATIONSTATISTICS._serialized_end=656 + _CALIBRATIONSTATISTICS_MINMAXSTATISTICS._serialized_start=434 + _CALIBRATIONSTATISTICS_MINMAXSTATISTICS._serialized_end=492 + _CALIBRATIONSTATISTICS_AVERAGEMINMAXSTATISTICS._serialized_start=494 + _CALIBRATIONSTATISTICS_AVERAGEMINMAXSTATISTICS._serialized_end=574 + _CALIBRATIONSTATISTICS_HISTOGRAMSTATISTICS._serialized_start=576 + _CALIBRATIONSTATISTICS_HISTOGRAMSTATISTICS._serialized_end=656 + _CALIBRATIONSTATISTICSMAP._serialized_start=659 + _CALIBRATIONSTATISTICSMAP._serialized_end=867 + _CALIBRATIONSTATISTICSMAP_STATISTICSENTRY._serialized_start=772 + _CALIBRATIONSTATISTICSMAP_STATISTICSENTRY._serialized_end=867 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/exported_model_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/exported_model_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..45d3f9c0b26523fa356814d7ca5c35c77bcf71f8 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/exported_model_pb2.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from tensorflow.core.framework import graph_pb2 as tensorflow_dot_core_dot_framework_dot_graph__pb2 +from tensorflow.core.protobuf import meta_graph_pb2 as tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2 +from tensorflow.core.protobuf import saver_pb2 as tensorflow_dot_core_dot_protobuf_dot_saver__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nEtensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto\x12\x17tensorflow.quantization\x1a%tensorflow/core/framework/graph.proto\x1a)tensorflow/core/protobuf/meta_graph.proto\x1a$tensorflow/core/protobuf/saver.proto\"\xbe\x03\n\rExportedModel\x12\'\n\tgraph_def\x18\x01 \x01(\x0b\x32\x14.tensorflow.GraphDef\x12\x16\n\x0einit_node_name\x18\x02 \x01(\t\x12\x16\n\x0e\x63heckpoint_dir\x18\x05 \x01(\t\x12U\n\x10\x66unction_aliases\x18\x06 \x03(\x0b\x32;.tensorflow.quantization.ExportedModel.FunctionAliasesEntry\x12\x31\n\x0f\x61sset_file_defs\x18\x08 \x03(\x0b\x32\x18.tensorflow.AssetFileDef\x12\'\n\tsaver_def\x18\n \x01(\x0b\x32\x14.tensorflow.SaverDef\x1a\x36\n\x14\x46unctionAliasesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x07\x10\x08J\x04\x08\t\x10\nR\x15variable_shared_namesR\x11restore_node_nameR\x0esave_node_nameR\x17\x66ile_prefix_tensor_nameB\x03\xf8\x01\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.tensorflow.exported_model_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\370\001\001' + _EXPORTEDMODEL_FUNCTIONALIASESENTRY._options = None + _EXPORTEDMODEL_FUNCTIONALIASESENTRY._serialized_options = b'8\001' + _EXPORTEDMODEL._serialized_start=219 + _EXPORTEDMODEL._serialized_end=665 + _EXPORTEDMODEL_FUNCTIONALIASESENTRY._serialized_start=504 + _EXPORTEDMODEL_FUNCTIONALIASESENTRY._serialized_end=558 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ff93e25652793890197d8f5695ce438dd2079b2e --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_quantize_model.pyi @@ -0,0 +1,72 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import Any + +from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd + +# LINT.IfChange(quantize_qat_model) +def quantize_qat_model( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], + py_function_library: py_function_lib.PyFunctionLibrary, +) -> Any: ... # Status + +# LINT.ThenChange() + +# LINT.IfChange(quantize_ptq_dynamic_range) +def quantize_ptq_dynamic_range( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], + py_function_library: py_function_lib.PyFunctionLibrary, +) -> Any: ... # Status + +# LINT.ThenChange() + +# LINT.IfChange(quantize_weight_only) +def quantize_weight_only( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options_serialized: bytes, + *, + signature_def_map_serialized: dict[str, bytes], + py_function_library: py_function_lib.PyFunctionLibrary, +) -> Any: ... # Status + +# LINT.ThenChange() + +# LINT.IfChange(quantize_ptq_static_range) +def quantize_ptq_static_range( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options_serialized: bytes, + *, + signature_keys: list[str], + signature_def_map_serialized: dict[str, bytes], + py_function_library: py_function_lib.PyFunctionLibrary, + # Value type: RepresentativeDatasetFile. + representative_dataset_file_map_serialized: dict[str, bytes], +) -> Any: ... # Status + +# LINT.ThenChange() diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f7dec2d2a5dee753ccfc79d3ffd35f9cda92a9e3 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.py @@ -0,0 +1,926 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines TF Quantization API from SavedModel to SavedModel.""" + +import tempfile +from typing import Mapping, Optional + +from absl import logging + +from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as stablehlo_quant_config_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2 +from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib +from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model +from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset +from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import load as saved_model_load +from tensorflow.python.saved_model import loader_impl as saved_model_loader +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.trackable import autotrackable +from tensorflow.python.util import tf_export + +# Type aliases for quant_opts_pb2 messages. +_QuantizationOptions = tf_export.tf_export( + 'quantization.experimental.QuantizationOptions' +)(quant_opts_pb2.QuantizationOptions) + +_QuantizationMethod = tf_export.tf_export( + 'quantization.experimental.QuantizationMethod' +)(quant_opts_pb2.QuantizationMethod) + +_QuantizationComponentSpec = tf_export.tf_export( + 'quantization.experimental.QuantizationComponentSpec' +)(quant_opts_pb2.QuantizationComponentSpec) + +_UnitWiseQuantizationSpec = tf_export.tf_export( + 'quantization.experimental.UnitWiseQuantizationSpec' +)(quant_opts_pb2.UnitWiseQuantizationSpec) + +_PresetMethod = _QuantizationMethod.PresetMethod +_CalibrationMethod = ( + stablehlo_quant_config_pb2.CalibrationOptions.CalibrationMethod +) + +_QuantizationComponent = _QuantizationComponentSpec.QuantizationComponent +_TensorType = _QuantizationComponentSpec.TensorType + +_RepresentativeDatasetFile = quant_opts_pb2.RepresentativeDatasetFile + +# Mapping of signature def key -> SignatureDef. +_SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] + +# Default minimum number of elements in the weights for them to be quantized +# during dynamic range quantization (DRQ) and weight-only quantization. +_DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS = 1024 + + +def _is_qat_saved_model(saved_model_path: str): + """Checks if the SavedModel is QAT-enabled by looking for 'FakeQuant' ops.""" + saved_model_proto = saved_model_loader.parse_saved_model(saved_model_path) + for meta_graph in saved_model_proto.meta_graphs: + if any( + node.op.startswith('FakeQuant') for node in meta_graph.graph_def.node + ): + return True + for function in meta_graph.graph_def.library.function: + if any(node.op.startswith('FakeQuant') for node in function.node_def): + return True + return False + + +def _serialize_signature_def_map( + signature_def_map: _SignatureDefMap, +) -> dict[str, bytes]: + """Serializes SignatureDef values in `signature_def_map`. + + Args: + signature_def_map: Signature key -> SignatureDef mapping. + + Returns: + Signature def map where the values (`SignatureDef`) are serialized. + """ + signature_def_map_serialized = {} + for key, signature_def in signature_def_map.items(): + signature_def_map_serialized[key] = signature_def.SerializeToString() + + return signature_def_map_serialized + + +def _save_representative_dataset( + representative_dataset: repr_dataset.RepresentativeDatasetOrMapping, + signature_def_map: _SignatureDefMap, +) -> Mapping[str, _RepresentativeDatasetFile]: + """Saves the representative dataset to temporary TFRecord files. + + Args: + representative_dataset: Representative dataset used for the calibration + step. Representative datasets should exist for each signature def key in + `signature_def_keys`. + signature_def_map: Signature def key -> SignatureDef mapping. + + Returns: + A map from signature key to the saved representative dataset file. + """ + if isinstance(representative_dataset, Mapping): + if set(signature_def_map.keys()) != set(representative_dataset.keys()): + raise ValueError( + 'The signature keys and the keys of representative dataset map ' + f'do not match. Signature keys: {set(signature_def_map.keys())}, ' + f'representative dataset map: {set(representative_dataset.keys())}.' + ) + representative_dataset_map = representative_dataset + elif len(signature_def_map.keys()) > 1: + raise ValueError( + 'Representative dataset is not a mapping (got: ' + f'{type(representative_dataset)}), but there is more than one ' + 'signature key provided. Please provide a map of ' + '{signature_key -> dataset} with more than one signature key.' + ) + else: + representative_dataset_map = { + list(signature_def_map.keys())[0]: representative_dataset, + } + + # Save the representative dataset to temporary TFRecord files. + path_map = {} + expected_input_key_map = {} + for signature_key, signature_def in signature_def_map.items(): + # Filepath is the second return value of mkstemp. + _, path_map[signature_key] = tempfile.mkstemp( + suffix='.tfrecord', prefix=signature_key + ) + expected_input_key_map[signature_key] = signature_def.inputs.keys() + + return repr_dataset.TfRecordRepresentativeDatasetSaver( + path_map=path_map, + expected_input_key_map=expected_input_key_map, + ).save(representative_dataset_map) + + +def _run_static_range_qat( + src_saved_model_path: str, + dst_saved_model_path: str, + quant_opts: _QuantizationOptions, + signature_def_map: _SignatureDefMap, +) -> None: + """Runs static-range quantization for a Quantization-Aware Trained model. + + Runs the quantization for a model trained using QAT. + + Args: + src_saved_model_path: Path to the source SavedModel directory. + dst_saved_model_path: Path to the destination SavedModel directory. + quant_opts: Quantization options. + signature_def_map: Signature def key -> SignatureDef mapping. + """ + logging.info('Running static-range quantization for QAT model.') + + pywrap_quantize_model.quantize_qat_model( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=quant_opts.SerializeToString(), + signature_keys=list(quant_opts.signature_keys), + signature_def_map_serialized=_serialize_signature_def_map( + signature_def_map + ), + py_function_library=py_function_lib.PyFunctionLibrary(), + ) + + +def _run_static_range_ptq( + src_saved_model_path: str, + dst_saved_model_path: str, + quant_opts: _QuantizationOptions, + representative_dataset: Mapping[str, _RepresentativeDatasetFile], + signature_def_map: _SignatureDefMap, +) -> None: + """Runs static-range Post-Training Quantization. + + Runs static-range PTQ for the model. Runs the calibration step with + `representative_dataset` to collect statistics required for quantization. This + produces the quantized GraphDef along with the SignatureDefs which might have + been modified according to the changes in the graph. + + Args: + src_saved_model_path: Path to the source SavedModel directory. + dst_saved_model_path: Path to the destination SavedModel directory. + quant_opts: Quantization options. + representative_dataset: A map from signature key to the saved representative + dataset file. + signature_def_map: Signature def key -> SignatureDef mapping. + + Raises: + ValueError if the graph doesn't contain a valid signature. + """ + logging.info('Running static-range post-training quantization.') + + signature_def_map_serialized = _serialize_signature_def_map(signature_def_map) + + # `quantize_ptq_static_range` requires `RepresentativeDatasetFile`s to be + # serialized. Serialize the values to match the type. + dataset_file_map_serialized = { + signature_key: dataset_file.SerializeToString() + for signature_key, dataset_file in representative_dataset.items() + } + pywrap_quantize_model.quantize_ptq_static_range( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=quant_opts.SerializeToString(), + signature_keys=list(quant_opts.signature_keys), + signature_def_map_serialized=signature_def_map_serialized, + py_function_library=py_function_lib.PyFunctionLibrary(), + representative_dataset_file_map_serialized=dataset_file_map_serialized, + ) + + +def _static_range_quantize( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options: _QuantizationOptions, + representative_dataset: Optional[ + repr_dataset.RepresentativeDatasetOrMapping + ] = None, +) -> autotrackable.AutoTrackable: + """Quantizes the given SavedModel via static range quantization. + + If the model is not trained with Quantization-Aware Training (QAT) technique, + it requires `representative_dataset` to collect statistics required for + quantization. If non-None `representative_dataset` is provided with a QAT + model input, `representative_dataset` will be ignored. + + Args: + src_saved_model_path: Path to the saved model. When representative_dataset + is not provided, this should be a model trained with QAT. + dst_saved_model_path: The path to save the output SavedModel. The directory + will be overwritten if not empty. + quantization_options: QuantizationOptions proto describing quantization + related config. + representative_dataset: a generator that returns a dictionary in {input_key: + input_value} format or a tuple with signature key and a dictionary in + {input_key: input_value} format that feeds calibration data for quantizing + model. This should be provided when the model is not a QAT model. + + Returns: + A SavedModel object with TF quantization applied. + + Raises: + ValueError: when representative_dataset is not provided for non-QAT model. + RuntimeError: When a MetaGraphDef could not be found associated with `tags` + in the SavedModel. + """ + logging.info( + 'Running static range quantization on model: %s', src_saved_model_path + ) + logging.info('QuantizationOptions: \n%s', quantization_options) + + is_qat_saved_model_or_method_no_quantize = _is_qat_saved_model( + src_saved_model_path + ) or ( + quantization_options.quantization_method.preset_method + == _QuantizationMethod.METHOD_NO_QUANTIZE + ) + signature_def_map = save_model.get_signatures_from_saved_model( + src_saved_model_path, + quantization_options.signature_keys, + set(quantization_options.tags), + ) + + if ( + representative_dataset is not None + and quantization_options.representative_datasets + ): + raise ValueError( + 'Do not specify both the `representative_dataset` argument and' + ' the `representative_datasets` field in `QuantizationOptions`.' + ) + + saved_representative_dataset = quantization_options.representative_datasets + if representative_dataset is not None: + saved_representative_dataset = _save_representative_dataset( + representative_dataset, signature_def_map + ) + + # Checks if the model is from QAT or method is METHOD_NO_QUANTIZE. + if ( + not saved_representative_dataset + and not is_qat_saved_model_or_method_no_quantize + ): + raise ValueError( + 'When `representative_dataset` is not provided, the model should be ' + 'trained with quantization-aware training (QAT).' + ) + if quantization_options.min_num_elements_for_weights > 0: + logging.warn( + 'min_num_elements_for_weights is set but is not supported for the ' + 'Post-training static range quantization. ' + 'The flag is ignored.' + ) + + if is_qat_saved_model_or_method_no_quantize: + _run_static_range_qat( + src_saved_model_path, + dst_saved_model_path, + quantization_options, + signature_def_map, + ) + else: + _run_static_range_ptq( + src_saved_model_path, + dst_saved_model_path, + quantization_options, + saved_representative_dataset, + signature_def_map, + ) + + return saved_model_load.load(dst_saved_model_path) + + +def _dynamic_range_quantize( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options: _QuantizationOptions, +) -> autotrackable.AutoTrackable: + """Quantizes the given SavedModel via post-training dynamic range quantization. + + Args: + src_saved_model_path: Path to the saved model. + dst_saved_model_path: The path to save the output SavedModel. The directory + will be overwritten if not empty. + quantization_options: QuantizationOptions proto describing quantization + related config. + + Returns: + A SavedModel object with TF quantization applied. + + Raises: + ValueError: when the model is QAT model. + """ + mode_str = 'dynamic-range quantization' + if _is_qat_saved_model(src_saved_model_path): + raise ValueError( + 'The models trained with quantization-aware training (QAT) is not ' + 'supported for %s.' % mode_str + ) + + logging.info( + 'Running post-training %s on model: %s', mode_str, src_saved_model_path + ) + logging.info('QuantizationOptions: \n%s', quantization_options) + + signature_def_map = save_model.get_signatures_from_saved_model( + src_saved_model_path, + quantization_options.signature_keys, + quantization_options.tags, + ) + + # Apply post-training dynamic range quantization to the model. + pywrap_quantize_model.quantize_ptq_dynamic_range( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=quantization_options.SerializeToString(), + signature_keys=list(quantization_options.signature_keys), + signature_def_map_serialized=_serialize_signature_def_map( + signature_def_map + ), + py_function_library=py_function_lib.PyFunctionLibrary(), + ) + + return saved_model_load.load(dst_saved_model_path) + + +def _weight_only_quantize( + src_saved_model_path: str, + dst_saved_model_path: str, + quantization_options: quant_opts_pb2.QuantizationOptions, +) -> autotrackable.AutoTrackable: + """Quantizes the given SavedModel via weight-only quantization. + + Args: + src_saved_model_path: Path to the saved model. + dst_saved_model_path: The path to save the output SavedModel. The directory + will be overwritten if not empty. + quantization_options: QuantizationOptions proto describing quantization + related config. + + Returns: + A SavedModel object with TF quantization applied. + + Raises: + ValueError: when the model is QAT model. + """ + mode_str = 'weight-only quantization' + + # QAT weight-only is not supported yet. + if _is_qat_saved_model(src_saved_model_path): + raise ValueError( + 'The models trained with quantization-aware training (QAT) is not ' + 'supported for %s.' % mode_str + ) + + logging.info( + 'Running post-training %s on model: %s', mode_str, src_saved_model_path + ) + logging.info('QuantizationOptions: \n%s', quantization_options) + + signature_def_map = save_model.get_signatures_from_saved_model( + src_saved_model_path, + list(quantization_options.signature_keys), + set(quantization_options.tags), + ) + + pywrap_quantize_model.quantize_weight_only( + src_saved_model_path, + dst_saved_model_path, + quantization_options_serialized=quantization_options.SerializeToString(), + signature_def_map_serialized=_serialize_signature_def_map( + signature_def_map + ), + py_function_library=py_function_lib.PyFunctionLibrary(), + ) + + return saved_model_load.load(dst_saved_model_path) + + +def _verify_output_dir(output_dir: Optional[str], overwrite: bool) -> None: + """Verifies the output directory. + + Raises an error if `output_dir` is not suitable for writing the output saved + model. + + Args: + output_dir: Output directory. + overwrite: An option allowing to overwrite the existing output directory if + set to true. Does not actually create or modify the `output_dir` in this + function. + + Raises: + FileExistsError: Iff `output_dir` is not empty and `overwrite` is false. + """ + dir_not_empty = ( + output_dir is not None + and file_io.file_exists_v2(output_dir) + and file_io.list_directory_v2(output_dir) + ) + + if dir_not_empty and not overwrite: + raise FileExistsError( + f'Output directory already exists: {output_dir} . ' + 'Please set overwrite_output_directory to true to ' + 'overwrite the existing directory.' + ) + + +def _populate_quantization_component_spec( + quant_method: _QuantizationMethod, +) -> None: + """Populates default values for QuantizationComponentSpec. + + Args: + quant_method: The quantization method to be updated. + """ + # Make sure creating one spec per component. + updated_component_spec = dict() + + # Populate default configuration. + if ( + quant_method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8 + or quant_method.preset_method == _PresetMethod.METHOD_DYNAMIC_RANGE_INT8 + ): + updated_component_spec[_QuantizationComponent.COMPONENT_ACTIVATION] = ( + _QuantizationComponentSpec( + quantization_component=_QuantizationComponent.COMPONENT_ACTIVATION, + tensor_type=_TensorType.TENSORTYPE_INT_8, + ) + ) + updated_component_spec[_QuantizationComponent.COMPONENT_WEIGHT] = ( + _QuantizationComponentSpec( + quantization_component=_QuantizationComponent.COMPONENT_WEIGHT, + tensor_type=_TensorType.TENSORTYPE_INT_8, + ) + ) + updated_component_spec[_QuantizationComponent.COMPONENT_BIAS] = ( + _QuantizationComponentSpec( + quantization_component=_QuantizationComponent.COMPONENT_BIAS, + tensor_type=_TensorType.TENSORTYPE_INT_32, + ) + ) + elif ( + quant_method.preset_method + == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 + ): + updated_component_spec[_QuantizationComponent.COMPONENT_WEIGHT] = ( + _QuantizationComponentSpec( + quantization_component=_QuantizationComponent.COMPONENT_WEIGHT, + tensor_type=_TensorType.TENSORTYPE_INT_8, + ) + ) + + # Override if quantization_component_spec is specified. + if quant_method.quantization_component_specs: + # Check if the component spec is supported configuration in TF-Quant. + for component_spec in quant_method.quantization_component_specs: + if component_spec.quantization_component in [ + _QuantizationComponent.COMPONENT_WEIGHT, + _QuantizationComponent.COMPONENT_ACTIVATION, + ]: + if component_spec.tensor_type != _TensorType.TENSORTYPE_INT_8: + raise ValueError( + 'Only int8 precision is supported for input operands.' + ) + else: + if component_spec.tensor_type != _TensorType.TENSORTYPE_INT_32: + raise ValueError('Only int32 precision is supported for bias.') + # Update with the custom spec. + updated_component_spec[component_spec.quantization_component] = ( + component_spec + ) + + # Update the componet spec + del quant_method.quantization_component_specs[:] + quant_method.quantization_component_specs.extend( + updated_component_spec.values() + ) + + if ( + quant_method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8 + or quant_method.preset_method == _PresetMethod.METHOD_DYNAMIC_RANGE_INT8 + ) and (len(quant_method.quantization_component_specs) != 3): + raise ValueError('Only 3 components are needed for', quant_method) + elif ( + quant_method.preset_method + == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 + ) and len(quant_method.quantization_component_specs) != 1: + raise ValueError('At least one component spec needs to be specified.') + + +def _populate_unitwise_quantization_specs( + quantization_options: _QuantizationOptions, +) -> None: + """Verifies and pupulates unitwise quantization specs.""" + if not quantization_options.unit_wise_quantization_specs: + return + + sorted_top_level_component_specs = sorted( + quantization_options.quantization_method.quantization_component_specs, + key=lambda x: x.quantization_component, + ) + + for unitwise_spec in quantization_options.unit_wise_quantization_specs: + if not unitwise_spec.unit: + raise ValueError( + 'UnitWiseQuantizationSpec must contain at least one unit.' + ) + + for unit in unitwise_spec.unit: + if not unit.op_type and not unit.node_name: + raise ValueError('Either `op_type` or `node_name` must be specified.') + + _populate_quantization_component_spec(unitwise_spec.quantization_method) + + component_specs = ( + unitwise_spec.quantization_method.quantization_component_specs + ) + if component_specs and ( + sorted_top_level_component_specs + != sorted(component_specs, key=lambda x: x.quantization_component) + ): + raise ValueError( + 'Currently unit-wise quantization spec only supports NO_QUANTIZE and' + ' same quantization method as the top-level `quantization_method`' + ) + + +def _populate_calibration_options( + quantization_options: quant_opts_pb2.QuantizationOptions, +): + """Populates default values for CalibrationOptions. + + Args: + quantization_options: An instance of QuantizationOptions with a field + specifying CalibrationOptions + """ + + calib_opts = quantization_options.calibration_options + if ( + calib_opts.calibration_method + == _CalibrationMethod.CALIBRATION_METHOD_UNSPECIFIED + ): + calib_opts.calibration_method = ( + _CalibrationMethod.CALIBRATION_METHOD_MIN_MAX + ) + elif ( + calib_opts.calibration_method + == _CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE + ): + if not calib_opts.calibration_parameters.num_bins: + calib_opts.calibration_parameters.num_bins = 512 + if not calib_opts.calibration_parameters.min_percentile: + calib_opts.calibration_parameters.min_percentile = 0.001 + if not calib_opts.calibration_parameters.max_percentile: + calib_opts.calibration_parameters.max_percentile = 99.999 + # Check the activation_tensor_type of HISTOGRAM_MSE methods. + elif calib_opts.calibration_method in [ + _CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, + _CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, + _CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, + ]: + activation_tensor_type = ( + quantization_options.quantization_method.quantization_component_specs[ + _QuantizationComponent.COMPONENT_ACTIVATION + ].tensor_type + ) + # Unlike the HISTOGRAM_PERCENTILE method, the HISTOGRAM_MSE method uses + # num_bits because it actually quantizes and dequantizes values. + if activation_tensor_type != _TensorType.TENSORTYPE_INT_8: + raise ValueError( + 'Only TENSORTYPE_INT_8 is supported for HISTOGRAM_MSE calibration' + f' methods. calibration_method={calib_opts.calibration_method}' + ) + + if not calib_opts.calibration_parameters.num_bins: + calib_opts.calibration_parameters.num_bins = 512 + + if calib_opts.calibration_data_dir: + save_model.create_empty_output_dir( + calib_opts.calibration_data_dir, + overwrite=calib_opts.force_regenerate_calibration_data, + ) + + +def _populate_quantization_options_default_values( + quantization_options: _QuantizationOptions, +) -> None: + """Populates default values for QuantizationOptions. + + Populates unspecified or unset fields of QuantizationOptions with the default + values. + + * If `op_set` is unspecified, it defaults to `OpSet.XLA`. + * If `freeze_all_variables` is not set, it defaults to `True`. + * Check if configurations are set correctly: + - Per-channel quantization is supported for Uniform Quantized opset only. + + Args: + quantization_options: An instance of QuantizationOptions. + """ + if quantization_options.op_set == quant_opts_pb2.OpSet.OP_SET_UNSPECIFIED: + quantization_options.op_set = quant_opts_pb2.OpSet.XLA + + if not quantization_options.tags: + quantization_options.tags.append(tag_constants.SERVING) + + if not quantization_options.signature_keys: + quantization_options.signature_keys.append( + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + ) + + if not quantization_options.HasField('freeze_all_variables'): + quantization_options.freeze_all_variables = True + + if quantization_options.enable_legacy_weight_only: + raise ValueError( + 'Legacy weight-only is deprecated. Use weight-only quantization method.' + ) + + # Converter assumes options are specified. So set SRQ explicitly. + if ( + quantization_options.quantization_method.preset_method + == _PresetMethod.METHOD_UNSPECIFIED + ): + logging.debug( + '"preset_method" for QuantizationMethod is not specified.' + 'Static range quantization is used by default.' + ) + quantization_options.quantization_method.preset_method = ( + _PresetMethod.METHOD_STATIC_RANGE_INT8 + ) + + # Check default quantization option values for weight-only quantization. + # TODO(b/242805842): Find good minimum_elements_for_weights number for server. + # please also update default value in tflite converter: + # tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc;l=201 + if quantization_options.min_num_elements_for_weights == 0: + quantization_options.min_num_elements_for_weights = ( + _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS + ) + logging.warning( + ( + 'QuantizationOptions.min_num_elements_for_weights is not set (0).' + ' Setting to the default value: %d.' + ), + _DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS, + ) + + if not quantization_options.HasField('enable_per_channel_quantization'): + quantization_options.enable_per_channel_quantization = False + + if quantization_options.enable_per_channel_quantization and not ( + ( + quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED + or quantization_options.quantization_method.preset_method + == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 + ) + or ( + quantization_options.op_set + in (quant_opts_pb2.OpSet.XLA, quant_opts_pb2.OpSet.STABLEHLO) + and quantization_options.quantization_method.preset_method + == _PresetMethod.METHOD_STATIC_RANGE_INT8 + ) + ): + raise ValueError( + 'Currently, per-channel quantization is supported for Uniform Quantized' + ' opset, weight only quantization, or XLA/StableHLO opset with static' + ' range quantization.' + ) + + if ( + quantization_options.quantization_method.preset_method + == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 + and ( + quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED + or quantization_options.op_set == quant_opts_pb2.OpSet.TF + ) + ): + raise ValueError('TF/Uniform quantized opset does not support weight-only.') + + if (quantization_options.op_set == quant_opts_pb2.OpSet.STABLEHLO) and ( + quantization_options.quantization_method.preset_method + != _PresetMethod.METHOD_STATIC_RANGE_INT8 + and quantization_options.quantization_method.preset_method + != _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 + ): + raise ValueError( + 'StableHLO quantized opset currently only supports static range' + ' quantization and weight-only quantizationvia TF Quantizer.' + ) + + # Set `force_graph_mode_calibration` to True to avoid skipping op execution, + # which are not connected to return ops, during calibration execution. + # TODO: b/335031954 - Bring back support to run calibration in Eager mode. + logging.debug( + 'Setting `force_graph_mode_calibration = True` to ensure the calibration' + ' mode is executed properly.' + ) + quantization_options.force_graph_mode_calibration = True + + if quantization_options.HasField('debugger_config'): + if not quantization_options.debugger_config.log_dir_path: + quantization_options.debugger_config.log_dir_path = '/tmp/dumps' + + if ( + quantization_options.debugger_config.debugger_type + == stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_UNSPECIFIED + ): + raise ValueError( + 'Debugger is enabled but debugger type was not specified.' + ) + + if ( + quantization_options.debugger_config.debugger_type + == stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_WHOLE_MODEL + and not quantization_options.debugger_config.unquantized_dump_model_path + ): + raise ValueError( + 'Debugger type whole model verify was used but' + ' unquantized_dump_model_path was not specified.' + ) + + # Check and populate quantization component spec. + _populate_quantization_component_spec( + quantization_options.quantization_method + ) + # Verify and populate unit-wise quantization specs. + _populate_unitwise_quantization_specs(quantization_options) + + if ( + quantization_options.quantization_method.preset_method + == _PresetMethod.METHOD_STATIC_RANGE_INT8 + ): + # Check and populate calibration options. + _populate_calibration_options(quantization_options) + + +@tf_export.tf_export('quantization.experimental.quantize_saved_model') +def quantize( + saved_model_path: str, + output_directory: Optional[str] = None, + quantization_options: Optional[_QuantizationOptions] = None, + representative_dataset: Optional[ + repr_dataset.RepresentativeDatasetOrMapping + ] = None, + *, + overwrite_output_directory: bool = False, +) -> autotrackable.AutoTrackable: + """Quantizes the SavedModel with the given quantization options. + + Example usage: + ```python + # Quantizing a model trained with QAT. + quantization_options = tf.quantization.experimental.QuantizationOptions( + signature_keys=['your_signature_key'], + ) + tf.quantization.experimental.quantize_saved_model( + '/tmp/input_model', + '/tmp/output_model', + quantization_options=quantization_options, + ) + + # When quantizing a model trained without QAT (Post-Training Quantization), + # a representative dataset is required. + representative_dataset = [{"input": tf.random.uniform(shape=(3, 3))} + for _ in range(256)] + tf.quantization.experimental.quantize_saved_model( + '/tmp/input_model', + '/tmp/output_model', + quantization_options=quantization_options, + representative_dataset={'your_signature_key': representative_dataset}, + ) + + # In addition to preset quantization methods, fine-grained control of + # quantization for each component is also supported. + _QuantizationComponentSpec = ( + tf.quantization.experimental.QuantizationComponentSpec + ) + quantization_options = tf.quantization.experimental.QuantizationOptions( + signature_keys=['your_signature_key'], + quantization_method=tf.quantization.experimental.QuantizationMethod( + quantization_component_specs=[ + _QuantizationComponentSpec( + quantization_component=( + _QuantizationComponentSpec.COMPONENT_ACTIVATION + ), + tensor_type=_QuantizationComponentSpec.TENSORTYPE_INT_8, + ) + ] + ) + ) + tf.quantization.experimental.quantize_saved_model( + '/tmp/input_model', + '/tmp/output_model', + quantization_options=quantization_options, + ) + ``` + + Args: + saved_model_path: Path to the saved model. When representative_dataset is + not provided, this should be a model trained with QAT. + output_directory: The path to save the output SavedModel. Set + `overwrite_output_directory` to `True` to overwrite any existing contents + in the directory if not empty. + quantization_options: A set of options for quantization. If None, it uses + post-training static range quantization with XLA opset by default. + representative_dataset: an iterator that returns a dictionary of {input_key: + input_value} or a map from signature key to a dictionary of {input_key: + input_value} that feeds calibration data for quantizing model. The + representative should be provided when the model is a PTQ model. It can be + provided either via this parameter or via the `representative_datasets` + field in `QuantizationOptions`. + overwrite_output_directory: If set to true, overwrites the output directory + iff it isn't empty. The default value is false. + + Returns: + A SavedModel object with TF quantization applied, or None if no quantization + is performed. + + Raises: + ValueError: When 1) representative_dataset is not provided for non QAT model + for enabling static range quantization, 2) invalid value is provided as + a quantization method, or 3) provide representative dataset via both + argument and QuantizationOptions. + ValueError: When the specified quantization method is not yet supported. + """ + _verify_output_dir(output_directory, overwrite_output_directory) + + # Set default values for None arguments. + if output_directory is None: + output_directory = tempfile.mkdtemp() + + if quantization_options is None: + quantization_options = _QuantizationOptions() + + _populate_quantization_options_default_values(quantization_options) + + method: _QuantizationMethod = quantization_options.quantization_method + if ( + method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8 + or method.preset_method == _PresetMethod.METHOD_NO_QUANTIZE + ): + return _static_range_quantize( + saved_model_path, + output_directory, + quantization_options, + representative_dataset, + ) + elif method.preset_method == _PresetMethod.METHOD_DYNAMIC_RANGE_INT8: + return _dynamic_range_quantize( + saved_model_path, + output_directory, + quantization_options, + ) + elif ( + method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8 + ): + return _weight_only_quantize( + saved_model_path, + output_directory, + quantization_options, + ) + else: + raise ValueError( + 'Quantization method {method.preset_method} is not supported.' + ) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c18358745866b41c82b0a5dab0b97c12c6d175ba --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/representative_dataset.py @@ -0,0 +1,402 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines types required for representative datasets for quantization.""" + +from collections.abc import Collection, Sized +import os +from typing import Iterable, Mapping, Optional, Union + +import numpy as np + +from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.data.ops import readers +from tensorflow.python.eager import context +from tensorflow.python.framework import tensor_util +from tensorflow.python.lib.io import python_io +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.types import core +from tensorflow.python.util import tf_export + +# A representative sample is a map of: input_key -> input_value. +# Ex.: {'dense_input': tf.constant([1, 2, 3])} +# Ex.: {'x1': np.ndarray([4, 5, 6]} +RepresentativeSample = Mapping[str, core.TensorLike] + +# A representative dataset is an iterable of representative samples. +RepresentativeDataset = Iterable[RepresentativeSample] + +# A type representing a map from: signature key -> representative dataset. +# Ex.: {'serving_default': [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])], +# 'other_signature_key': [tf.constant([[2, 2], [9, 9]])]} +RepresentativeDatasetMapping = Mapping[str, RepresentativeDataset] + +# A type alias expressing that it can be either a RepresentativeDataset or +# a mapping of signature key to RepresentativeDataset. +RepresentativeDatasetOrMapping = Union[ + RepresentativeDataset, RepresentativeDatasetMapping +] + +# Type aliases for quantization_options_pb2 messages. +_RepresentativeDataSample = quantization_options_pb2.RepresentativeDataSample +_RepresentativeDatasetFile = quantization_options_pb2.RepresentativeDatasetFile + + +class RepresentativeDatasetSaver: + """Representative dataset saver. + + Exposes a single method `save` that saves the provided representative dataset + into files. + + This is useful when you would like to keep a snapshot of your representative + dataset at a file system or when you need to pass the representative dataset + as files. + """ + + def save( + self, representative_dataset: RepresentativeDatasetMapping + ) -> Mapping[str, _RepresentativeDatasetFile]: + """Saves the representative dataset. + + Args: + representative_dataset: RepresentativeDatasetMapping which is a + signature_def_key -> representative dataset mapping. + """ + raise NotImplementedError('Method "save" is not implemented.') + + +@tf_export.tf_export( + 'quantization.experimental.TfRecordRepresentativeDatasetSaver' +) +class TfRecordRepresentativeDatasetSaver(RepresentativeDatasetSaver): + """Representative dataset saver in TFRecord format. + + Saves representative datasets for quantization calibration in TFRecord format. + The samples are serialized as `RepresentativeDataSample`. + + The `save` method return a signature key to `RepresentativeDatasetFile` map, + which can be used for QuantizationOptions. + + Example usage: + + ```python + # Creating the representative dataset. + representative_dataset = [{"input": tf.random.uniform(shape=(3, 3))} + for _ in range(256)] + + # Saving to a TFRecord file. + dataset_file_map = ( + tf.quantization.experimental.TfRecordRepresentativeDatasetSaver( + path_map={'serving_default': '/tmp/representative_dataset_path'} + ).save({'serving_default': representative_dataset}) + ) + + # Using in QuantizationOptions. + quantization_options = tf.quantization.experimental.QuantizationOptions( + signature_keys=['serving_default'], + representative_datasets=dataset_file_map, + ) + tf.quantization.experimental.quantize_saved_model( + '/tmp/input_model', + '/tmp/output_model', + quantization_options=quantization_options, + ) + ``` + """ + + def __init__( + self, + path_map: Mapping[str, os.PathLike[str]], + expected_input_key_map: Optional[Mapping[str, Collection[str]]] = None, + ): + """Initializes TFRecord represenatative dataset saver. + + Args: + path_map: Signature def key -> path mapping. Each path is a TFRecord file + to which a `RepresentativeDataset` is saved. The signature def keys + should be a subset of the `SignatureDef` keys of the + `representative_dataset` argument of the `save()` call. + expected_input_key_map: Signature def key -> expected input keys. If set, + validate that the sample has same set of input keys before saving. + + Raises: + KeyError: If path_map and expected_input_key_map have different keys. + """ + self.path_map: Mapping[str, os.PathLike[str]] = path_map + self.expected_input_key_map: Mapping[str, Collection[str]] = {} + if expected_input_key_map is not None: + if set(path_map.keys()) != set(expected_input_key_map.keys()): + raise KeyError( + 'The `path_map` and `expected_input_key_map` should have the same' + ' set of keys.' + ) + + self.expected_input_key_map = expected_input_key_map + + def _save_tf_record_dataset( + self, + repr_ds: RepresentativeDataset, + signature_def_key: str, + ) -> _RepresentativeDatasetFile: + """Saves `repr_ds` to a TFRecord file. + + Each sample in `repr_ds` is serialized as `RepresentativeDataSample`. + + Args: + repr_ds: `RepresentativeDataset` to save. + signature_def_key: The signature def key associated with `repr_ds`. + + Returns: + a RepresentativeDatasetFile instance contains the path to the saved file. + + Raises: + KeyError: If the set of input keys in the dataset samples doesn't match + the set of expected input keys. + """ + # When running in graph mode (TF1), tf.Tensor types should be converted to + # numpy ndarray types to be compatible with `make_tensor_proto`. + if not context.executing_eagerly(): + with session.Session() as sess: + repr_ds = replace_tensors_by_numpy_ndarrays(repr_ds, sess) + + expected_input_keys = self.expected_input_key_map.get( + signature_def_key, None + ) + tfrecord_file_path = self.path_map[signature_def_key] + with python_io.TFRecordWriter(tfrecord_file_path) as writer: + for repr_sample in repr_ds: + if ( + expected_input_keys is not None + and set(repr_sample.keys()) != expected_input_keys + ): + raise KeyError( + 'Invalid input keys for representative sample. The function' + f' expects input keys of: {set(expected_input_keys)}. Got:' + f' {set(repr_sample.keys())}. Please provide correct input keys' + ' for representative samples.' + ) + + sample = _RepresentativeDataSample() + for input_name, input_value in repr_sample.items(): + sample.tensor_proto_inputs[input_name].CopyFrom( + tensor_util.make_tensor_proto(input_value) + ) + + writer.write(sample.SerializeToString()) + + logging.info( + 'Saved representative dataset for signature def: %s to: %s', + signature_def_key, + tfrecord_file_path, + ) + return _RepresentativeDatasetFile( + tfrecord_file_path=str(tfrecord_file_path) + ) + + def save( + self, representative_dataset: RepresentativeDatasetMapping + ) -> Mapping[str, _RepresentativeDatasetFile]: + """Saves the representative dataset. + + Args: + representative_dataset: Signature def key -> representative dataset + mapping. Each dataset is saved in a separate TFRecord file whose path + matches the signature def key of `path_map`. + + Raises: + ValueError: When the signature def key in `representative_dataset` is not + present in the `path_map`. + + Returns: + A map from signature key to the RepresentativeDatasetFile instance + contains the path to the saved file. + """ + dataset_file_map = {} + for signature_def_key, repr_ds in representative_dataset.items(): + if signature_def_key not in self.path_map: + raise ValueError( + 'SignatureDef key does not exist in the provided path_map:' + f' {signature_def_key}' + ) + + dataset_file_map[signature_def_key] = self._save_tf_record_dataset( + repr_ds, signature_def_key + ) + return dataset_file_map + + +class RepresentativeDatasetLoader: + """Representative dataset loader. + + Exposes the `load` method that loads the representative dataset from files. + """ + + def load(self) -> RepresentativeDatasetMapping: + """Loads the representative datasets. + + Returns: + representative dataset mapping: A loaded signature def key -> + representative mapping. + """ + raise NotImplementedError('Method "load" is not implemented.') + + +class TfRecordRepresentativeDatasetLoader(RepresentativeDatasetLoader): + """TFRecord representative dataset loader. + + Loads representative dataset stored in TFRecord files. + """ + + def __init__( + self, + dataset_file_map: Mapping[str, _RepresentativeDatasetFile], + ) -> None: + """Initializes TFRecord represenatative dataset loader. + + Args: + dataset_file_map: Signature key -> `RepresentativeDatasetFile` mapping. + + Raises: + DecodeError: If the sample is not RepresentativeDataSample. + """ + self.dataset_file_map = dataset_file_map + + def _load_tf_record(self, tf_record_path: str) -> RepresentativeDataset: + """Loads TFRecord containing samples of type`RepresentativeDataSample`.""" + samples = [] + with context.eager_mode(): + for sample_bytes in readers.TFRecordDatasetV2(filenames=[tf_record_path]): + sample_proto = _RepresentativeDataSample.FromString( + sample_bytes.numpy() + ) + sample = {} + for input_key, tensor_proto in sample_proto.tensor_proto_inputs.items(): + sample[input_key] = tensor_util.MakeNdarray(tensor_proto) + samples.append(sample) + return samples + + def load(self) -> RepresentativeDatasetMapping: + """Loads the representative datasets. + + Returns: + representative dataset mapping: A signature def key -> representative + mapping. The loader loads `RepresentativeDataset` for each path in + `self.dataset_file_map` and associates the loaded dataset to the + corresponding signature def key. + """ + repr_dataset_map = {} + for signature_def_key, dataset_file in self.dataset_file_map.items(): + if dataset_file.HasField('tfrecord_file_path'): + repr_dataset_map[signature_def_key] = self._load_tf_record( + dataset_file.tfrecord_file_path + ) + else: + raise ValueError('Unsupported Representative Dataset filetype') + + return repr_dataset_map + + +def replace_tensors_by_numpy_ndarrays( + repr_ds: RepresentativeDataset, sess: session.Session +) -> RepresentativeDataset: + """Replaces tf.Tensors in samples by their evaluated numpy arrays. + + Note: This should be run in graph mode (default in TF1) only. + + Args: + repr_ds: Representative dataset to replace the tf.Tensors with their + evaluated values. `repr_ds` is iterated through, so it may not be reusable + (e.g. if it is a generator object). + sess: Session instance used to evaluate tf.Tensors. + + Returns: + The new representative dataset where each tf.Tensor is replaced by its + evaluated numpy ndarrays. + """ + new_repr_ds = [] + for sample in repr_ds: + new_sample = {} + for input_key, input_data in sample.items(): + # Evaluate the Tensor to get the actual value. + if isinstance(input_data, core.Tensor): + input_data = input_data.eval(session=sess) + + new_sample[input_key] = input_data + + new_repr_ds.append(new_sample) + return new_repr_ds + + +def get_num_samples(repr_ds: RepresentativeDataset) -> Optional[int]: + """Returns the number of samples if known. + + Args: + repr_ds: Representative dataset. + + Returns: + Returns the total number of samples in `repr_ds` if it can be determined + without iterating the entier dataset. Returns None iff otherwise. When it + returns None it does not mean the representative dataset is infinite or it + is malformed; it simply means the size cannot be determined without + iterating the whole dataset. + """ + if isinstance(repr_ds, Sized): + try: + return len(repr_ds) + except Exception as ex: # pylint: disable=broad-except + # There are some cases where calling __len__() raises an exception. + # Handle this as if the size is unknown. + logging.info('Cannot determine the size of the dataset (%s).', ex) + return None + else: + return None + + +def create_feed_dict_from_input_data( + input_data: RepresentativeSample, + signature_def: meta_graph_pb2.SignatureDef, +) -> Mapping[str, np.ndarray]: + """Constructs a feed_dict from input data. + + Note: This function should only be used in graph mode. + + This is a helper function that converts an 'input key -> input value' mapping + to a feed dict. A feed dict is an 'input tensor name -> input value' mapping + and can be directly passed to the `feed_dict` argument of `sess.run()`. + + Args: + input_data: Input key -> input value mapping. The input keys should match + the input keys of `signature_def`. + signature_def: A SignatureDef representing the function that `input_data` is + an input to. + + Returns: + Feed dict, which is intended to be used as input for `sess.run`. It is + essentially a mapping: input tensor name -> input value. Note that the input + value in the feed dict is not a `Tensor`. + """ + feed_dict = {} + for input_key, input_value in input_data.items(): + input_tensor_name = signature_def.inputs[input_key].name + + value = input_value + if isinstance(input_value, core.Tensor): + # Take the data out of the tensor. + value = input_value.eval() + + feed_dict[input_tensor_name] = value + + return feed_dict diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py new file mode 100644 index 0000000000000000000000000000000000000000..87ad7a11f2e6770461f5220cda31b37930744c1f --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/python/save_model.py @@ -0,0 +1,346 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines utilities involving SavedModel.""" +from typing import Collection, Dict, Mapping, Optional, Sequence + +from absl import logging + +# pylint: disable=g-importing-member +from google.protobuf.any_pb2 import Any +# pylint: enable=g-importing-member +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import builder +from tensorflow.python.saved_model import constants as saved_model_constants +from tensorflow.python.saved_model import loader_impl as saved_model_loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.training import saver + +# Mapping of signature def key -> SignatureDef. +_SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef] + + +def get_signatures_from_saved_model( + saved_model_path: str, + signature_keys: Optional[Sequence[str]] = None, + tags: Optional[Collection[str]] = None, +) -> Dict[str, meta_graph_pb2.SignatureDef]: + """Gets a map from signature keys to their SignatureDef. + + Args: + saved_model_path: Path to the saved model. + signature_keys: List of keys identifying SignatureDef to retrieve. If None, + retrieve all except the init signature. + tags: Set of tags identifying the MetaGraphDef within the SavedModel. + + Returns: + A map from signature_key to its SignatureDef. + """ + if tags is None: + tags = {tag_constants.SERVING} + + loader = saved_model_loader.SavedModelLoader(saved_model_path) + meta_graphdef = loader.get_meta_graph_def_from_tags(tags) + signatures = {} + for key, signature_def in meta_graphdef.signature_def.items(): + if key == saved_model_constants.INIT_OP_SIGNATURE_KEY: + continue + if signature_keys is not None and key not in signature_keys: + continue + signatures[key] = signature_def + + return signatures + + +def _restore_output_tensor_names( + graph_def: graph_pb2.GraphDef, +) -> graph_pb2.GraphDef: + """Restores the output tensor names of the converted model. + + During the conversion, the output tensor names of the original model are + embedded in the `tf_saved_model.index_path` attribute of the RetVal nodes and + might become the name of Retval nodes as well (with an index suffix if there + are multiple output tensors from one node). Since Retval nodes are not used in + SavedModel, this function removes them and restore the names to the actual + output tensors. + + Args: + graph_def: the converted GraphDef. + + Returns: + The GraphDef with Retval nodes removed and output tensor names restored. + """ + output_renaming_map = {} + with session.Session(graph=ops.Graph()): + importer.import_graph_def(graph_def, name='') + graph = ops.get_default_graph() + for op in graph.get_operations(): + if op.type == '_Retval': + expected_node_name = op.name + if op.get_attr('tf_saved_model.index_path') is not None: + index_path_name = op.get_attr('tf_saved_model.index_path')[0] + index_path_name = index_path_name.decode('utf-8').split(':')[0] + try: + # Only use the index_path name if it points to a Retval node. + index_path_node = graph.get_operation_by_name(index_path_name) + if index_path_node.type == '_Retval': + expected_node_name = index_path_name + except KeyError: + pass + retval_input_node_name = op.inputs[0].op.name + output_renaming_map[retval_input_node_name] = expected_node_name + + for node in reversed(graph_def.node): + if node.name in output_renaming_map: + node.name = output_renaming_map[node.name] + elif node.op == '_Retval': + graph_def.node.remove(node) + else: + # Update the inputs referring to the pre-renaming node. + for idx, input_name in enumerate(node.input): + if input_name in output_renaming_map: + node.input[idx] = output_renaming_map[input_name] + # Update the control inputs referring to the pre-renaming node. + updating_inputs = [] + for input_name in reversed(node.input): + if input_name.startswith('^') and input_name[1:] in output_renaming_map: + updating_inputs.append(input_name[1:]) + node.input.remove(input_name) + for updating_input in updating_inputs: + node.input.append('^' + output_renaming_map[updating_input]) + return graph_def + + +def create_empty_output_dir( + output_directory: str, overwrite: bool = True +) -> None: + """Creates the `output_directory`. + + If `output_directory` already exists, it recursively deletes all contents + inside the directory. + + Also creates the parent & intermediate directories. + + Args: + output_directory: Output directory. + overwrite: Where to clean the output directory if exists. + """ + if overwrite and file_io.file_exists_v2(output_directory): + logging.info( + 'Deleting existing output directory: %s .', + output_directory, + ) + file_io.delete_recursively_v2(output_directory) + + file_io.recursive_create_dir_v2(output_directory) + + +def _validate_signatures( + signature_def_map: _SignatureDefMap, exported_graph: ops.Graph +) -> _SignatureDefMap: + """Validates if the tensor names in signatures are consistent with the graph. + + This function checks if the input and output tensor names in the signatures + exist if the graph. The output tensor names might change during conversion, + we try to fix that with `_restore_output_tensor_names`. Besides, if there + are duplicated tensor names, they we will be prefixed with the signature name. + However, if that doesn't work the signatures can't be used with the converted + graph. + + Args: + signature_def_map: the signatures to validate. + exported_graph: The PTQ-exported GraphDef. + + Returns: + The signatures with tensor names prefixed with signature name if necessary. + + Raises: + ValueError: Iff the signatures are not consistent with the graph. + """ + for signature_key, signature_def in signature_def_map.items(): + for tensor_info in signature_def.inputs.values(): + try: + exported_graph.get_tensor_by_name(tensor_info.name) + except KeyError as exc: + try: + prefixed_name = signature_key + '_' + tensor_info.name + exported_graph.get_tensor_by_name(prefixed_name) + tensor_info.name = prefixed_name + except KeyError: + raise ValueError( + 'Cannot find the input tensor with name %s in the graph.' + % tensor_info.name + ) from exc + + for tensor_info in signature_def.outputs.values(): + try: + exported_graph.get_tensor_by_name(tensor_info.name) + except KeyError as exc: + try: + prefixed_name = signature_key + '_' + tensor_info.name + exported_graph.get_tensor_by_name(prefixed_name) + tensor_info.name = prefixed_name + except KeyError: + raise ValueError( + 'Cannot find the output tensor with name %s in the graph.' + % tensor_info.name + ) from exc + + return signature_def_map + + +def _find_op( + graph: ops.Graph, op_name: Optional[str] +) -> Optional[ops.Operation]: + """Finds the operation with `op_name`. + + Args: + graph: The graph to find from. + op_name: Name of the node. + + Returns: + The operation that corresponds to `op_name`. Returns None iff op_name is an + empty string or None. + + Raises: + ValueError: `op_name` is malformed. + """ + if not op_name: + return None + + init_op = graph.get_operation_by_name(op_name) + logging.debug('Op found in the graph: %s', op_name) + + return init_op + + +def _save_function_alias( + saved_model_dir: str, + tags: Collection[str], + function_aliases: Mapping[str, str], +) -> None: + """Saves the function alias to the SavedModel. + + SavedModelBuilder (TF1 saved model saver) does not support saving function + aliases, so this function loads the SavedModel proto and adds the + `function_aliases` field. + + Args: + saved_model_dir: Path to the saved model directory. + tags: A collection of tags to specify the meta graph. + function_aliases: Function name -> function alias mapping. + """ + loader = saved_model_loader.SavedModelLoader(saved_model_dir) + meta_graph_def = loader.get_meta_graph_def_from_tags(tags) + + for function_name, function_alias in function_aliases.items(): + meta_graph_def.meta_info_def.function_aliases[function_name] = ( + function_alias + ) + + saved_model_proto_serialized = loader.saved_model.SerializeToString() + + # TODO(b/266015731): Also update and set the SavedModel fingerprint. + path = file_io.join( + saved_model_dir, saved_model_constants.SAVED_MODEL_FILENAME_PB + ) + file_io.atomic_write_string_to_file(path, saved_model_proto_serialized) + + +def save_model_v1( + graph_def: graph_pb2.GraphDef, + output_dir: str, + signature_def_map: _SignatureDefMap, + tags: Collection[str], + init_op_name: Optional[str] = None, + saver_def: Optional[saver_pb2.SaverDef] = None, + checkpoint_dir: Optional[str] = None, + function_aliases: Optional[Mapping[str, str]] = None, + asset_file_defs: Sequence[meta_graph_pb2.AssetFileDef] = (), +) -> None: + """Saves the model. + + Saves the provided graph def as SavedModel. + Uses TF1 SavedModel semantics (i.e. no object graph). + + Args: + graph_def: Graph to save. + output_dir: Output directory for the SavedModel. + signature_def_map: Mapping of signature def key -> SignatureDef. + tags: Tags for the meta graph def. + init_op_name: Name of the node for initialization. + saver_def: `saver_pb2.SaverDef` to create a `saver.Saver` from. The created + saver will be used to save and load variables. This may be `None` if no + variables exist in the graph. + checkpoint_dir: Path to checkpoint file where variable values are saved. + function_aliases: Function name -> function alias mapping. + asset_file_defs: `AssetFileDef`s that associates the asset files and the + name of the tensors to which the asset file names should be fed. The + caller should make sure the asset files exist in the output saved model + directory. + + Raises: + ValueError iff the graph does not contain a valid signature or the file + prefix tensor is not found in the graph. + """ + create_empty_output_dir(output_dir) + v1_builder = builder.SavedModelBuilder(output_dir) + + graph_def = _restore_output_tensor_names(graph_def) + with session.Session(graph=ops.Graph()) as sess: + importer.import_graph_def(graph_def, name='') + + signature_def_map = _validate_signatures( + signature_def_map, ops.get_default_graph() + ) + + # Add `AssetFileDef`s to the collection so that correct values are fed to + # the tensors that accept asset file paths. + for asset_file_def in asset_file_defs: + asset_any_proto = Any() + asset_any_proto.Pack(asset_file_def) + ops.add_to_collection( + saved_model_constants.ASSETS_KEY, + asset_any_proto, + ) + + model_saver = None + # If `saver_def` is not None, it means there are variables in the graph. + if saver_def: + model_saver = saver.Saver(saver_def=saver_def) + logging.info('Saver created with SaverDef: %s', saver_def) + + # Variables should be restored once before exporting as saved model + # because the variables are not initialized when the GraphDef was + # imported. + model_saver.restore(sess, checkpoint_dir) + + v1_builder.add_meta_graph_and_variables( + sess, + tags, + signature_def_map=signature_def_map, + main_op=_find_op(sess.graph, op_name=init_op_name), + saver=model_saver, + ) + + v1_builder.save() + + if function_aliases: + _save_function_alias(output_dir, tags, function_aliases) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..d20b4536379cef96ae337626f51a35eb762e59b4 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options_pb2.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as tensorflow_dot_compiler_dot_mlir_dot_quantization_dot_stablehlo_dot_quantization__config__pb2 +from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nKtensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto\x12\x17tensorflow.quantization\x1aItensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto\x1a&tensorflow/core/framework/tensor.proto\"\xed\x02\n\x12QuantizationMethod\x12O\n\rpreset_method\x18\x04 \x01(\x0e\x32\x38.tensorflow.quantization.QuantizationMethod.PresetMethod\x12X\n\x1cquantization_component_specs\x18\x03 \x03(\x0b\x32\x32.tensorflow.quantization.QuantizationComponentSpec\"\xa5\x01\n\x0cPresetMethod\x12\x16\n\x12METHOD_UNSPECIFIED\x10\x00\x12\x16\n\x12METHOD_NO_QUANTIZE\x10\x01\x12\x1c\n\x18METHOD_STATIC_RANGE_INT8\x10\x02\x12\x1d\n\x19METHOD_DYNAMIC_RANGE_INT8\x10\x03\x12(\n$METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8\x10\x04J\x04\x08\x01\x10\x03\"\xbe\x03\n\x19QuantizationComponentSpec\x12h\n\x16quantization_component\x18\x01 \x01(\x0e\x32H.tensorflow.quantization.QuantizationComponentSpec.QuantizationComponent\x12R\n\x0btensor_type\x18\x02 \x01(\x0e\x32=.tensorflow.quantization.QuantizationComponentSpec.TensorType\"v\n\x15QuantizationComponent\x12\x19\n\x15\x43OMPONENT_UNSPECIFIED\x10\x00\x12\x18\n\x14\x43OMPONENT_ACTIVATION\x10\x01\x12\x14\n\x10\x43OMPONENT_WEIGHT\x10\x02\x12\x12\n\x0e\x43OMPONENT_BIAS\x10\x03\"k\n\nTensorType\x12\x1a\n\x16TENSORTYPE_UNSPECIFIED\x10\x00\x12\x14\n\x10TENSORTYPE_INT_4\x10\x01\x12\x14\n\x10TENSORTYPE_INT_8\x10\x02\x12\x15\n\x11TENSORTYPE_INT_32\x10\x03\"\x87\x02\n\x18UnitWiseQuantizationSpec\x12P\n\x04unit\x18\x05 \x03(\x0b\x32\x42.tensorflow.quantization.UnitWiseQuantizationSpec.QuantizationUnit\x12H\n\x13quantization_method\x18\x06 \x01(\x0b\x32+.tensorflow.quantization.QuantizationMethod\x1aI\n\x10QuantizationUnit\x12\x0f\n\x07op_type\x18\x01 \x01(\t\x12\x11\n\tnode_name\x18\x02 \x01(\t\x12\x11\n\tfunc_name\x18\x03 \x01(\tJ\x04\x08\x01\x10\x05\"\xd4\x01\n\x18RepresentativeDataSample\x12\x65\n\x13tensor_proto_inputs\x18\x02 \x03(\x0b\x32H.tensorflow.quantization.RepresentativeDataSample.TensorProtoInputsEntry\x1aQ\n\x16TensorProtoInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"I\n\x19RepresentativeDatasetFile\x12\x1c\n\x12tfrecord_file_path\x18\x01 \x01(\tH\x00\x42\x0e\n\x0c\x64\x61taset_file\"\xca\x07\n\x13QuantizationOptions\x12H\n\x13quantization_method\x18\x01 \x01(\x0b\x32+.tensorflow.quantization.QuantizationMethod\x12.\n\x06op_set\x18\x02 \x01(\x0e\x32\x1e.tensorflow.quantization.OpSet\x12W\n\x1cunit_wise_quantization_specs\x18\x11 \x03(\x0b\x32\x31.tensorflow.quantization.UnitWiseQuantizationSpec\x12\x0c\n\x04tags\x18\x05 \x03(\t\x12\x16\n\x0esignature_keys\x18\x06 \x03(\t\x12i\n\x17representative_datasets\x18\x07 \x03(\x0b\x32H.tensorflow.quantization.QuantizationOptions.RepresentativeDatasetsEntry\x12$\n\x1cmin_num_elements_for_weights\x18\x08 \x01(\x03\x12!\n\x14\x66reeze_all_variables\x18\t \x01(\x08H\x00\x88\x01\x01\x12,\n\x1f\x65nable_per_channel_quantization\x18\n \x01(\x08H\x01\x88\x01\x01\x12 \n\x18\x65nable_two_input_tensors\x18\x0b \x01(\x08\x12-\n%experimental_enable_tpu_model_support\x18\x0c \x01(\x08\x12!\n\x19\x65nable_legacy_weight_only\x18\r \x01(\x08\x12$\n\x1c\x66orce_graph_mode_calibration\x18\x0e \x01(\x08\x12G\n\x13\x63\x61libration_options\x18\x0f \x01(\x0b\x32*.stablehlo.quantization.CalibrationOptions\x12?\n\x0f\x64\x65\x62ugger_config\x18\x10 \x01(\x0b\x32&.stablehlo.quantization.DebuggerConfig\x1aq\n\x1bRepresentativeDatasetsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32\x32.tensorflow.quantization.RepresentativeDatasetFile:\x02\x38\x01\x42\x17\n\x15_freeze_all_variablesB\"\n _enable_per_channel_quantizationJ\x04\x08\x03\x10\x04*V\n\x05OpSet\x12\x16\n\x12OP_SET_UNSPECIFIED\x10\x00\x12\x06\n\x02TF\x10\x01\x12\x07\n\x03XLA\x10\x02\x12\x15\n\x11UNIFORM_QUANTIZED\x10\x03\x12\r\n\tSTABLEHLO\x10\x04\x42\x03\xf8\x01\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.tensorflow.quantization_options_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\370\001\001' + _REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY._options = None + _REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY._serialized_options = b'8\001' + _QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY._options = None + _QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY._serialized_options = b'8\001' + _OPSET._serialized_start=2565 + _OPSET._serialized_end=2651 + _QUANTIZATIONMETHOD._serialized_start=220 + _QUANTIZATIONMETHOD._serialized_end=585 + _QUANTIZATIONMETHOD_PRESETMETHOD._serialized_start=414 + _QUANTIZATIONMETHOD_PRESETMETHOD._serialized_end=579 + _QUANTIZATIONCOMPONENTSPEC._serialized_start=588 + _QUANTIZATIONCOMPONENTSPEC._serialized_end=1034 + _QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT._serialized_start=807 + _QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT._serialized_end=925 + _QUANTIZATIONCOMPONENTSPEC_TENSORTYPE._serialized_start=927 + _QUANTIZATIONCOMPONENTSPEC_TENSORTYPE._serialized_end=1034 + _UNITWISEQUANTIZATIONSPEC._serialized_start=1037 + _UNITWISEQUANTIZATIONSPEC._serialized_end=1300 + _UNITWISEQUANTIZATIONSPEC_QUANTIZATIONUNIT._serialized_start=1221 + _UNITWISEQUANTIZATIONSPEC_QUANTIZATIONUNIT._serialized_end=1294 + _REPRESENTATIVEDATASAMPLE._serialized_start=1303 + _REPRESENTATIVEDATASAMPLE._serialized_end=1515 + _REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY._serialized_start=1434 + _REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY._serialized_end=1515 + _REPRESENTATIVEDATASETFILE._serialized_start=1517 + _REPRESENTATIVEDATASETFILE._serialized_end=1590 + _QUANTIZATIONOPTIONS._serialized_start=1593 + _QUANTIZATIONOPTIONS._serialized_end=2563 + _QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY._serialized_start=2383 + _QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY._serialized_end=2496 +# @@protoc_insertion_point(module_scope) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec83ebf0f4bc8dc3799b460835f1602b93ddf81b Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__pycache__/stablehlo.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__pycache__/stablehlo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..478c165ec7540ecc2334d5a9b0afb548ec8aa3e7 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/__pycache__/stablehlo.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/stablehlo.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/stablehlo.py new file mode 100644 index 0000000000000000000000000000000000000000..64c3f1b7be30734d5066a0bcc3f372edaf39dad9 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/stablehlo.py @@ -0,0 +1,26 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""StableHLO Portable Python APIs. + +This setup only exports the the StableHLO Portable C++ APIs, which have +signatures that do not rely on MLIR classes. + +Exporting all of MLIR Python bindings to TF OSS has high maintenance +implications, especially given the frequency that TF updates the revision of +LLVM used. +""" + +# pylint: disable=wildcard-import +from .stablehlo_extension import * diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/stablehlo_extension.so b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/stablehlo_extension.so new file mode 100644 index 0000000000000000000000000000000000000000..fd4aef69caafb3846ac8abd31cf5327c88f95931 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/stablehlo/stablehlo_extension.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56d69c76b5706bf8e9d77ffca1fe5d719bdc0018b36dfee10d8afb2dce3b4180 +size 28852296 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5dc333f7f5fc486f3b88752db2f6e644534c139 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe4682a9e192a1303afe234606d777f2b3fa717 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.pyi b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.pyi new file mode 100644 index 0000000000000000000000000000000000000000..ec5eaad7983bf005e559f489346b46b7a3a8608e --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.pyi @@ -0,0 +1,30 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# LINT.IfChange(savedmodel_to_stablehlo) +def savedmodel_to_stablehlo( + input_path: str, + exported_model_signatures: list[str] = ["serving_default"], + tag_names: list[str] = ["serve"], + input_arg_shapes_str: str = "", +) -> bytes: ... +# LINT.ThenChange() + +# LINT.IfChange(tensorflow_module_to_stablehlo) +def tensorflow_module_to_stablehlo( + module: str, + input_arg_shapes_str: str = "", +) -> bytes: ... +# LINT.ThenChange() diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.so b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.so new file mode 100644 index 0000000000000000000000000000000000000000..ea714b6c30fa63f9ec260c3e754704236f6e6492 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/pywrap_tensorflow_to_stablehlo.so differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..059e07756918bdcefd085ecc3da206227c506aba Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563bf97920f25027e00a1ade544f203b2d60b02f Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__pycache__/gen_trt_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__pycache__/gen_trt_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa2ad7574b7b4f3874804b39ae044665931fec9 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/__pycache__/gen_trt_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/gen_trt_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/gen_trt_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c81575c632181b03299a44abf4ba4e6b717c8321 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2tensorrt/ops/gen_trt_ops.py @@ -0,0 +1,23 @@ +"""Python wrappers around TensorFlow ops. + +This file is MACHINE GENERATED! Do not edit. +""" + +import collections + +from tensorflow.python import pywrap_tfe as pywrap_tfe +from tensorflow.python.eager import context as _context +from tensorflow.python.eager import core as _core +from tensorflow.python.eager import execute as _execute +from tensorflow.python.framework import dtypes as _dtypes +from tensorflow.security.fuzzing.py import annotation_types as _atypes + +from tensorflow.python.framework import op_def_registry as _op_def_registry +from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.deprecation import deprecated_endpoints +from tensorflow.python.util import dispatch as _dispatch +from tensorflow.python.util.tf_export import tf_export + +from typing import TypeVar, List, Any +from typing_extensions import Annotated diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45f62d2d93941f8c3458282db038c9f7c7b5ea62 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__pycache__/tf2xla_pb2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__pycache__/tf2xla_pb2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66a6b850162bd326554fe55db0b6db63a727e5eb Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/__pycache__/tf2xla_pb2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..254c8f2a268282a83367bdb06bf833a1d3aafed3 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__pycache__/gen_xla_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__pycache__/gen_xla_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36864b7aaa0c55c8b2d757064afbabd7e2a23ce7 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/__pycache__/gen_xla_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/_xla_ops.so b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/_xla_ops.so new file mode 100644 index 0000000000000000000000000000000000000000..e94878b0a23723c36b3481b1cea25691ad6703ea --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/_xla_ops.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f5d3759bfef2be9134f4a0d0ea0fc3f3ad139c6172893b62febe1155a2570d7d +size 8233648 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/gen_xla_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/gen_xla_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..f363b162f33959a465335dacde370a3884759782 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/ops/gen_xla_ops.py @@ -0,0 +1,4855 @@ +"""Python wrappers around TensorFlow ops. + +This file is MACHINE GENERATED! Do not edit. +""" + +import collections + +from tensorflow.python import pywrap_tfe as pywrap_tfe +from tensorflow.python.eager import context as _context +from tensorflow.python.eager import core as _core +from tensorflow.python.eager import execute as _execute +from tensorflow.python.framework import dtypes as _dtypes +from tensorflow.security.fuzzing.py import annotation_types as _atypes + +from tensorflow.python.framework import op_def_registry as _op_def_registry +from tensorflow.python.framework import ops as _ops +from tensorflow.python.framework import op_def_library as _op_def_library +from tensorflow.python.util.deprecation import deprecated_endpoints +from tensorflow.python.util import dispatch as _dispatch +from tensorflow.python.util.tf_export import tf_export + +from typing import TypeVar, List, Any +from typing_extensions import Annotated + +TV_XlaAllReduce_T = TypeVar("TV_XlaAllReduce_T", _atypes.BFloat16, _atypes.Float32, _atypes.Half, _atypes.Int32, _atypes.UInt32) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_all_reduce') +def xla_all_reduce(input: Annotated[Any, TV_XlaAllReduce_T], group_assignment: Annotated[Any, _atypes.Int32], reduce_op: str, mode: str, name=None) -> Annotated[Any, TV_XlaAllReduce_T]: + r"""Wraps the XLA AllReduce operator + + documented at https://www.tensorflow.org/xla/operation_semantics#allreduce. + + Args: + input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`, `float32`, `int32`, `uint32`. + Array or a non-empty tuple of arrays to reduce across replicas. + group_assignment: A `Tensor` of type `int32`. + Groups between which the reductions are performed. + reduce_op: A `string` from: `"Min", "Max", "Mul", "Add", "Mean"`. + Reduction computation. + mode: A `string` from: `"CrossReplica", "CrossReplicaAndPartition"`. + group mode. + CrossReplica: group_assignment contains replica_id. Each group contains the + replicas for the current partition. + CrossReplicaAndPartition: group_assignment contains replica_id. Each group + contains the replicas for all partitions. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaAllReduce", name, input, group_assignment, "reduce_op", + reduce_op, "mode", mode) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_all_reduce( + (input, group_assignment, reduce_op, mode, name,), None) + if _result is not NotImplemented: + return _result + return xla_all_reduce_eager_fallback( + input, group_assignment, reduce_op=reduce_op, mode=mode, name=name, + ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_all_reduce, (), dict(input=input, + group_assignment=group_assignment, + reduce_op=reduce_op, mode=mode, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_all_reduce( + (input, group_assignment, reduce_op, mode, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + reduce_op = _execute.make_str(reduce_op, "reduce_op") + mode = _execute.make_str(mode, "mode") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaAllReduce", input=input, group_assignment=group_assignment, + reduce_op=reduce_op, mode=mode, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_all_reduce, (), dict(input=input, + group_assignment=group_assignment, + reduce_op=reduce_op, mode=mode, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "reduce_op", + _op.get_attr("reduce_op"), "mode", _op.get_attr("mode")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaAllReduce", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaAllReduce = tf_export("raw_ops.XlaAllReduce")(_ops.to_raw_op(xla_all_reduce)) +_dispatcher_for_xla_all_reduce = xla_all_reduce._tf_type_based_dispatcher.Dispatch + + +def xla_all_reduce_eager_fallback(input: Annotated[Any, TV_XlaAllReduce_T], group_assignment: Annotated[Any, _atypes.Int32], reduce_op: str, mode: str, name, ctx) -> Annotated[Any, TV_XlaAllReduce_T]: + reduce_op = _execute.make_str(reduce_op, "reduce_op") + mode = _execute.make_str(mode, "mode") + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.half, _dtypes.bfloat16, _dtypes.float32, _dtypes.int32, _dtypes.uint32, ]) + group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32) + _inputs_flat = [input, group_assignment] + _attrs = ("T", _attr_T, "reduce_op", reduce_op, "mode", mode) + _result = _execute.execute(b"XlaAllReduce", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaAllReduce", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +_XlaBroadcastHelperOutput = collections.namedtuple( + "XlaBroadcastHelper", + ["lhs_output", "rhs_output"]) + + +TV_XlaBroadcastHelper_T = TypeVar("TV_XlaBroadcastHelper_T", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaBroadcastHelper_Tindices = TypeVar("TV_XlaBroadcastHelper_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_broadcast_helper') +def xla_broadcast_helper(lhs: Annotated[Any, TV_XlaBroadcastHelper_T], rhs: Annotated[Any, TV_XlaBroadcastHelper_T], broadcast_dims: Annotated[Any, TV_XlaBroadcastHelper_Tindices], name=None): + r"""Helper operator for performing XLA-style broadcasts + + Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to + whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules + for binary operators. + + Args: + lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the LHS input tensor + rhs: A `Tensor`. Must have the same type as `lhs`. the RHS input tensor + broadcast_dims: A `Tensor`. Must be one of the following types: `int32`, `int64`. + an XLA-style broadcast dimension specification + name: A name for the operation (optional). + + Returns: + A tuple of `Tensor` objects (lhs_output, rhs_output). + + lhs_output: A `Tensor`. Has the same type as `lhs`. the broadcasted LHS tensor + rhs_output: A `Tensor`. Has the same type as `lhs`. the broadcasted RHS tensor + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaBroadcastHelper", name, lhs, rhs, broadcast_dims) + _result = _XlaBroadcastHelperOutput._make(_result) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_broadcast_helper( + (lhs, rhs, broadcast_dims, name,), None) + if _result is not NotImplemented: + return _result + return xla_broadcast_helper_eager_fallback( + lhs, rhs, broadcast_dims, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_broadcast_helper, (), dict(lhs=lhs, rhs=rhs, + broadcast_dims=broadcast_dims, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_broadcast_helper( + (lhs, rhs, broadcast_dims, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaBroadcastHelper", lhs=lhs, rhs=rhs, broadcast_dims=broadcast_dims, + name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_broadcast_helper, (), dict(lhs=lhs, rhs=rhs, + broadcast_dims=broadcast_dims, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaBroadcastHelper", _inputs_flat, _attrs, _result) + _result = _XlaBroadcastHelperOutput._make(_result) + return _result + +XlaBroadcastHelper = tf_export("raw_ops.XlaBroadcastHelper")(_ops.to_raw_op(xla_broadcast_helper)) +_dispatcher_for_xla_broadcast_helper = xla_broadcast_helper._tf_type_based_dispatcher.Dispatch + + +def xla_broadcast_helper_eager_fallback(lhs: Annotated[Any, TV_XlaBroadcastHelper_T], rhs: Annotated[Any, TV_XlaBroadcastHelper_T], broadcast_dims: Annotated[Any, TV_XlaBroadcastHelper_Tindices], name, ctx): + _attr_T, _inputs_T = _execute.args_to_matching_eager([lhs, rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + (lhs, rhs) = _inputs_T + _attr_Tindices, (broadcast_dims,) = _execute.args_to_matching_eager([broadcast_dims], ctx, [_dtypes.int32, _dtypes.int64, ]) + _inputs_flat = [lhs, rhs, broadcast_dims] + _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) + _result = _execute.execute(b"XlaBroadcastHelper", 2, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaBroadcastHelper", _inputs_flat, _attrs, _result) + _result = _XlaBroadcastHelperOutput._make(_result) + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_call_module') +def xla_call_module(args, version: int, module: str, Sout, Tout, dim_args_spec=[], platforms=[], function_list=[], has_token_input_output:bool=False, disabled_checks=[], name=None): + r"""Invokes a StableHLO module. + + This op is used with JAX native serialization in a TensorFlow context with + stability guarantees. + + Args: + args: A list of `Tensor` objects. + A list of `Tensor` with possibly different types to be passed as arguments + to the `module`. These are the actual arguments and do not include the + platform argument (see `platforms`) nor the dimension arguments (see + `dim_args_spec`). + version: An `int`. + Tracks changes the semantics of the op, to support backwards + compatibility. Minimum supported version is 2. From + version 2, the op carries a StableHLO text or bytecode `module`. From + version 3, the op also supports the `platforms` attribute. From version 4, + the op carries a StableHLO module with compatibility guarantees. From version + 5, XLACallModule can include `stablehlo.custom_call` op to execute tf + functions. From version 6 the op supports the `disabled_checks` attribute. + See more versioning details at https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code. + module: A `string`. + A serialized computation, a text or bytecode representation of + an mlir.Module. The return type must be a tuple if and only if the `Sout` is + a list with 0 or more than 1 elements. The length of `Tout` and + `Sout` must match. This op always returns a tuple of results, even if the + module returns a single result. + Sout: A list of shapes (each a `tf.TensorShape` or list of `ints`). + List of output tensor shapes. + Tout: A list of `tf.DTypes`. List of output tensor data types. + dim_args_spec: An optional list of `strings`. Defaults to `[]`. + this attribute is not supported anymore. + platforms: An optional list of `strings`. Defaults to `[]`. + the list of platforms supported by `module`. The list can contain + the strings "CPU", "CUDA", "ROCM", or "TPU". It is an error to compile + this op for a platform that does not appear in the list. This check can be + disabled using `disabled_checks`. If the list contains more than + one platform, then the `module` takes one additional 0-dimensional + integer-tensor parameter in the first position, encoding the index in + `platforms` of the current compilation platform. This parameter has value 0 + if the plaform is not among `platforms` and the check has been disabled. + The list can be empty in old versions (earlier than 6) to denote that no + platform checking must be performed at loading time. + function_list: An optional list of functions decorated with @Defun. Defaults to `[]`. + This list contains the TensorFlow FunctionDefs that are used by + the XLACallModule. If the XLACallModule contains `stablehlo.custom_call` + operations, they can call TensorFlow graph functions outside of the + XLACallModule. This `function_list` attribute registers the dependency of the + XLACallModule on those functions. This attribute was added in version 5. + has_token_input_output: An optional `bool`. Defaults to `False`. + If true, the embedded StableHLO module's main function + must take a `!stablehlo.token` as its first argument and returns a token as + its first result. This can be used in conjunction with the TF2XLA's side + effect mechanism in order to model side effects. This is used only in versions + prior to version 9. After that, the number and position of tokens among + the arguments and results are obtained from the main function type. This + allows us to support more than one token and not necessarily at the start. + disabled_checks: An optional list of `strings`. Defaults to `[]`. + A list of strings describing the safety checks that were + disabled at serialization time. This attribute was added in version 6. + For more details see + https://github.com/search?q=repo%3Agoogle%2Fjax+path%3Ajax_export+%22class+DisabledSafetyCheck%22&type=code. + This list, supplemented with a comma-separate list of directives specified + using the flag --tf_xla_call_module_disabled_checks, + is used at module loading time to skip the corresponding checks. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects of type `Tout`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaCallModule", name, args, "version", version, "module", + module, "Sout", Sout, "Tout", Tout, "dim_args_spec", dim_args_spec, + "platforms", platforms, "function_list", function_list, + "has_token_input_output", has_token_input_output, "disabled_checks", + disabled_checks) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_call_module( + (args, version, module, Sout, Tout, dim_args_spec, platforms, + function_list, has_token_input_output, disabled_checks, name,), None) + if _result is not NotImplemented: + return _result + return xla_call_module_eager_fallback( + args, version=version, module=module, Sout=Sout, Tout=Tout, + dim_args_spec=dim_args_spec, platforms=platforms, + function_list=function_list, + has_token_input_output=has_token_input_output, + disabled_checks=disabled_checks, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_call_module, (), dict(args=args, version=version, + module=module, Sout=Sout, Tout=Tout, + dim_args_spec=dim_args_spec, + platforms=platforms, + function_list=function_list, + has_token_input_output=has_token_input_output, + disabled_checks=disabled_checks, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_call_module( + (args, version, module, Sout, Tout, dim_args_spec, platforms, + function_list, has_token_input_output, disabled_checks, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + version = _execute.make_int(version, "version") + module = _execute.make_str(module, "module") + if not isinstance(Sout, (list, tuple)): + raise TypeError( + "Expected list for 'Sout' argument to " + "'xla_call_module' Op, not %r." % Sout) + Sout = [_execute.make_shape(_s, "Sout") for _s in Sout] + if not isinstance(Tout, (list, tuple)): + raise TypeError( + "Expected list for 'Tout' argument to " + "'xla_call_module' Op, not %r." % Tout) + Tout = [_execute.make_type(_t, "Tout") for _t in Tout] + if dim_args_spec is None: + dim_args_spec = [] + if not isinstance(dim_args_spec, (list, tuple)): + raise TypeError( + "Expected list for 'dim_args_spec' argument to " + "'xla_call_module' Op, not %r." % dim_args_spec) + dim_args_spec = [_execute.make_str(_s, "dim_args_spec") for _s in dim_args_spec] + if platforms is None: + platforms = [] + if not isinstance(platforms, (list, tuple)): + raise TypeError( + "Expected list for 'platforms' argument to " + "'xla_call_module' Op, not %r." % platforms) + platforms = [_execute.make_str(_s, "platforms") for _s in platforms] + if function_list is None: + function_list = [] + if not isinstance(function_list, (list, tuple)): + raise TypeError( + "Expected list for 'function_list' argument to " + "'xla_call_module' Op, not %r." % function_list) + if has_token_input_output is None: + has_token_input_output = False + has_token_input_output = _execute.make_bool(has_token_input_output, "has_token_input_output") + if disabled_checks is None: + disabled_checks = [] + if not isinstance(disabled_checks, (list, tuple)): + raise TypeError( + "Expected list for 'disabled_checks' argument to " + "'xla_call_module' Op, not %r." % disabled_checks) + disabled_checks = [_execute.make_str(_s, "disabled_checks") for _s in disabled_checks] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaCallModule", args=args, version=version, module=module, Sout=Sout, + Tout=Tout, dim_args_spec=dim_args_spec, + platforms=platforms, function_list=function_list, + has_token_input_output=has_token_input_output, + disabled_checks=disabled_checks, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_call_module, (), dict(args=args, version=version, module=module, + Sout=Sout, Tout=Tout, + dim_args_spec=dim_args_spec, + platforms=platforms, + function_list=function_list, + has_token_input_output=has_token_input_output, + disabled_checks=disabled_checks, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if not _result: + return _op + if _execute.must_record_gradient(): + _attrs = ("version", _op._get_attr_int("version"), "module", + _op.get_attr("module"), "Sout", _op.get_attr("Sout"), "Tout", + _op.get_attr("Tout"), "Tin", _op.get_attr("Tin"), + "dim_args_spec", _op.get_attr("dim_args_spec"), "platforms", + _op.get_attr("platforms"), "function_list", + _op.get_attr("function_list"), "has_token_input_output", + _op._get_attr_bool("has_token_input_output"), "disabled_checks", + _op.get_attr("disabled_checks")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaCallModule", _inputs_flat, _attrs, _result) + return _result + +XlaCallModule = tf_export("raw_ops.XlaCallModule")(_ops.to_raw_op(xla_call_module)) +_dispatcher_for_xla_call_module = xla_call_module._tf_type_based_dispatcher.Dispatch + + +def xla_call_module_eager_fallback(args, version: int, module: str, Sout, Tout, dim_args_spec, platforms, function_list, has_token_input_output: bool, disabled_checks, name, ctx): + version = _execute.make_int(version, "version") + module = _execute.make_str(module, "module") + if not isinstance(Sout, (list, tuple)): + raise TypeError( + "Expected list for 'Sout' argument to " + "'xla_call_module' Op, not %r." % Sout) + Sout = [_execute.make_shape(_s, "Sout") for _s in Sout] + if not isinstance(Tout, (list, tuple)): + raise TypeError( + "Expected list for 'Tout' argument to " + "'xla_call_module' Op, not %r." % Tout) + Tout = [_execute.make_type(_t, "Tout") for _t in Tout] + if dim_args_spec is None: + dim_args_spec = [] + if not isinstance(dim_args_spec, (list, tuple)): + raise TypeError( + "Expected list for 'dim_args_spec' argument to " + "'xla_call_module' Op, not %r." % dim_args_spec) + dim_args_spec = [_execute.make_str(_s, "dim_args_spec") for _s in dim_args_spec] + if platforms is None: + platforms = [] + if not isinstance(platforms, (list, tuple)): + raise TypeError( + "Expected list for 'platforms' argument to " + "'xla_call_module' Op, not %r." % platforms) + platforms = [_execute.make_str(_s, "platforms") for _s in platforms] + if function_list is None: + function_list = [] + if not isinstance(function_list, (list, tuple)): + raise TypeError( + "Expected list for 'function_list' argument to " + "'xla_call_module' Op, not %r." % function_list) + if has_token_input_output is None: + has_token_input_output = False + has_token_input_output = _execute.make_bool(has_token_input_output, "has_token_input_output") + if disabled_checks is None: + disabled_checks = [] + if not isinstance(disabled_checks, (list, tuple)): + raise TypeError( + "Expected list for 'disabled_checks' argument to " + "'xla_call_module' Op, not %r." % disabled_checks) + disabled_checks = [_execute.make_str(_s, "disabled_checks") for _s in disabled_checks] + _attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx) + _inputs_flat = list(args) + _attrs = ("version", version, "module", module, "Sout", Sout, "Tout", Tout, + "Tin", _attr_Tin, "dim_args_spec", dim_args_spec, "platforms", platforms, + "function_list", function_list, "has_token_input_output", + has_token_input_output, "disabled_checks", disabled_checks) + _result = _execute.execute(b"XlaCallModule", len(Tout), inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaCallModule", _inputs_flat, _attrs, _result) + return _result + + +TV_XlaConv_T = TypeVar("TV_XlaConv_T", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaConv_Tindices = TypeVar("TV_XlaConv_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_conv') +def xla_conv(lhs: Annotated[Any, TV_XlaConv_T], rhs: Annotated[Any, TV_XlaConv_T], window_strides: Annotated[Any, TV_XlaConv_Tindices], padding: Annotated[Any, TV_XlaConv_Tindices], lhs_dilation: Annotated[Any, TV_XlaConv_Tindices], rhs_dilation: Annotated[Any, TV_XlaConv_Tindices], feature_group_count: Annotated[Any, TV_XlaConv_Tindices], dimension_numbers: str, precision_config: str, name=None) -> Annotated[Any, TV_XlaConv_T]: + r"""Wraps the XLA ConvGeneralDilated operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + . + + Args: + lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the input tensor + rhs: A `Tensor`. Must have the same type as `lhs`. the kernel tensor + window_strides: A `Tensor`. Must be one of the following types: `int32`, `int64`. + the inter-window strides + padding: A `Tensor`. Must have the same type as `window_strides`. + the padding to apply at the start and end of each input dimensions + lhs_dilation: A `Tensor`. Must have the same type as `window_strides`. + dilation to apply between input elements + rhs_dilation: A `Tensor`. Must have the same type as `window_strides`. + dilation to apply between kernel elements + feature_group_count: A `Tensor`. Must have the same type as `window_strides`. + number of feature groups for grouped convolution. + dimension_numbers: A `string`. + a serialized xla::ConvolutionDimensionNumbers proto. + precision_config: A `string`. a serialized xla::PrecisionConfig proto. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `lhs`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaConv", name, lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, feature_group_count, "dimension_numbers", + dimension_numbers, "precision_config", precision_config) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_conv( + (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, dimension_numbers, precision_config, name,), + None) + if _result is not NotImplemented: + return _result + return xla_conv_eager_fallback( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_conv, (), dict(lhs=lhs, rhs=rhs, + window_strides=window_strides, padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_conv( + (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, dimension_numbers, precision_config, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaConv", lhs=lhs, rhs=rhs, window_strides=window_strides, + padding=padding, lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_conv, (), dict(lhs=lhs, rhs=rhs, window_strides=window_strides, + padding=padding, lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices"), "dimension_numbers", + _op.get_attr("dimension_numbers"), "precision_config", + _op.get_attr("precision_config")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaConv", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaConv = tf_export("raw_ops.XlaConv")(_ops.to_raw_op(xla_conv)) +_dispatcher_for_xla_conv = xla_conv._tf_type_based_dispatcher.Dispatch + + +def xla_conv_eager_fallback(lhs: Annotated[Any, TV_XlaConv_T], rhs: Annotated[Any, TV_XlaConv_T], window_strides: Annotated[Any, TV_XlaConv_Tindices], padding: Annotated[Any, TV_XlaConv_Tindices], lhs_dilation: Annotated[Any, TV_XlaConv_Tindices], rhs_dilation: Annotated[Any, TV_XlaConv_Tindices], feature_group_count: Annotated[Any, TV_XlaConv_Tindices], dimension_numbers: str, precision_config: str, name, ctx) -> Annotated[Any, TV_XlaConv_T]: + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + _attr_T, _inputs_T = _execute.args_to_matching_eager([lhs, rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + (lhs, rhs) = _inputs_T + _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count], ctx, [_dtypes.int32, _dtypes.int64, ]) + (window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count) = _inputs_Tindices + _inputs_flat = [lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count] + _attrs = ("T", _attr_T, "Tindices", _attr_Tindices, "dimension_numbers", + dimension_numbers, "precision_config", precision_config) + _result = _execute.execute(b"XlaConv", 1, inputs=_inputs_flat, attrs=_attrs, + ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaConv", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaConvV2_LhsT = TypeVar("TV_XlaConvV2_LhsT", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaConvV2_RhsT = TypeVar("TV_XlaConvV2_RhsT", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaConvV2_Tindices = TypeVar("TV_XlaConvV2_Tindices", _atypes.Int32, _atypes.Int64) +TV_XlaConvV2_preferred_element_type = TypeVar("TV_XlaConvV2_preferred_element_type", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_conv_v2') +def xla_conv_v2(lhs: Annotated[Any, TV_XlaConvV2_LhsT], rhs: Annotated[Any, TV_XlaConvV2_RhsT], window_strides: Annotated[Any, TV_XlaConvV2_Tindices], padding: Annotated[Any, TV_XlaConvV2_Tindices], lhs_dilation: Annotated[Any, TV_XlaConvV2_Tindices], rhs_dilation: Annotated[Any, TV_XlaConvV2_Tindices], feature_group_count: Annotated[Any, TV_XlaConvV2_Tindices], dimension_numbers: str, precision_config: str, preferred_element_type: TV_XlaConvV2_preferred_element_type, batch_group_count:int=1, name=None) -> Annotated[Any, TV_XlaConvV2_preferred_element_type]: + r"""Wraps the XLA ConvGeneralDilated operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + . + + Args: + lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + input tensor + rhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + kernel tensor + window_strides: A `Tensor`. Must be one of the following types: `int32`, `int64`. + inter-window strides + padding: A `Tensor`. Must have the same type as `window_strides`. + padding to apply at the start and end of each input dimensions + lhs_dilation: A `Tensor`. Must have the same type as `window_strides`. + dilation to apply between input elements + rhs_dilation: A `Tensor`. Must have the same type as `window_strides`. + dilation to apply between kernel elements + feature_group_count: A `Tensor`. Must have the same type as `window_strides`. + number of feature groups for grouped convolution. + dimension_numbers: A `string`. + serialized xla::ConvolutionDimensionNumbers proto. + precision_config: A `string`. serialized xla::PrecisionConfig proto. + preferred_element_type: A `tf.DType` from: `tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.complex64, tf.int64, tf.qint8, tf.quint8, tf.qint32, tf.bfloat16, tf.qint16, tf.quint16, tf.uint16, tf.complex128, tf.half, tf.uint32, tf.uint64`. + type of the tensor. + batch_group_count: An optional `int`. Defaults to `1`. + number of batch groups or grouped filters. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `preferred_element_type`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaConvV2", name, lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, feature_group_count, "dimension_numbers", + dimension_numbers, "precision_config", precision_config, + "preferred_element_type", preferred_element_type, "batch_group_count", + batch_group_count) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_conv_v2( + (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, dimension_numbers, precision_config, + preferred_element_type, batch_group_count, name,), None) + if _result is not NotImplemented: + return _result + return xla_conv_v2_eager_fallback( + lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, + batch_group_count=batch_group_count, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_conv_v2, (), dict(lhs=lhs, rhs=rhs, + window_strides=window_strides, + padding=padding, lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, + batch_group_count=batch_group_count, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_conv_v2( + (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, + feature_group_count, dimension_numbers, precision_config, + preferred_element_type, batch_group_count, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") + if batch_group_count is None: + batch_group_count = 1 + batch_group_count = _execute.make_int(batch_group_count, "batch_group_count") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaConvV2", lhs=lhs, rhs=rhs, window_strides=window_strides, + padding=padding, lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, + batch_group_count=batch_group_count, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_conv_v2, (), dict(lhs=lhs, rhs=rhs, + window_strides=window_strides, + padding=padding, lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, + batch_group_count=batch_group_count, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("LhsT", _op._get_attr_type("LhsT"), "RhsT", + _op._get_attr_type("RhsT"), "Tindices", + _op._get_attr_type("Tindices"), "dimension_numbers", + _op.get_attr("dimension_numbers"), "precision_config", + _op.get_attr("precision_config"), "preferred_element_type", + _op._get_attr_type("preferred_element_type"), + "batch_group_count", _op._get_attr_int("batch_group_count")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaConvV2", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaConvV2 = tf_export("raw_ops.XlaConvV2")(_ops.to_raw_op(xla_conv_v2)) +_dispatcher_for_xla_conv_v2 = xla_conv_v2._tf_type_based_dispatcher.Dispatch + + +def xla_conv_v2_eager_fallback(lhs: Annotated[Any, TV_XlaConvV2_LhsT], rhs: Annotated[Any, TV_XlaConvV2_RhsT], window_strides: Annotated[Any, TV_XlaConvV2_Tindices], padding: Annotated[Any, TV_XlaConvV2_Tindices], lhs_dilation: Annotated[Any, TV_XlaConvV2_Tindices], rhs_dilation: Annotated[Any, TV_XlaConvV2_Tindices], feature_group_count: Annotated[Any, TV_XlaConvV2_Tindices], dimension_numbers: str, precision_config: str, preferred_element_type: TV_XlaConvV2_preferred_element_type, batch_group_count: int, name, ctx) -> Annotated[Any, TV_XlaConvV2_preferred_element_type]: + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") + if batch_group_count is None: + batch_group_count = 1 + batch_group_count = _execute.make_int(batch_group_count, "batch_group_count") + _attr_LhsT, (lhs,) = _execute.args_to_matching_eager([lhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + _attr_RhsT, (rhs,) = _execute.args_to_matching_eager([rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count], ctx, [_dtypes.int32, _dtypes.int64, ]) + (window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count) = _inputs_Tindices + _inputs_flat = [lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count] + _attrs = ("LhsT", _attr_LhsT, "RhsT", _attr_RhsT, "Tindices", + _attr_Tindices, "dimension_numbers", dimension_numbers, "precision_config", + precision_config, "preferred_element_type", preferred_element_type, + "batch_group_count", batch_group_count) + _result = _execute.execute(b"XlaConvV2", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaConvV2", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaCustomCall_dtype = TypeVar("TV_XlaCustomCall_dtype", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_custom_call') +def xla_custom_call(args, target_name: str, backend_config: str, dtype: TV_XlaCustomCall_dtype, shape, name=None) -> Annotated[Any, TV_XlaCustomCall_dtype]: + r"""Wraps the XLA CustomCall operator + + documented at https://www.tensorflow.org/xla/operation_semantics#customcall. + + Args: + args: A list of `Tensor` objects. + A list of `Tensor` with possibly different types. + target_name: A `string`. + Name of the function. A call instruction will be emitted which + targets this symbol name. + backend_config: A `string`. + String, used to encode serialized metadata to the backend. + dtype: A `tf.DType`. Output tensor data type. + shape: A `tf.TensorShape` or list of `ints`. Output tensor shape. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `dtype`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaCustomCall", name, args, "target_name", target_name, + "backend_config", backend_config, "dtype", dtype, "shape", shape) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_custom_call( + (args, target_name, backend_config, dtype, shape, name,), None) + if _result is not NotImplemented: + return _result + return xla_custom_call_eager_fallback( + args, target_name=target_name, backend_config=backend_config, + dtype=dtype, shape=shape, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_custom_call, (), dict(args=args, target_name=target_name, + backend_config=backend_config, + dtype=dtype, shape=shape, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_custom_call( + (args, target_name, backend_config, dtype, shape, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + target_name = _execute.make_str(target_name, "target_name") + backend_config = _execute.make_str(backend_config, "backend_config") + dtype = _execute.make_type(dtype, "dtype") + shape = _execute.make_shape(shape, "shape") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaCustomCall", args=args, target_name=target_name, + backend_config=backend_config, dtype=dtype, + shape=shape, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_custom_call, (), dict(args=args, target_name=target_name, + backend_config=backend_config, + dtype=dtype, shape=shape, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("target_name", _op.get_attr("target_name"), "backend_config", + _op.get_attr("backend_config"), "T", _op.get_attr("T"), "dtype", + _op._get_attr_type("dtype"), "shape", _op.get_attr("shape")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaCustomCall", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaCustomCall = tf_export("raw_ops.XlaCustomCall")(_ops.to_raw_op(xla_custom_call)) +_dispatcher_for_xla_custom_call = xla_custom_call._tf_type_based_dispatcher.Dispatch + + +def xla_custom_call_eager_fallback(args, target_name: str, backend_config: str, dtype: TV_XlaCustomCall_dtype, shape, name, ctx) -> Annotated[Any, TV_XlaCustomCall_dtype]: + target_name = _execute.make_str(target_name, "target_name") + backend_config = _execute.make_str(backend_config, "backend_config") + dtype = _execute.make_type(dtype, "dtype") + shape = _execute.make_shape(shape, "shape") + _attr_T, args = _execute.convert_to_mixed_eager_tensors(args, ctx) + _inputs_flat = list(args) + _attrs = ("target_name", target_name, "backend_config", backend_config, "T", + _attr_T, "dtype", dtype, "shape", shape) + _result = _execute.execute(b"XlaCustomCall", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaCustomCall", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_custom_call_v2') +def xla_custom_call_v2(operands, call_target_name: str, backend_config: str, has_side_effect: bool, result_dtypes, result_shapes, name=None): + r"""Emits an HLO `CustomCall` operation with multiple outputs. + + As opposed to `XlaCustomCall`, this operation supports multiple outputs. + + See `CustomCall` specification at + https://tensorflow.org/xla/operation_semantics#customcall, + and `mhlo.custom_call` specification at + https://tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop. + + Args: + operands: A list of `Tensor` objects. + A sequence of tensors with possibly different types. + call_target_name: A `string`. + Name of the user function. The function signature must conform + to version 3 of the API, see `API_VERSION_STATUS_RETURNING_UNIFIED`. All + operands and results assumed to be in the default layout. + backend_config: A `string`. + A string that encodes a metadata for the backend. + has_side_effect: A `bool`. + Indicates whether the custom call has side effects. + result_dtypes: A list of `tf.DTypes`. Types of all results. + result_shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). + Shapes of all results. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects of type `result_dtypes`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaCustomCallV2", name, operands, "call_target_name", + call_target_name, "backend_config", backend_config, "has_side_effect", + has_side_effect, "result_dtypes", result_dtypes, "result_shapes", + result_shapes) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_custom_call_v2( + (operands, call_target_name, backend_config, has_side_effect, + result_dtypes, result_shapes, name,), None) + if _result is not NotImplemented: + return _result + return xla_custom_call_v2_eager_fallback( + operands, call_target_name=call_target_name, + backend_config=backend_config, has_side_effect=has_side_effect, + result_dtypes=result_dtypes, result_shapes=result_shapes, name=name, + ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_custom_call_v2, (), dict(operands=operands, + call_target_name=call_target_name, + backend_config=backend_config, + has_side_effect=has_side_effect, + result_dtypes=result_dtypes, + result_shapes=result_shapes, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_custom_call_v2( + (operands, call_target_name, backend_config, has_side_effect, + result_dtypes, result_shapes, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + call_target_name = _execute.make_str(call_target_name, "call_target_name") + backend_config = _execute.make_str(backend_config, "backend_config") + has_side_effect = _execute.make_bool(has_side_effect, "has_side_effect") + if not isinstance(result_dtypes, (list, tuple)): + raise TypeError( + "Expected list for 'result_dtypes' argument to " + "'xla_custom_call_v2' Op, not %r." % result_dtypes) + result_dtypes = [_execute.make_type(_t, "result_dtypes") for _t in result_dtypes] + if not isinstance(result_shapes, (list, tuple)): + raise TypeError( + "Expected list for 'result_shapes' argument to " + "'xla_custom_call_v2' Op, not %r." % result_shapes) + result_shapes = [_execute.make_shape(_s, "result_shapes") for _s in result_shapes] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaCustomCallV2", operands=operands, + call_target_name=call_target_name, + backend_config=backend_config, + has_side_effect=has_side_effect, + result_dtypes=result_dtypes, + result_shapes=result_shapes, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_custom_call_v2, (), dict(operands=operands, + call_target_name=call_target_name, + backend_config=backend_config, + has_side_effect=has_side_effect, + result_dtypes=result_dtypes, + result_shapes=result_shapes, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("call_target_name", _op.get_attr("call_target_name"), + "backend_config", _op.get_attr("backend_config"), + "has_side_effect", _op._get_attr_bool("has_side_effect"), + "operand_dtypes", _op.get_attr("operand_dtypes"), + "result_dtypes", _op.get_attr("result_dtypes"), "result_shapes", + _op.get_attr("result_shapes")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaCustomCallV2", _inputs_flat, _attrs, _result) + return _result + +XlaCustomCallV2 = tf_export("raw_ops.XlaCustomCallV2")(_ops.to_raw_op(xla_custom_call_v2)) +_dispatcher_for_xla_custom_call_v2 = xla_custom_call_v2._tf_type_based_dispatcher.Dispatch + + +def xla_custom_call_v2_eager_fallback(operands, call_target_name: str, backend_config: str, has_side_effect: bool, result_dtypes, result_shapes, name, ctx): + call_target_name = _execute.make_str(call_target_name, "call_target_name") + backend_config = _execute.make_str(backend_config, "backend_config") + has_side_effect = _execute.make_bool(has_side_effect, "has_side_effect") + if not isinstance(result_dtypes, (list, tuple)): + raise TypeError( + "Expected list for 'result_dtypes' argument to " + "'xla_custom_call_v2' Op, not %r." % result_dtypes) + result_dtypes = [_execute.make_type(_t, "result_dtypes") for _t in result_dtypes] + if not isinstance(result_shapes, (list, tuple)): + raise TypeError( + "Expected list for 'result_shapes' argument to " + "'xla_custom_call_v2' Op, not %r." % result_shapes) + result_shapes = [_execute.make_shape(_s, "result_shapes") for _s in result_shapes] + _attr_operand_dtypes, operands = _execute.convert_to_mixed_eager_tensors(operands, ctx) + _inputs_flat = list(operands) + _attrs = ("call_target_name", call_target_name, "backend_config", + backend_config, "has_side_effect", has_side_effect, "operand_dtypes", + _attr_operand_dtypes, "result_dtypes", result_dtypes, "result_shapes", + result_shapes) + _result = _execute.execute(b"XlaCustomCallV2", len(result_dtypes), + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaCustomCallV2", _inputs_flat, _attrs, _result) + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_dequantize') +def xla_dequantize(input: Annotated[Any, _atypes.UInt32], min_range: float, max_range: float, mode: str, transpose_output: bool, name=None) -> Annotated[Any, _atypes.BFloat16]: + r"""Takes the packed uint32 input and unpacks the input to uint8 to do + + Dequantization on device. + + Args: + input: A `Tensor` of type `uint32`. + Input tensors whose types is uint32, shape is [d0, ..., dn]. + min_range: A `float`. + The minimum scalar value possibly produced for the input. + max_range: A `float`. + The maximum scalar value possibly produced for the input. + mode: A `string`. + String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}. + transpose_output: A `bool`. + Boolean to determine if output is transposed. transpose_output + is faster when input is large and rank of input is higher than 1. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `bfloat16`. + Output tensors whose types is bfloat16. If transpose_output is true, + output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output + is false, output shape is [d0,..., dn * 4]. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaDequantize", name, input, "min_range", min_range, + "max_range", max_range, "mode", mode, "transpose_output", + transpose_output) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_dequantize( + (input, min_range, max_range, mode, transpose_output, name,), None) + if _result is not NotImplemented: + return _result + return xla_dequantize_eager_fallback( + input, min_range=min_range, max_range=max_range, mode=mode, + transpose_output=transpose_output, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dequantize, (), dict(input=input, min_range=min_range, + max_range=max_range, mode=mode, + transpose_output=transpose_output, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_dequantize( + (input, min_range, max_range, mode, transpose_output, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + min_range = _execute.make_float(min_range, "min_range") + max_range = _execute.make_float(max_range, "max_range") + mode = _execute.make_str(mode, "mode") + transpose_output = _execute.make_bool(transpose_output, "transpose_output") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaDequantize", input=input, min_range=min_range, + max_range=max_range, mode=mode, + transpose_output=transpose_output, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dequantize, (), dict(input=input, min_range=min_range, + max_range=max_range, mode=mode, + transpose_output=transpose_output, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("min_range", _op.get_attr("min_range"), "max_range", + _op.get_attr("max_range"), "mode", _op.get_attr("mode"), + "transpose_output", _op._get_attr_bool("transpose_output")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaDequantize", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaDequantize = tf_export("raw_ops.XlaDequantize")(_ops.to_raw_op(xla_dequantize)) +_dispatcher_for_xla_dequantize = xla_dequantize._tf_type_based_dispatcher.Dispatch + + +def xla_dequantize_eager_fallback(input: Annotated[Any, _atypes.UInt32], min_range: float, max_range: float, mode: str, transpose_output: bool, name, ctx) -> Annotated[Any, _atypes.BFloat16]: + min_range = _execute.make_float(min_range, "min_range") + max_range = _execute.make_float(max_range, "max_range") + mode = _execute.make_str(mode, "mode") + transpose_output = _execute.make_bool(transpose_output, "transpose_output") + input = _ops.convert_to_tensor(input, _dtypes.uint32) + _inputs_flat = [input] + _attrs = ("min_range", min_range, "max_range", max_range, "mode", mode, + "transpose_output", transpose_output) + _result = _execute.execute(b"XlaDequantize", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaDequantize", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaDot_T = TypeVar("TV_XlaDot_T", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_dot') +def xla_dot(lhs: Annotated[Any, TV_XlaDot_T], rhs: Annotated[Any, TV_XlaDot_T], dimension_numbers: str, precision_config: str, name=None) -> Annotated[Any, TV_XlaDot_T]: + r"""Wraps the XLA DotGeneral operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral + . + + Args: + lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the LHS tensor + rhs: A `Tensor`. Must have the same type as `lhs`. the RHS tensor + dimension_numbers: A `string`. + a serialized xla::DotDimensionNumbers proto. + precision_config: A `string`. a serialized xla::PrecisionConfig proto. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `lhs`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaDot", name, lhs, rhs, "dimension_numbers", + dimension_numbers, "precision_config", precision_config) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_dot( + (lhs, rhs, dimension_numbers, precision_config, name,), None) + if _result is not NotImplemented: + return _result + return xla_dot_eager_fallback( + lhs, rhs, dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dot, (), dict(lhs=lhs, rhs=rhs, + dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_dot( + (lhs, rhs, dimension_numbers, precision_config, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaDot", lhs=lhs, rhs=rhs, dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dot, (), dict(lhs=lhs, rhs=rhs, + dimension_numbers=dimension_numbers, + precision_config=precision_config, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "dimension_numbers", + _op.get_attr("dimension_numbers"), "precision_config", + _op.get_attr("precision_config")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaDot", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaDot = tf_export("raw_ops.XlaDot")(_ops.to_raw_op(xla_dot)) +_dispatcher_for_xla_dot = xla_dot._tf_type_based_dispatcher.Dispatch + + +def xla_dot_eager_fallback(lhs: Annotated[Any, TV_XlaDot_T], rhs: Annotated[Any, TV_XlaDot_T], dimension_numbers: str, precision_config: str, name, ctx) -> Annotated[Any, TV_XlaDot_T]: + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + _attr_T, _inputs_T = _execute.args_to_matching_eager([lhs, rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + (lhs, rhs) = _inputs_T + _inputs_flat = [lhs, rhs] + _attrs = ("T", _attr_T, "dimension_numbers", dimension_numbers, + "precision_config", precision_config) + _result = _execute.execute(b"XlaDot", 1, inputs=_inputs_flat, attrs=_attrs, + ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaDot", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaDotV2_LhsT = TypeVar("TV_XlaDotV2_LhsT", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaDotV2_RhsT = TypeVar("TV_XlaDotV2_RhsT", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaDotV2_preferred_element_type = TypeVar("TV_XlaDotV2_preferred_element_type", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_dot_v2') +def xla_dot_v2(lhs: Annotated[Any, TV_XlaDotV2_LhsT], rhs: Annotated[Any, TV_XlaDotV2_RhsT], dimension_numbers: str, precision_config: str, preferred_element_type: TV_XlaDotV2_preferred_element_type, name=None) -> Annotated[Any, TV_XlaDotV2_preferred_element_type]: + r"""Wraps the XLA DotGeneral operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral + . + + Args: + lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the LHS tensor + rhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the RHS tensor + dimension_numbers: A `string`. + a serialized xla::DotDimensionNumbers proto. + precision_config: A `string`. a serialized xla::PrecisionConfig proto. + preferred_element_type: A `tf.DType` from: `tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.complex64, tf.int64, tf.qint8, tf.quint8, tf.qint32, tf.bfloat16, tf.qint16, tf.quint16, tf.uint16, tf.complex128, tf.half, tf.uint32, tf.uint64`. + The type of the tensor. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `preferred_element_type`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaDotV2", name, lhs, rhs, "dimension_numbers", + dimension_numbers, "precision_config", precision_config, + "preferred_element_type", preferred_element_type) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_dot_v2( + (lhs, rhs, dimension_numbers, precision_config, + preferred_element_type, name,), None) + if _result is not NotImplemented: + return _result + return xla_dot_v2_eager_fallback( + lhs, rhs, dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dot_v2, (), dict(lhs=lhs, rhs=rhs, + dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_dot_v2( + (lhs, rhs, dimension_numbers, precision_config, + preferred_element_type, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaDotV2", lhs=lhs, rhs=rhs, dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dot_v2, (), dict(lhs=lhs, rhs=rhs, + dimension_numbers=dimension_numbers, + precision_config=precision_config, + preferred_element_type=preferred_element_type, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("LhsT", _op._get_attr_type("LhsT"), "RhsT", + _op._get_attr_type("RhsT"), "dimension_numbers", + _op.get_attr("dimension_numbers"), "precision_config", + _op.get_attr("precision_config"), "preferred_element_type", + _op._get_attr_type("preferred_element_type")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaDotV2", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaDotV2 = tf_export("raw_ops.XlaDotV2")(_ops.to_raw_op(xla_dot_v2)) +_dispatcher_for_xla_dot_v2 = xla_dot_v2._tf_type_based_dispatcher.Dispatch + + +def xla_dot_v2_eager_fallback(lhs: Annotated[Any, TV_XlaDotV2_LhsT], rhs: Annotated[Any, TV_XlaDotV2_RhsT], dimension_numbers: str, precision_config: str, preferred_element_type: TV_XlaDotV2_preferred_element_type, name, ctx) -> Annotated[Any, TV_XlaDotV2_preferred_element_type]: + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + precision_config = _execute.make_str(precision_config, "precision_config") + preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") + _attr_LhsT, (lhs,) = _execute.args_to_matching_eager([lhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + _attr_RhsT, (rhs,) = _execute.args_to_matching_eager([rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + _inputs_flat = [lhs, rhs] + _attrs = ("LhsT", _attr_LhsT, "RhsT", _attr_RhsT, "dimension_numbers", + dimension_numbers, "precision_config", precision_config, + "preferred_element_type", preferred_element_type) + _result = _execute.execute(b"XlaDotV2", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaDotV2", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaDynamicSlice_T = TypeVar("TV_XlaDynamicSlice_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) +TV_XlaDynamicSlice_Tindices = TypeVar("TV_XlaDynamicSlice_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_dynamic_slice') +def xla_dynamic_slice(input: Annotated[Any, TV_XlaDynamicSlice_T], start_indices: Annotated[Any, TV_XlaDynamicSlice_Tindices], size_indices: Annotated[Any, TV_XlaDynamicSlice_Tindices], name=None) -> Annotated[Any, TV_XlaDynamicSlice_T]: + r"""Wraps the XLA DynamicSlice operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice + . + + DynamicSlice extracts a sub-array from the input array at dynamic + start_indices. The size of the slice in each dimension is passed in + size_indices, which specify the end point of exclusive slice intervals in each + dimension -- [start, start + size). The shape of start_indices must have rank 1, + with dimension size equal to the rank of operand. + + Args: + input: A `Tensor`. A `Tensor` of type T. + start_indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + List of N integers containing the slice size for each + dimension. Each value must be strictly greater than zero, and start + size + must be less than or equal to the size of the dimension to avoid + implementation defined behavior. + size_indices: A `Tensor`. Must have the same type as `start_indices`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaDynamicSlice", name, input, start_indices, size_indices) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_dynamic_slice( + (input, start_indices, size_indices, name,), None) + if _result is not NotImplemented: + return _result + return xla_dynamic_slice_eager_fallback( + input, start_indices, size_indices, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dynamic_slice, (), dict(input=input, + start_indices=start_indices, + size_indices=size_indices, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_dynamic_slice( + (input, start_indices, size_indices, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaDynamicSlice", input=input, start_indices=start_indices, + size_indices=size_indices, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dynamic_slice, (), dict(input=input, + start_indices=start_indices, + size_indices=size_indices, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaDynamicSlice", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaDynamicSlice = tf_export("raw_ops.XlaDynamicSlice")(_ops.to_raw_op(xla_dynamic_slice)) +_dispatcher_for_xla_dynamic_slice = xla_dynamic_slice._tf_type_based_dispatcher.Dispatch + + +def xla_dynamic_slice_eager_fallback(input: Annotated[Any, TV_XlaDynamicSlice_T], start_indices: Annotated[Any, TV_XlaDynamicSlice_Tindices], size_indices: Annotated[Any, TV_XlaDynamicSlice_Tindices], name, ctx) -> Annotated[Any, TV_XlaDynamicSlice_T]: + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) + _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([start_indices, size_indices], ctx, [_dtypes.int32, _dtypes.int64, ]) + (start_indices, size_indices) = _inputs_Tindices + _inputs_flat = [input, start_indices, size_indices] + _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) + _result = _execute.execute(b"XlaDynamicSlice", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaDynamicSlice", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaDynamicUpdateSlice_T = TypeVar("TV_XlaDynamicUpdateSlice_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) +TV_XlaDynamicUpdateSlice_Tindices = TypeVar("TV_XlaDynamicUpdateSlice_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_dynamic_update_slice') +def xla_dynamic_update_slice(input: Annotated[Any, TV_XlaDynamicUpdateSlice_T], update: Annotated[Any, TV_XlaDynamicUpdateSlice_T], indices: Annotated[Any, TV_XlaDynamicUpdateSlice_Tindices], name=None) -> Annotated[Any, TV_XlaDynamicUpdateSlice_T]: + r"""Wraps the XLA DynamicUpdateSlice operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice + . + + XlaDynamicUpdateSlice generates a result which is the value of the `input` + operand, with a slice update overwritten at `indices`. The shape of `update` + determines the shape of the sub-array of the result which is updated. The shape + of indices must be rank == 1, with dimension size equal to the rank of `input`. + + Handling of out-of-bounds slice indices is implementation-defined. + + Args: + input: A `Tensor`. A `Tensor` of type T. + update: A `Tensor`. Must have the same type as `input`. + A `Tensor` of type T. Same rank as `input`. + indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + A vector of indices into `input`. Must have length equal to the rank of + `input`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. A `Tensor` of type T. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaDynamicUpdateSlice", name, input, update, indices) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_dynamic_update_slice( + (input, update, indices, name,), None) + if _result is not NotImplemented: + return _result + return xla_dynamic_update_slice_eager_fallback( + input, update, indices, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dynamic_update_slice, (), dict(input=input, update=update, + indices=indices, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_dynamic_update_slice( + (input, update, indices, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaDynamicUpdateSlice", input=input, update=update, indices=indices, + name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_dynamic_update_slice, (), dict(input=input, update=update, + indices=indices, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaDynamicUpdateSlice", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaDynamicUpdateSlice = tf_export("raw_ops.XlaDynamicUpdateSlice")(_ops.to_raw_op(xla_dynamic_update_slice)) +_dispatcher_for_xla_dynamic_update_slice = xla_dynamic_update_slice._tf_type_based_dispatcher.Dispatch + + +def xla_dynamic_update_slice_eager_fallback(input: Annotated[Any, TV_XlaDynamicUpdateSlice_T], update: Annotated[Any, TV_XlaDynamicUpdateSlice_T], indices: Annotated[Any, TV_XlaDynamicUpdateSlice_Tindices], name, ctx) -> Annotated[Any, TV_XlaDynamicUpdateSlice_T]: + _attr_T, _inputs_T = _execute.args_to_matching_eager([input, update], ctx, []) + (input, update) = _inputs_T + _attr_Tindices, (indices,) = _execute.args_to_matching_eager([indices], ctx, [_dtypes.int32, _dtypes.int64, ]) + _inputs_flat = [input, update, indices] + _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) + _result = _execute.execute(b"XlaDynamicUpdateSlice", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaDynamicUpdateSlice", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaEinsum_T = TypeVar("TV_XlaEinsum_T", _atypes.BFloat16, _atypes.Complex64, _atypes.Float32) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_einsum') +def xla_einsum(a: Annotated[Any, TV_XlaEinsum_T], b: Annotated[Any, TV_XlaEinsum_T], equation: str, name=None) -> Annotated[Any, TV_XlaEinsum_T]: + r"""An op which supports basic einsum op with 2 inputs and 1 output. + + This op has better TPU performance since it doesn't have explicitly reshape and + transpose operations as tf.einsum does. + + Args: + a: A `Tensor`. Must be one of the following types: `complex64`, `bfloat16`, `float32`. + b: A `Tensor`. Must have the same type as `a`. + equation: A `string`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `a`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaEinsum", name, a, b, "equation", equation) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_einsum( + (a, b, equation, name,), None) + if _result is not NotImplemented: + return _result + return xla_einsum_eager_fallback( + a, b, equation=equation, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_einsum, (), dict(a=a, b=b, equation=equation, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_einsum( + (a, b, equation, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + equation = _execute.make_str(equation, "equation") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaEinsum", a=a, b=b, equation=equation, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_einsum, (), dict(a=a, b=b, equation=equation, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("equation", _op.get_attr("equation"), "T", + _op._get_attr_type("T")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaEinsum", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaEinsum = tf_export("raw_ops.XlaEinsum")(_ops.to_raw_op(xla_einsum)) +_dispatcher_for_xla_einsum = xla_einsum._tf_type_based_dispatcher.Dispatch + + +def xla_einsum_eager_fallback(a: Annotated[Any, TV_XlaEinsum_T], b: Annotated[Any, TV_XlaEinsum_T], equation: str, name, ctx) -> Annotated[Any, TV_XlaEinsum_T]: + equation = _execute.make_str(equation, "equation") + _attr_T, _inputs_T = _execute.args_to_matching_eager([a, b], ctx, [_dtypes.complex64, _dtypes.bfloat16, _dtypes.float32, ]) + (a, b) = _inputs_T + _inputs_flat = [a, b] + _attrs = ("equation", equation, "T", _attr_T) + _result = _execute.execute(b"XlaEinsum", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaEinsum", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaGather_T = TypeVar("TV_XlaGather_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaGather_Tindices = TypeVar("TV_XlaGather_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_gather') +def xla_gather(operand: Annotated[Any, TV_XlaGather_T], start_indices: Annotated[Any, TV_XlaGather_Tindices], slice_sizes: Annotated[Any, TV_XlaGather_Tindices], dimension_numbers: str, indices_are_sorted: bool, name=None) -> Annotated[Any, TV_XlaGather_T]: + r"""Wraps the XLA Gather operator documented at + + https://www.tensorflow.org/xla/operation_semantics#gather + + Args: + operand: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. + The array we're gathering from. + start_indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + Array containing the starting indices of the slices we gather. + slice_sizes: A `Tensor`. Must have the same type as `start_indices`. + slice_sizes[i] is the bounds for the slice on dimension i. + dimension_numbers: A `string`. + A serialized xla::GatherDimensionNumbers proto. + indices_are_sorted: A `bool`. + Boolean indicating if the indices are sorted. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `operand`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaGather", name, operand, start_indices, slice_sizes, + "dimension_numbers", dimension_numbers, "indices_are_sorted", + indices_are_sorted) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_gather( + (operand, start_indices, slice_sizes, dimension_numbers, + indices_are_sorted, name,), None) + if _result is not NotImplemented: + return _result + return xla_gather_eager_fallback( + operand, start_indices, slice_sizes, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_gather, (), dict(operand=operand, start_indices=start_indices, + slice_sizes=slice_sizes, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_gather( + (operand, start_indices, slice_sizes, dimension_numbers, + indices_are_sorted, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaGather", operand=operand, start_indices=start_indices, + slice_sizes=slice_sizes, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_gather, (), dict(operand=operand, start_indices=start_indices, + slice_sizes=slice_sizes, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("dimension_numbers", _op.get_attr("dimension_numbers"), + "indices_are_sorted", _op._get_attr_bool("indices_are_sorted"), + "T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaGather", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaGather = tf_export("raw_ops.XlaGather")(_ops.to_raw_op(xla_gather)) +_dispatcher_for_xla_gather = xla_gather._tf_type_based_dispatcher.Dispatch + + +def xla_gather_eager_fallback(operand: Annotated[Any, TV_XlaGather_T], start_indices: Annotated[Any, TV_XlaGather_Tindices], slice_sizes: Annotated[Any, TV_XlaGather_Tindices], dimension_numbers: str, indices_are_sorted: bool, name, ctx) -> Annotated[Any, TV_XlaGather_T]: + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") + _attr_T, (operand,) = _execute.args_to_matching_eager([operand], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) + _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([start_indices, slice_sizes], ctx, [_dtypes.int32, _dtypes.int64, ]) + (start_indices, slice_sizes) = _inputs_Tindices + _inputs_flat = [operand, start_indices, slice_sizes] + _attrs = ("dimension_numbers", dimension_numbers, "indices_are_sorted", + indices_are_sorted, "T", _attr_T, "Tindices", _attr_Tindices) + _result = _execute.execute(b"XlaGather", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaGather", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaIf_Tcond = TypeVar("TV_XlaIf_Tcond", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_if') +def xla_if(cond: Annotated[Any, TV_XlaIf_Tcond], inputs, then_branch, else_branch, Tout, name=None): + r"""output = cond ? then_branch(inputs) : else_branch(inputs). + + Args: + cond: A `Tensor`. A boolean scalar. + inputs: A list of `Tensor` objects. A list of input tensors. + then_branch: A function decorated with @Defun. + A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. + else_branch: A function decorated with @Defun. + A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. + Tout: A list of `tf.DTypes`. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects of type `Tout`. + A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaIf", name, cond, inputs, "then_branch", then_branch, + "else_branch", else_branch, "Tout", Tout) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_if( + (cond, inputs, then_branch, else_branch, Tout, name,), None) + if _result is not NotImplemented: + return _result + return xla_if_eager_fallback( + cond, inputs, then_branch=then_branch, else_branch=else_branch, + Tout=Tout, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_if, (), dict(cond=cond, inputs=inputs, + then_branch=then_branch, else_branch=else_branch, + Tout=Tout, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_if( + (cond, inputs, then_branch, else_branch, Tout, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + if not isinstance(Tout, (list, tuple)): + raise TypeError( + "Expected list for 'Tout' argument to " + "'xla_if' Op, not %r." % Tout) + Tout = [_execute.make_type(_t, "Tout") for _t in Tout] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaIf", cond=cond, inputs=inputs, then_branch=then_branch, + else_branch=else_branch, Tout=Tout, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_if, (), dict(cond=cond, inputs=inputs, then_branch=then_branch, + else_branch=else_branch, Tout=Tout, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if not _result: + return _op + if _execute.must_record_gradient(): + _attrs = ("Tcond", _op._get_attr_type("Tcond"), "then_branch", + _op.get_attr("then_branch"), "else_branch", + _op.get_attr("else_branch"), "Tin", _op.get_attr("Tin"), "Tout", + _op.get_attr("Tout")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaIf", _inputs_flat, _attrs, _result) + return _result + +XlaIf = tf_export("raw_ops.XlaIf")(_ops.to_raw_op(xla_if)) +_dispatcher_for_xla_if = xla_if._tf_type_based_dispatcher.Dispatch + + +def xla_if_eager_fallback(cond: Annotated[Any, TV_XlaIf_Tcond], inputs, then_branch, else_branch, Tout, name, ctx): + if not isinstance(Tout, (list, tuple)): + raise TypeError( + "Expected list for 'Tout' argument to " + "'xla_if' Op, not %r." % Tout) + Tout = [_execute.make_type(_t, "Tout") for _t in Tout] + _attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, []) + _attr_Tin, inputs = _execute.convert_to_mixed_eager_tensors(inputs, ctx) + _inputs_flat = [cond] + list(inputs) + _attrs = ("Tcond", _attr_Tcond, "then_branch", then_branch, "else_branch", + else_branch, "Tin", _attr_Tin, "Tout", Tout) + _result = _execute.execute(b"XlaIf", len(Tout), inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaIf", _inputs_flat, _attrs, _result) + return _result + +_XlaKeyValueSortOutput = collections.namedtuple( + "XlaKeyValueSort", + ["sorted_keys", "sorted_values"]) + + +TV_XlaKeyValueSort_K = TypeVar("TV_XlaKeyValueSort_K", _atypes.BFloat16, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaKeyValueSort_V = TypeVar("TV_XlaKeyValueSort_V", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_key_value_sort') +def xla_key_value_sort(keys: Annotated[Any, TV_XlaKeyValueSort_K], values: Annotated[Any, TV_XlaKeyValueSort_V], name=None): + r"""Wraps the XLA Sort operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#sort + . + + Sorts a tensor. Currently only sorts in ascending order are supported. + + Args: + keys: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`, `uint32`, `uint64`. + A `Tensor` of type K. + values: A `Tensor`. A `Tensor` of type V. + name: A name for the operation (optional). + + Returns: + A tuple of `Tensor` objects (sorted_keys, sorted_values). + + sorted_keys: A `Tensor`. Has the same type as `keys`. A `Tensor` of type K. + sorted_values: A `Tensor`. Has the same type as `values`. A `Tensor` of type V. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaKeyValueSort", name, keys, values) + _result = _XlaKeyValueSortOutput._make(_result) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_key_value_sort( + (keys, values, name,), None) + if _result is not NotImplemented: + return _result + return xla_key_value_sort_eager_fallback( + keys, values, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_key_value_sort, (), dict(keys=keys, values=values, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_key_value_sort( + (keys, values, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaKeyValueSort", keys=keys, values=values, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_key_value_sort, (), dict(keys=keys, values=values, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("K", _op._get_attr_type("K"), "V", _op._get_attr_type("V")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaKeyValueSort", _inputs_flat, _attrs, _result) + _result = _XlaKeyValueSortOutput._make(_result) + return _result + +XlaKeyValueSort = tf_export("raw_ops.XlaKeyValueSort")(_ops.to_raw_op(xla_key_value_sort)) +_dispatcher_for_xla_key_value_sort = xla_key_value_sort._tf_type_based_dispatcher.Dispatch + + +def xla_key_value_sort_eager_fallback(keys: Annotated[Any, TV_XlaKeyValueSort_K], values: Annotated[Any, TV_XlaKeyValueSort_V], name, ctx): + _attr_K, (keys,) = _execute.args_to_matching_eager([keys], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.int64, _dtypes.bfloat16, _dtypes.uint16, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + _attr_V, (values,) = _execute.args_to_matching_eager([values], ctx, []) + _inputs_flat = [keys, values] + _attrs = ("K", _attr_K, "V", _attr_V) + _result = _execute.execute(b"XlaKeyValueSort", 2, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaKeyValueSort", _inputs_flat, _attrs, _result) + _result = _XlaKeyValueSortOutput._make(_result) + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_optimization_barrier') +def xla_optimization_barrier(input, name=None): + r"""Wraps the XLA OptimizationBarrier operator. + + Documented at https://www.tensorflow.org/xla/operation_semantics#optimizationbarrier. + + Args: + input: A list of `Tensor` objects. A Tuple of Arrays of any type. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaOptimizationBarrier", name, input) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_optimization_barrier( + (input, name,), None) + if _result is not NotImplemented: + return _result + return xla_optimization_barrier_eager_fallback( + input, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_optimization_barrier, (), dict(input=input, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_optimization_barrier( + (input, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaOptimizationBarrier", input=input, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_optimization_barrier, (), dict(input=input, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op.get_attr("T")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaOptimizationBarrier", _inputs_flat, _attrs, _result) + return _result + +XlaOptimizationBarrier = tf_export("raw_ops.XlaOptimizationBarrier")(_ops.to_raw_op(xla_optimization_barrier)) +_dispatcher_for_xla_optimization_barrier = xla_optimization_barrier._tf_type_based_dispatcher.Dispatch + + +def xla_optimization_barrier_eager_fallback(input, name, ctx): + _attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx) + _inputs_flat = list(input) + _attrs = ("T", _attr_T) + _result = _execute.execute(b"XlaOptimizationBarrier", len(input), + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaOptimizationBarrier", _inputs_flat, _attrs, _result) + return _result + + +TV_XlaPad_T = TypeVar("TV_XlaPad_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) +TV_XlaPad_Tindices = TypeVar("TV_XlaPad_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_pad') +def xla_pad(input: Annotated[Any, TV_XlaPad_T], padding_value: Annotated[Any, TV_XlaPad_T], padding_low: Annotated[Any, TV_XlaPad_Tindices], padding_high: Annotated[Any, TV_XlaPad_Tindices], padding_interior: Annotated[Any, TV_XlaPad_Tindices], name=None) -> Annotated[Any, TV_XlaPad_T]: + r"""Wraps the XLA Pad operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#pad + . + + Args: + input: A `Tensor`. A `Tensor` of type T. + padding_value: A `Tensor`. Must have the same type as `input`. + A scalar `Tensor` of type T. + padding_low: A `Tensor`. Must be one of the following types: `int32`, `int64`. + the padding to apply at the start of each input dimensions. Must + be a compile-time constant 1D tensor of length equal to rank of input. + padding_high: A `Tensor`. Must have the same type as `padding_low`. + the padding to apply at the end of each input dimension. Must + be a compile-time constant 1D tensor of length equal to rank of input. + padding_interior: A `Tensor`. Must have the same type as `padding_low`. + the padding to apply between each input element. Must + be a compile-time constant 1D tensor of length equal to rank of input, + containing only non-negative values. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. A `Tensor` of type T. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaPad", name, input, padding_value, padding_low, padding_high, + padding_interior) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_pad( + (input, padding_value, padding_low, padding_high, padding_interior, + name,), None) + if _result is not NotImplemented: + return _result + return xla_pad_eager_fallback( + input, padding_value, padding_low, padding_high, padding_interior, + name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_pad, (), dict(input=input, padding_value=padding_value, + padding_low=padding_low, + padding_high=padding_high, + padding_interior=padding_interior, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_pad( + (input, padding_value, padding_low, padding_high, padding_interior, + name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaPad", input=input, padding_value=padding_value, + padding_low=padding_low, padding_high=padding_high, + padding_interior=padding_interior, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_pad, (), dict(input=input, padding_value=padding_value, + padding_low=padding_low, + padding_high=padding_high, + padding_interior=padding_interior, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaPad", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaPad = tf_export("raw_ops.XlaPad")(_ops.to_raw_op(xla_pad)) +_dispatcher_for_xla_pad = xla_pad._tf_type_based_dispatcher.Dispatch + + +def xla_pad_eager_fallback(input: Annotated[Any, TV_XlaPad_T], padding_value: Annotated[Any, TV_XlaPad_T], padding_low: Annotated[Any, TV_XlaPad_Tindices], padding_high: Annotated[Any, TV_XlaPad_Tindices], padding_interior: Annotated[Any, TV_XlaPad_Tindices], name, ctx) -> Annotated[Any, TV_XlaPad_T]: + _attr_T, _inputs_T = _execute.args_to_matching_eager([input, padding_value], ctx, []) + (input, padding_value) = _inputs_T + _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([padding_low, padding_high, padding_interior], ctx, [_dtypes.int32, _dtypes.int64, ]) + (padding_low, padding_high, padding_interior) = _inputs_Tindices + _inputs_flat = [input, padding_value, padding_low, padding_high, padding_interior] + _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) + _result = _execute.execute(b"XlaPad", 1, inputs=_inputs_flat, attrs=_attrs, + ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaPad", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaRecv_dtype = TypeVar("TV_XlaRecv_dtype", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_recv') +def xla_recv(dtype: TV_XlaRecv_dtype, tensor_name: str, shape, name=None) -> Annotated[Any, TV_XlaRecv_dtype]: + r"""Receives the named tensor from another XLA computation. Wraps the XLA Recv + + operator documented at + https://www.tensorflow.org/performance/xla/operation_semantics#recv . + + Args: + dtype: A `tf.DType`. The type of the tensor. + tensor_name: A `string`. A string key that identifies the channel. + shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `dtype`. The tensor to receive. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaRecv", name, "dtype", dtype, "tensor_name", tensor_name, + "shape", shape) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_recv( + (dtype, tensor_name, shape, name,), None) + if _result is not NotImplemented: + return _result + return xla_recv_eager_fallback( + dtype=dtype, tensor_name=tensor_name, shape=shape, name=name, + ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_recv, (), dict(dtype=dtype, tensor_name=tensor_name, + shape=shape, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_recv( + (dtype, tensor_name, shape, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + dtype = _execute.make_type(dtype, "dtype") + tensor_name = _execute.make_str(tensor_name, "tensor_name") + shape = _execute.make_shape(shape, "shape") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaRecv", dtype=dtype, tensor_name=tensor_name, shape=shape, + name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_recv, (), dict(dtype=dtype, tensor_name=tensor_name, + shape=shape, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("dtype", _op._get_attr_type("dtype"), "tensor_name", + _op.get_attr("tensor_name"), "shape", _op.get_attr("shape")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaRecv", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaRecv = tf_export("raw_ops.XlaRecv")(_ops.to_raw_op(xla_recv)) +_dispatcher_for_xla_recv = xla_recv._tf_type_based_dispatcher.Dispatch + + +def xla_recv_eager_fallback(dtype: TV_XlaRecv_dtype, tensor_name: str, shape, name, ctx) -> Annotated[Any, TV_XlaRecv_dtype]: + dtype = _execute.make_type(dtype, "dtype") + tensor_name = _execute.make_str(tensor_name, "tensor_name") + shape = _execute.make_shape(shape, "shape") + _inputs_flat = [] + _attrs = ("dtype", dtype, "tensor_name", tensor_name, "shape", shape) + _result = _execute.execute(b"XlaRecv", 1, inputs=_inputs_flat, attrs=_attrs, + ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaRecv", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaReduce_T = TypeVar("TV_XlaReduce_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_reduce') +def xla_reduce(input: Annotated[Any, TV_XlaReduce_T], init_value: Annotated[Any, TV_XlaReduce_T], dimensions_to_reduce, reducer, name=None) -> Annotated[Any, TV_XlaReduce_T]: + r"""Wraps the XLA Reduce operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#reduce . + + Args: + input: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. + the input tensor + init_value: A `Tensor`. Must have the same type as `input`. + a scalar representing the initial value for the reduction + dimensions_to_reduce: A list of `ints`. + dimension numbers over which to reduce + reducer: A function decorated with @Defun. a reducer function to apply + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaReduce", name, input, init_value, "dimensions_to_reduce", + dimensions_to_reduce, "reducer", reducer) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_reduce( + (input, init_value, dimensions_to_reduce, reducer, name,), None) + if _result is not NotImplemented: + return _result + return xla_reduce_eager_fallback( + input, init_value, dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce, (), dict(input=input, init_value=init_value, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_reduce( + (input, init_value, dimensions_to_reduce, reducer, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + if not isinstance(dimensions_to_reduce, (list, tuple)): + raise TypeError( + "Expected list for 'dimensions_to_reduce' argument to " + "'xla_reduce' Op, not %r." % dimensions_to_reduce) + dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaReduce", input=input, init_value=init_value, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce, (), dict(input=input, init_value=init_value, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "dimensions_to_reduce", + _op.get_attr("dimensions_to_reduce"), "reducer", + _op.get_attr("reducer")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaReduce", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaReduce = tf_export("raw_ops.XlaReduce")(_ops.to_raw_op(xla_reduce)) +_dispatcher_for_xla_reduce = xla_reduce._tf_type_based_dispatcher.Dispatch + + +def xla_reduce_eager_fallback(input: Annotated[Any, TV_XlaReduce_T], init_value: Annotated[Any, TV_XlaReduce_T], dimensions_to_reduce, reducer, name, ctx) -> Annotated[Any, TV_XlaReduce_T]: + if not isinstance(dimensions_to_reduce, (list, tuple)): + raise TypeError( + "Expected list for 'dimensions_to_reduce' argument to " + "'xla_reduce' Op, not %r." % dimensions_to_reduce) + dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] + _attr_T, _inputs_T = _execute.args_to_matching_eager([input, init_value], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) + (input, init_value) = _inputs_T + _inputs_flat = [input, init_value] + _attrs = ("T", _attr_T, "dimensions_to_reduce", dimensions_to_reduce, + "reducer", reducer) + _result = _execute.execute(b"XlaReduce", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaReduce", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaReducePrecision_T = TypeVar("TV_XlaReducePrecision_T", _atypes.BFloat16, _atypes.Float32, _atypes.Float64, _atypes.Half) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_reduce_precision') +def xla_reduce_precision(operand: Annotated[Any, TV_XlaReducePrecision_T], exponent_bits: int, mantissa_bits: int, name=None) -> Annotated[Any, TV_XlaReducePrecision_T]: + r"""Wraps the XLA ReducePrecision operator + + documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision. + + Args: + operand: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`. + array of floating-point type. + exponent_bits: An `int`. number of exponent bits in lower-precision format + mantissa_bits: An `int`. number of mantissa bits in lower-precision format + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `operand`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaReducePrecision", name, operand, "exponent_bits", + exponent_bits, "mantissa_bits", mantissa_bits) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_reduce_precision( + (operand, exponent_bits, mantissa_bits, name,), None) + if _result is not NotImplemented: + return _result + return xla_reduce_precision_eager_fallback( + operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits, + name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce_precision, (), dict(operand=operand, + exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_reduce_precision( + (operand, exponent_bits, mantissa_bits, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + exponent_bits = _execute.make_int(exponent_bits, "exponent_bits") + mantissa_bits = _execute.make_int(mantissa_bits, "mantissa_bits") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaReducePrecision", operand=operand, exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce_precision, (), dict(operand=operand, + exponent_bits=exponent_bits, + mantissa_bits=mantissa_bits, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "exponent_bits", + _op._get_attr_int("exponent_bits"), "mantissa_bits", + _op._get_attr_int("mantissa_bits")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaReducePrecision", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaReducePrecision = tf_export("raw_ops.XlaReducePrecision")(_ops.to_raw_op(xla_reduce_precision)) +_dispatcher_for_xla_reduce_precision = xla_reduce_precision._tf_type_based_dispatcher.Dispatch + + +def xla_reduce_precision_eager_fallback(operand: Annotated[Any, TV_XlaReducePrecision_T], exponent_bits: int, mantissa_bits: int, name, ctx) -> Annotated[Any, TV_XlaReducePrecision_T]: + exponent_bits = _execute.make_int(exponent_bits, "exponent_bits") + mantissa_bits = _execute.make_int(mantissa_bits, "mantissa_bits") + _attr_T, (operand,) = _execute.args_to_matching_eager([operand], ctx, [_dtypes.bfloat16, _dtypes.half, _dtypes.float32, _dtypes.float64, ]) + _inputs_flat = [operand] + _attrs = ("T", _attr_T, "exponent_bits", exponent_bits, "mantissa_bits", + mantissa_bits) + _result = _execute.execute(b"XlaReducePrecision", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaReducePrecision", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaReduceScatter_T = TypeVar("TV_XlaReduceScatter_T", _atypes.BFloat16, _atypes.Float32, _atypes.Half, _atypes.Int32, _atypes.UInt32) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_reduce_scatter') +def xla_reduce_scatter(input: Annotated[Any, TV_XlaReduceScatter_T], group_assignment: Annotated[Any, _atypes.Int32], scatter_dimension: Annotated[Any, _atypes.Int32], reduce_op: str, name=None) -> Annotated[Any, TV_XlaReduceScatter_T]: + r"""Wraps the XLA ReduceScatter operator + + documented at https://www.tensorflow.org/xla/operation_semantics#reducescatter. + + Args: + input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`, `float32`, `int32`, `uint32`. + Array or a non-empty tuple of arrays to reduce across replicas. + group_assignment: A `Tensor` of type `int32`. + Groups between which the reductions are performed. + scatter_dimension: A `Tensor` of type `int32`. Dimension to scatter. + reduce_op: A `string` from: `"Min", "Max", "Mul", "Add", "Mean"`. + Reduction computation. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaReduceScatter", name, input, group_assignment, + scatter_dimension, "reduce_op", reduce_op) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_reduce_scatter( + (input, group_assignment, scatter_dimension, reduce_op, name,), None) + if _result is not NotImplemented: + return _result + return xla_reduce_scatter_eager_fallback( + input, group_assignment, scatter_dimension, reduce_op=reduce_op, + name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce_scatter, (), dict(input=input, + group_assignment=group_assignment, + scatter_dimension=scatter_dimension, + reduce_op=reduce_op, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_reduce_scatter( + (input, group_assignment, scatter_dimension, reduce_op, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + reduce_op = _execute.make_str(reduce_op, "reduce_op") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaReduceScatter", input=input, group_assignment=group_assignment, + scatter_dimension=scatter_dimension, + reduce_op=reduce_op, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce_scatter, (), dict(input=input, + group_assignment=group_assignment, + scatter_dimension=scatter_dimension, + reduce_op=reduce_op, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "reduce_op", + _op.get_attr("reduce_op")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaReduceScatter", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaReduceScatter = tf_export("raw_ops.XlaReduceScatter")(_ops.to_raw_op(xla_reduce_scatter)) +_dispatcher_for_xla_reduce_scatter = xla_reduce_scatter._tf_type_based_dispatcher.Dispatch + + +def xla_reduce_scatter_eager_fallback(input: Annotated[Any, TV_XlaReduceScatter_T], group_assignment: Annotated[Any, _atypes.Int32], scatter_dimension: Annotated[Any, _atypes.Int32], reduce_op: str, name, ctx) -> Annotated[Any, TV_XlaReduceScatter_T]: + reduce_op = _execute.make_str(reduce_op, "reduce_op") + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.half, _dtypes.bfloat16, _dtypes.float32, _dtypes.int32, _dtypes.uint32, ]) + group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32) + scatter_dimension = _ops.convert_to_tensor(scatter_dimension, _dtypes.int32) + _inputs_flat = [input, group_assignment, scatter_dimension] + _attrs = ("T", _attr_T, "reduce_op", reduce_op) + _result = _execute.execute(b"XlaReduceScatter", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaReduceScatter", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaReduceWindow_T = TypeVar("TV_XlaReduceWindow_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaReduceWindow_Tindices = TypeVar("TV_XlaReduceWindow_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_reduce_window') +def xla_reduce_window(input: Annotated[Any, TV_XlaReduceWindow_T], init_value: Annotated[Any, TV_XlaReduceWindow_T], window_dimensions: Annotated[Any, TV_XlaReduceWindow_Tindices], window_strides: Annotated[Any, TV_XlaReduceWindow_Tindices], base_dilations: Annotated[Any, TV_XlaReduceWindow_Tindices], window_dilations: Annotated[Any, TV_XlaReduceWindow_Tindices], padding: Annotated[Any, TV_XlaReduceWindow_Tindices], computation, name=None) -> Annotated[Any, TV_XlaReduceWindow_T]: + r"""Wraps the XLA ReduceWindow operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + + Args: + input: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. + the input tensor + init_value: A `Tensor`. Must have the same type as `input`. + a scalar representing the initial value for the reduction + window_dimensions: A `Tensor`. Must be one of the following types: `int32`, `int64`. + the shape of the window + window_strides: A `Tensor`. Must have the same type as `window_dimensions`. + the inter-window strides + base_dilations: A `Tensor`. Must have the same type as `window_dimensions`. + window_dilations: A `Tensor`. Must have the same type as `window_dimensions`. + padding: A `Tensor`. Must have the same type as `window_dimensions`. + the padding to apply at the start and end of each input dimensions + computation: A function decorated with @Defun. a reducer function to apply + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaReduceWindow", name, input, init_value, window_dimensions, + window_strides, base_dilations, window_dilations, padding, + "computation", computation) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_reduce_window( + (input, init_value, window_dimensions, window_strides, + base_dilations, window_dilations, padding, computation, name,), None) + if _result is not NotImplemented: + return _result + return xla_reduce_window_eager_fallback( + input, init_value, window_dimensions, window_strides, + base_dilations, window_dilations, padding, computation=computation, + name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce_window, (), dict(input=input, init_value=init_value, + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, + padding=padding, + computation=computation, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_reduce_window( + (input, init_value, window_dimensions, window_strides, base_dilations, + window_dilations, padding, computation, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaReduceWindow", input=input, init_value=init_value, + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, padding=padding, + computation=computation, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_reduce_window, (), dict(input=input, init_value=init_value, + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, + padding=padding, + computation=computation, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices"), "computation", + _op.get_attr("computation")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaReduceWindow", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaReduceWindow = tf_export("raw_ops.XlaReduceWindow")(_ops.to_raw_op(xla_reduce_window)) +_dispatcher_for_xla_reduce_window = xla_reduce_window._tf_type_based_dispatcher.Dispatch + + +def xla_reduce_window_eager_fallback(input: Annotated[Any, TV_XlaReduceWindow_T], init_value: Annotated[Any, TV_XlaReduceWindow_T], window_dimensions: Annotated[Any, TV_XlaReduceWindow_Tindices], window_strides: Annotated[Any, TV_XlaReduceWindow_Tindices], base_dilations: Annotated[Any, TV_XlaReduceWindow_Tindices], window_dilations: Annotated[Any, TV_XlaReduceWindow_Tindices], padding: Annotated[Any, TV_XlaReduceWindow_Tindices], computation, name, ctx) -> Annotated[Any, TV_XlaReduceWindow_T]: + _attr_T, _inputs_T = _execute.args_to_matching_eager([input, init_value], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) + (input, init_value) = _inputs_T + _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_dimensions, window_strides, base_dilations, window_dilations, padding], ctx, [_dtypes.int32, _dtypes.int64, ]) + (window_dimensions, window_strides, base_dilations, window_dilations, padding) = _inputs_Tindices + _inputs_flat = [input, init_value, window_dimensions, window_strides, base_dilations, window_dilations, padding] + _attrs = ("T", _attr_T, "Tindices", _attr_Tindices, "computation", + computation) + _result = _execute.execute(b"XlaReduceWindow", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaReduceWindow", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaRemoveDynamicDimensionSize_T = TypeVar("TV_XlaRemoveDynamicDimensionSize_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_remove_dynamic_dimension_size') +def xla_remove_dynamic_dimension_size(input: Annotated[Any, TV_XlaRemoveDynamicDimensionSize_T], dim_index: Annotated[Any, _atypes.Int32], name=None) -> Annotated[Any, TV_XlaRemoveDynamicDimensionSize_T]: + r"""Inverse of XlaSetDynamicDimensionSize. + + Make an xla bounded dynamic dimension into a static dimension. The bound of the + size of dimension `dim_index` becomes the static dimension size. + + Args: + input: A `Tensor`. + dim_index: A `Tensor` of type `int32`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaRemoveDynamicDimensionSize", name, input, dim_index) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_remove_dynamic_dimension_size( + (input, dim_index, name,), None) + if _result is not NotImplemented: + return _result + return xla_remove_dynamic_dimension_size_eager_fallback( + input, dim_index, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_remove_dynamic_dimension_size, (), dict(input=input, + dim_index=dim_index, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_remove_dynamic_dimension_size( + (input, dim_index, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaRemoveDynamicDimensionSize", input=input, dim_index=dim_index, + name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_remove_dynamic_dimension_size, (), dict(input=input, + dim_index=dim_index, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaRemoveDynamicDimensionSize", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaRemoveDynamicDimensionSize = tf_export("raw_ops.XlaRemoveDynamicDimensionSize")(_ops.to_raw_op(xla_remove_dynamic_dimension_size)) +_dispatcher_for_xla_remove_dynamic_dimension_size = xla_remove_dynamic_dimension_size._tf_type_based_dispatcher.Dispatch + + +def xla_remove_dynamic_dimension_size_eager_fallback(input: Annotated[Any, TV_XlaRemoveDynamicDimensionSize_T], dim_index: Annotated[Any, _atypes.Int32], name, ctx) -> Annotated[Any, TV_XlaRemoveDynamicDimensionSize_T]: + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) + dim_index = _ops.convert_to_tensor(dim_index, _dtypes.int32) + _inputs_flat = [input, dim_index] + _attrs = ("T", _attr_T) + _result = _execute.execute(b"XlaRemoveDynamicDimensionSize", 1, + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaRemoveDynamicDimensionSize", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_replica_id') +def xla_replica_id(name=None) -> Annotated[Any, _atypes.Int32]: + r"""Replica ID. + + Args: + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `int32`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaReplicaId", name) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_replica_id( + (name,), None) + if _result is not NotImplemented: + return _result + return xla_replica_id_eager_fallback( + name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_replica_id, (), dict(name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_replica_id( + (name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaReplicaId", name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_replica_id, (), dict(name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = () + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaReplicaId", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaReplicaId = tf_export("raw_ops.XlaReplicaId")(_ops.to_raw_op(xla_replica_id)) +_dispatcher_for_xla_replica_id = xla_replica_id._tf_type_based_dispatcher.Dispatch + + +def xla_replica_id_eager_fallback(name, ctx) -> Annotated[Any, _atypes.Int32]: + _inputs_flat = [] + _attrs = None + _result = _execute.execute(b"XlaReplicaId", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaReplicaId", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +_XlaRngBitGeneratorOutput = collections.namedtuple( + "XlaRngBitGenerator", + ["output_key", "output"]) + + +TV_XlaRngBitGenerator_dtype = TypeVar("TV_XlaRngBitGenerator_dtype", _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaRngBitGenerator_Tshape = TypeVar("TV_XlaRngBitGenerator_Tshape", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_rng_bit_generator') +def xla_rng_bit_generator(algorithm: Annotated[Any, _atypes.Int32], initial_state: Annotated[Any, _atypes.UInt64], shape: Annotated[Any, TV_XlaRngBitGenerator_Tshape], dtype:TV_XlaRngBitGenerator_dtype=_dtypes.uint64, name=None): + r"""Stateless PRNG bit generator. + + Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + + Args: + algorithm: A `Tensor` of type `int32`. The PRNG algorithm to use, one of + tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. + initial_state: A `Tensor` of type `uint64`. + Initial state for the PRNG algorithm. For THREEFRY, it should be + a u64[2] and for PHILOX a u64[3]. + shape: A `Tensor`. Must be one of the following types: `int32`, `int64`. + The output shape of the generated data. + dtype: An optional `tf.DType` from: `tf.uint8, tf.int8, tf.int32, tf.int64, tf.uint32, tf.uint64`. Defaults to `tf.uint64`. + The type of the tensor. + name: A name for the operation (optional). + + Returns: + A tuple of `Tensor` objects (output_key, output). + + output_key: A `Tensor` of type `uint64`. + output: A `Tensor` of type `dtype`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaRngBitGenerator", name, algorithm, initial_state, shape, + "dtype", dtype) + _result = _XlaRngBitGeneratorOutput._make(_result) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_rng_bit_generator( + (algorithm, initial_state, shape, dtype, name,), None) + if _result is not NotImplemented: + return _result + return xla_rng_bit_generator_eager_fallback( + algorithm, initial_state, shape, dtype=dtype, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_rng_bit_generator, (), dict(algorithm=algorithm, + initial_state=initial_state, + shape=shape, dtype=dtype, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_rng_bit_generator( + (algorithm, initial_state, shape, dtype, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + if dtype is None: + dtype = _dtypes.uint64 + dtype = _execute.make_type(dtype, "dtype") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaRngBitGenerator", algorithm=algorithm, + initial_state=initial_state, shape=shape, + dtype=dtype, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_rng_bit_generator, (), dict(algorithm=algorithm, + initial_state=initial_state, + shape=shape, dtype=dtype, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("dtype", _op._get_attr_type("dtype"), "Tshape", + _op._get_attr_type("Tshape")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaRngBitGenerator", _inputs_flat, _attrs, _result) + _result = _XlaRngBitGeneratorOutput._make(_result) + return _result + +XlaRngBitGenerator = tf_export("raw_ops.XlaRngBitGenerator")(_ops.to_raw_op(xla_rng_bit_generator)) +_dispatcher_for_xla_rng_bit_generator = xla_rng_bit_generator._tf_type_based_dispatcher.Dispatch + + +def xla_rng_bit_generator_eager_fallback(algorithm: Annotated[Any, _atypes.Int32], initial_state: Annotated[Any, _atypes.UInt64], shape: Annotated[Any, TV_XlaRngBitGenerator_Tshape], dtype: TV_XlaRngBitGenerator_dtype, name, ctx): + if dtype is None: + dtype = _dtypes.uint64 + dtype = _execute.make_type(dtype, "dtype") + _attr_Tshape, (shape,) = _execute.args_to_matching_eager([shape], ctx, [_dtypes.int32, _dtypes.int64, ], _dtypes.int32) + algorithm = _ops.convert_to_tensor(algorithm, _dtypes.int32) + initial_state = _ops.convert_to_tensor(initial_state, _dtypes.uint64) + _inputs_flat = [algorithm, initial_state, shape] + _attrs = ("dtype", dtype, "Tshape", _attr_Tshape) + _result = _execute.execute(b"XlaRngBitGenerator", 2, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaRngBitGenerator", _inputs_flat, _attrs, _result) + _result = _XlaRngBitGeneratorOutput._make(_result) + return _result + + +TV_XlaScatter_T = TypeVar("TV_XlaScatter_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaScatter_Tindices = TypeVar("TV_XlaScatter_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_scatter') +def xla_scatter(operand: Annotated[Any, TV_XlaScatter_T], scatter_indices: Annotated[Any, TV_XlaScatter_Tindices], updates: Annotated[Any, TV_XlaScatter_T], update_computation, dimension_numbers: str, indices_are_sorted: bool, name=None) -> Annotated[Any, TV_XlaScatter_T]: + r"""Wraps the XLA Scatter operator documented at + + https://www.tensorflow.org/xla/operation_semantics#scatter. + + Args: + operand: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. + Array to be scattered into. + scatter_indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. + Array containing the starting indices of the slices that must + be scattered to. + updates: A `Tensor`. Must have the same type as `operand`. + Array containing the values that must be used for scattering. + update_computation: A function decorated with @Defun. + Computation to be used for combining the existing values in + the input array and the updates during scatter. + dimension_numbers: A `string`. + A serialized xla::ScatterDimensionNumbers proto. + indices_are_sorted: A `bool`. + Boolean indicating if the indices are sorted. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `operand`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaScatter", name, operand, scatter_indices, updates, + "update_computation", update_computation, "dimension_numbers", + dimension_numbers, "indices_are_sorted", indices_are_sorted) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_scatter( + (operand, scatter_indices, updates, update_computation, + dimension_numbers, indices_are_sorted, name,), None) + if _result is not NotImplemented: + return _result + return xla_scatter_eager_fallback( + operand, scatter_indices, updates, + update_computation=update_computation, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_scatter, (), dict(operand=operand, + scatter_indices=scatter_indices, + updates=updates, + update_computation=update_computation, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_scatter( + (operand, scatter_indices, updates, update_computation, + dimension_numbers, indices_are_sorted, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaScatter", operand=operand, scatter_indices=scatter_indices, + updates=updates, update_computation=update_computation, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_scatter, (), dict(operand=operand, + scatter_indices=scatter_indices, + updates=updates, + update_computation=update_computation, + dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("update_computation", _op.get_attr("update_computation"), + "dimension_numbers", _op.get_attr("dimension_numbers"), + "indices_are_sorted", _op._get_attr_bool("indices_are_sorted"), + "T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaScatter", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaScatter = tf_export("raw_ops.XlaScatter")(_ops.to_raw_op(xla_scatter)) +_dispatcher_for_xla_scatter = xla_scatter._tf_type_based_dispatcher.Dispatch + + +def xla_scatter_eager_fallback(operand: Annotated[Any, TV_XlaScatter_T], scatter_indices: Annotated[Any, TV_XlaScatter_Tindices], updates: Annotated[Any, TV_XlaScatter_T], update_computation, dimension_numbers: str, indices_are_sorted: bool, name, ctx) -> Annotated[Any, TV_XlaScatter_T]: + dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") + indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") + _attr_T, _inputs_T = _execute.args_to_matching_eager([operand, updates], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) + (operand, updates) = _inputs_T + _attr_Tindices, (scatter_indices,) = _execute.args_to_matching_eager([scatter_indices], ctx, [_dtypes.int32, _dtypes.int64, ]) + _inputs_flat = [operand, scatter_indices, updates] + _attrs = ("update_computation", update_computation, "dimension_numbers", + dimension_numbers, "indices_are_sorted", indices_are_sorted, "T", _attr_T, + "Tindices", _attr_Tindices) + _result = _execute.execute(b"XlaScatter", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaScatter", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaSelectAndScatter_T = TypeVar("TV_XlaSelectAndScatter_T", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) +TV_XlaSelectAndScatter_Tindices = TypeVar("TV_XlaSelectAndScatter_Tindices", _atypes.Int32, _atypes.Int64) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_select_and_scatter') +def xla_select_and_scatter(operand: Annotated[Any, TV_XlaSelectAndScatter_T], window_dimensions: Annotated[Any, TV_XlaSelectAndScatter_Tindices], window_strides: Annotated[Any, TV_XlaSelectAndScatter_Tindices], padding: Annotated[Any, TV_XlaSelectAndScatter_Tindices], source: Annotated[Any, TV_XlaSelectAndScatter_T], init_value: Annotated[Any, TV_XlaSelectAndScatter_T], select, scatter, name=None) -> Annotated[Any, TV_XlaSelectAndScatter_T]: + r"""Wraps the XLA SelectAndScatter operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter + . + + Args: + operand: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the input tensor + window_dimensions: A `Tensor`. Must be one of the following types: `int32`, `int64`. + the shape of the window + window_strides: A `Tensor`. Must have the same type as `window_dimensions`. + the inter-window strides + padding: A `Tensor`. Must have the same type as `window_dimensions`. + the padding to apply at the start and end of each input dimensions + source: A `Tensor`. Must have the same type as `operand`. + a tensor of values to scatter + init_value: A `Tensor`. Must have the same type as `operand`. + a scalar representing the initial value for the output tensor + select: A function decorated with @Defun. a selection function to apply + scatter: A function decorated with @Defun. a scatter function to apply + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `operand`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSelectAndScatter", name, operand, window_dimensions, + window_strides, padding, source, init_value, "select", select, + "scatter", scatter) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_select_and_scatter( + (operand, window_dimensions, window_strides, padding, source, + init_value, select, scatter, name,), None) + if _result is not NotImplemented: + return _result + return xla_select_and_scatter_eager_fallback( + operand, window_dimensions, window_strides, padding, source, + init_value, select=select, scatter=scatter, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_select_and_scatter, (), dict(operand=operand, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, source=source, + init_value=init_value, + select=select, scatter=scatter, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_select_and_scatter( + (operand, window_dimensions, window_strides, padding, source, + init_value, select, scatter, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSelectAndScatter", operand=operand, + window_dimensions=window_dimensions, + window_strides=window_strides, padding=padding, + source=source, init_value=init_value, + select=select, scatter=scatter, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_select_and_scatter, (), dict(operand=operand, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding=padding, source=source, + init_value=init_value, + select=select, scatter=scatter, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "Tindices", + _op._get_attr_type("Tindices"), "select", + _op.get_attr("select"), "scatter", _op.get_attr("scatter")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSelectAndScatter", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaSelectAndScatter = tf_export("raw_ops.XlaSelectAndScatter")(_ops.to_raw_op(xla_select_and_scatter)) +_dispatcher_for_xla_select_and_scatter = xla_select_and_scatter._tf_type_based_dispatcher.Dispatch + + +def xla_select_and_scatter_eager_fallback(operand: Annotated[Any, TV_XlaSelectAndScatter_T], window_dimensions: Annotated[Any, TV_XlaSelectAndScatter_Tindices], window_strides: Annotated[Any, TV_XlaSelectAndScatter_Tindices], padding: Annotated[Any, TV_XlaSelectAndScatter_Tindices], source: Annotated[Any, TV_XlaSelectAndScatter_T], init_value: Annotated[Any, TV_XlaSelectAndScatter_T], select, scatter, name, ctx) -> Annotated[Any, TV_XlaSelectAndScatter_T]: + _attr_T, _inputs_T = _execute.args_to_matching_eager([operand, source, init_value], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + (operand, source, init_value) = _inputs_T + _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_dimensions, window_strides, padding], ctx, [_dtypes.int32, _dtypes.int64, ]) + (window_dimensions, window_strides, padding) = _inputs_Tindices + _inputs_flat = [operand, window_dimensions, window_strides, padding, source, init_value] + _attrs = ("T", _attr_T, "Tindices", _attr_Tindices, "select", select, + "scatter", scatter) + _result = _execute.execute(b"XlaSelectAndScatter", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSelectAndScatter", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +_XlaSelfAdjointEigOutput = collections.namedtuple( + "XlaSelfAdjointEig", + ["w", "v"]) + + +TV_XlaSelfAdjointEig_T = TypeVar("TV_XlaSelfAdjointEig_T", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_self_adjoint_eig') +def xla_self_adjoint_eig(a: Annotated[Any, TV_XlaSelfAdjointEig_T], lower: bool, max_iter: int, epsilon: float, name=None): + r"""Computes the eigen decomposition of a batch of self-adjoint matrices + + (Note: Only real inputs are supported). + + Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in + tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for + i=0...N-1. + + Args: + a: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the input tensor. + lower: A `bool`. + a boolean specifies whether the calculation is done with the lower + triangular part or the upper triangular part. + max_iter: An `int`. + maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximately logN sweeps are needed in practice (Ref: Golub & + van Loan "Matrix Computation"). + epsilon: A `float`. the tolerance ratio. + name: A name for the operation (optional). + + Returns: + A tuple of `Tensor` objects (w, v). + + w: A `Tensor`. Has the same type as `a`. The eigenvalues in ascending order, each repeated according to its + multiplicity. + v: A `Tensor`. Has the same type as `a`. The column v[..., :, i] is the normalized eigenvector corresponding to the + eigenvalue w[..., i]. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSelfAdjointEig", name, a, "lower", lower, "max_iter", + max_iter, "epsilon", epsilon) + _result = _XlaSelfAdjointEigOutput._make(_result) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_self_adjoint_eig( + (a, lower, max_iter, epsilon, name,), None) + if _result is not NotImplemented: + return _result + return xla_self_adjoint_eig_eager_fallback( + a, lower=lower, max_iter=max_iter, epsilon=epsilon, name=name, + ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_self_adjoint_eig, (), dict(a=a, lower=lower, + max_iter=max_iter, epsilon=epsilon, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_self_adjoint_eig( + (a, lower, max_iter, epsilon, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + lower = _execute.make_bool(lower, "lower") + max_iter = _execute.make_int(max_iter, "max_iter") + epsilon = _execute.make_float(epsilon, "epsilon") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSelfAdjointEig", a=a, lower=lower, max_iter=max_iter, + epsilon=epsilon, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_self_adjoint_eig, (), dict(a=a, lower=lower, max_iter=max_iter, + epsilon=epsilon, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("lower", _op._get_attr_bool("lower"), "max_iter", + _op._get_attr_int("max_iter"), "epsilon", + _op.get_attr("epsilon"), "T", _op._get_attr_type("T")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSelfAdjointEig", _inputs_flat, _attrs, _result) + _result = _XlaSelfAdjointEigOutput._make(_result) + return _result + +XlaSelfAdjointEig = tf_export("raw_ops.XlaSelfAdjointEig")(_ops.to_raw_op(xla_self_adjoint_eig)) +_dispatcher_for_xla_self_adjoint_eig = xla_self_adjoint_eig._tf_type_based_dispatcher.Dispatch + + +def xla_self_adjoint_eig_eager_fallback(a: Annotated[Any, TV_XlaSelfAdjointEig_T], lower: bool, max_iter: int, epsilon: float, name, ctx): + lower = _execute.make_bool(lower, "lower") + max_iter = _execute.make_int(max_iter, "max_iter") + epsilon = _execute.make_float(epsilon, "epsilon") + _attr_T, (a,) = _execute.args_to_matching_eager([a], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + _inputs_flat = [a] + _attrs = ("lower", lower, "max_iter", max_iter, "epsilon", epsilon, "T", + _attr_T) + _result = _execute.execute(b"XlaSelfAdjointEig", 2, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSelfAdjointEig", _inputs_flat, _attrs, _result) + _result = _XlaSelfAdjointEigOutput._make(_result) + return _result + + +TV_XlaSend_T = TypeVar("TV_XlaSend_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_send') +def xla_send(tensor: Annotated[Any, TV_XlaSend_T], tensor_name: str, name=None): + r"""Sends the named tensor to another XLA computation. Wraps the XLA Send operator + + documented at + https://www.tensorflow.org/performance/xla/operation_semantics#send . + + Args: + tensor: A `Tensor`. The tensor to send. + tensor_name: A `string`. A string key that identifies the channel. + name: A name for the operation (optional). + + Returns: + The created Operation. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSend", name, tensor, "tensor_name", tensor_name) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_send( + (tensor, tensor_name, name,), None) + if _result is not NotImplemented: + return _result + return xla_send_eager_fallback( + tensor, tensor_name=tensor_name, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_send, (), dict(tensor=tensor, tensor_name=tensor_name, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_send( + (tensor, tensor_name, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + tensor_name = _execute.make_str(tensor_name, "tensor_name") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSend", tensor=tensor, tensor_name=tensor_name, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_send, (), dict(tensor=tensor, tensor_name=tensor_name, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + return _op +XlaSend = tf_export("raw_ops.XlaSend")(_ops.to_raw_op(xla_send)) +_dispatcher_for_xla_send = xla_send._tf_type_based_dispatcher.Dispatch + + +def xla_send_eager_fallback(tensor: Annotated[Any, TV_XlaSend_T], tensor_name: str, name, ctx): + tensor_name = _execute.make_str(tensor_name, "tensor_name") + _attr_T, (tensor,) = _execute.args_to_matching_eager([tensor], ctx, []) + _inputs_flat = [tensor] + _attrs = ("T", _attr_T, "tensor_name", tensor_name) + _result = _execute.execute(b"XlaSend", 0, inputs=_inputs_flat, attrs=_attrs, + ctx=ctx, name=name) + _result = None + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_set_bound') +def xla_set_bound(input: Annotated[Any, _atypes.Int32], bound: Annotated[Any, _atypes.Int32], name=None) -> Annotated[Any, _atypes.Int32]: + r"""Set a bound for the given input value as a hint to Xla compiler, + + returns the same value. + + Args: + input: A `Tensor` of type `int32`. + bound: A `Tensor` of type `int32`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of type `int32`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSetBound", name, input, bound) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_set_bound( + (input, bound, name,), None) + if _result is not NotImplemented: + return _result + return xla_set_bound_eager_fallback( + input, bound, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_set_bound, (), dict(input=input, bound=bound, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_set_bound( + (input, bound, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSetBound", input=input, bound=bound, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_set_bound, (), dict(input=input, bound=bound, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = () + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSetBound", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaSetBound = tf_export("raw_ops.XlaSetBound")(_ops.to_raw_op(xla_set_bound)) +_dispatcher_for_xla_set_bound = xla_set_bound._tf_type_based_dispatcher.Dispatch + + +def xla_set_bound_eager_fallback(input: Annotated[Any, _atypes.Int32], bound: Annotated[Any, _atypes.Int32], name, ctx) -> Annotated[Any, _atypes.Int32]: + input = _ops.convert_to_tensor(input, _dtypes.int32) + bound = _ops.convert_to_tensor(bound, _dtypes.int32) + _inputs_flat = [input, bound] + _attrs = None + _result = _execute.execute(b"XlaSetBound", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSetBound", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaSetDynamicDimensionSize_T = TypeVar("TV_XlaSetDynamicDimensionSize_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_set_dynamic_dimension_size') +def xla_set_dynamic_dimension_size(input: Annotated[Any, TV_XlaSetDynamicDimensionSize_T], dim_index: Annotated[Any, _atypes.Int32], size: Annotated[Any, _atypes.Int32], name=None) -> Annotated[Any, TV_XlaSetDynamicDimensionSize_T]: + r"""Make a static dimension into a xla bounded dynamic dimension. + + The current static dimension size will become the bound and the second + operand becomes the dynamic size of the dimension. + + Args: + input: A `Tensor`. + dim_index: A `Tensor` of type `int32`. + size: A `Tensor` of type `int32`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSetDynamicDimensionSize", name, input, dim_index, size) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_set_dynamic_dimension_size( + (input, dim_index, size, name,), None) + if _result is not NotImplemented: + return _result + return xla_set_dynamic_dimension_size_eager_fallback( + input, dim_index, size, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_set_dynamic_dimension_size, (), dict(input=input, + dim_index=dim_index, + size=size, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_set_dynamic_dimension_size( + (input, dim_index, size, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSetDynamicDimensionSize", input=input, dim_index=dim_index, + size=size, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_set_dynamic_dimension_size, (), dict(input=input, + dim_index=dim_index, + size=size, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSetDynamicDimensionSize", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaSetDynamicDimensionSize = tf_export("raw_ops.XlaSetDynamicDimensionSize")(_ops.to_raw_op(xla_set_dynamic_dimension_size)) +_dispatcher_for_xla_set_dynamic_dimension_size = xla_set_dynamic_dimension_size._tf_type_based_dispatcher.Dispatch + + +def xla_set_dynamic_dimension_size_eager_fallback(input: Annotated[Any, TV_XlaSetDynamicDimensionSize_T], dim_index: Annotated[Any, _atypes.Int32], size: Annotated[Any, _atypes.Int32], name, ctx) -> Annotated[Any, TV_XlaSetDynamicDimensionSize_T]: + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) + dim_index = _ops.convert_to_tensor(dim_index, _dtypes.int32) + size = _ops.convert_to_tensor(size, _dtypes.int32) + _inputs_flat = [input, dim_index, size] + _attrs = ("T", _attr_T) + _result = _execute.execute(b"XlaSetDynamicDimensionSize", 1, + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSetDynamicDimensionSize", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaSharding_T = TypeVar("TV_XlaSharding_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_sharding') +def xla_sharding(input: Annotated[Any, TV_XlaSharding_T], sharding:str="", unspecified_dims=[], name=None) -> Annotated[Any, TV_XlaSharding_T]: + r"""An op which shards the input based on the given sharding attribute. It can + + selectively annotate a subset of tensor dimensions by skipping unspecified_dims, + and the sharding annotation should be replicated in those dims. + + Args: + input: A `Tensor`. + sharding: An optional `string`. Defaults to `""`. + unspecified_dims: An optional list of `ints`. Defaults to `[]`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSharding", name, input, "sharding", sharding, + "unspecified_dims", unspecified_dims) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_sharding( + (input, sharding, unspecified_dims, name,), None) + if _result is not NotImplemented: + return _result + return xla_sharding_eager_fallback( + input, sharding=sharding, unspecified_dims=unspecified_dims, + name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_sharding, (), dict(input=input, sharding=sharding, + unspecified_dims=unspecified_dims, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_sharding( + (input, sharding, unspecified_dims, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + if sharding is None: + sharding = "" + sharding = _execute.make_str(sharding, "sharding") + if unspecified_dims is None: + unspecified_dims = [] + if not isinstance(unspecified_dims, (list, tuple)): + raise TypeError( + "Expected list for 'unspecified_dims' argument to " + "'xla_sharding' Op, not %r." % unspecified_dims) + unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSharding", input=input, sharding=sharding, + unspecified_dims=unspecified_dims, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_sharding, (), dict(input=input, sharding=sharding, + unspecified_dims=unspecified_dims, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "sharding", + _op.get_attr("sharding"), "unspecified_dims", + _op.get_attr("unspecified_dims")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSharding", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaSharding = tf_export("raw_ops.XlaSharding")(_ops.to_raw_op(xla_sharding)) +_dispatcher_for_xla_sharding = xla_sharding._tf_type_based_dispatcher.Dispatch + + +def xla_sharding_eager_fallback(input: Annotated[Any, TV_XlaSharding_T], sharding: str, unspecified_dims, name, ctx) -> Annotated[Any, TV_XlaSharding_T]: + if sharding is None: + sharding = "" + sharding = _execute.make_str(sharding, "sharding") + if unspecified_dims is None: + unspecified_dims = [] + if not isinstance(unspecified_dims, (list, tuple)): + raise TypeError( + "Expected list for 'unspecified_dims' argument to " + "'xla_sharding' Op, not %r." % unspecified_dims) + unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) + _inputs_flat = [input] + _attrs = ("T", _attr_T, "sharding", sharding, "unspecified_dims", + unspecified_dims) + _result = _execute.execute(b"XlaSharding", 1, inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSharding", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaSort_T = TypeVar("TV_XlaSort_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_sort') +def xla_sort(input: Annotated[Any, TV_XlaSort_T], name=None) -> Annotated[Any, TV_XlaSort_T]: + r"""Wraps the XLA Sort operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#sort + . + + Sorts a tensor. Currently only sorts in ascending order are supported. + + Args: + input: A `Tensor`. A `Tensor` of type T. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. A `Tensor` of type T. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSort", name, input) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_sort( + (input, name,), None) + if _result is not NotImplemented: + return _result + return xla_sort_eager_fallback( + input, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_sort, (), dict(input=input, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_sort( + (input, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSort", input=input, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_sort, (), dict(input=input, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSort", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaSort = tf_export("raw_ops.XlaSort")(_ops.to_raw_op(xla_sort)) +_dispatcher_for_xla_sort = xla_sort._tf_type_based_dispatcher.Dispatch + + +def xla_sort_eager_fallback(input: Annotated[Any, TV_XlaSort_T], name, ctx) -> Annotated[Any, TV_XlaSort_T]: + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) + _inputs_flat = [input] + _attrs = ("T", _attr_T) + _result = _execute.execute(b"XlaSort", 1, inputs=_inputs_flat, attrs=_attrs, + ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSort", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaSpmdFullToShardShape_T = TypeVar("TV_XlaSpmdFullToShardShape_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_spmd_full_to_shard_shape') +def xla_spmd_full_to_shard_shape(input: Annotated[Any, TV_XlaSpmdFullToShardShape_T], manual_sharding: str, dim:int=-1, unspecified_dims=[], name=None) -> Annotated[Any, TV_XlaSpmdFullToShardShape_T]: + r"""An op used by XLA SPMD partitioner to switch from automatic partitioning to + + manual partitioning. It annotates the input (full-shape, to be automatically + partitioned) with the same sharding used by manual partitioning, and outputs a + shard-shaped tensor to be consumed by later manually-partitioned ops. If the + shape is not evenly partitionable, the padding region will be masked with 0s. + The conversion can happen partially in subgroups, by specifying the dim + attribute, where only that dim will be converted. + + Args: + input: A `Tensor`. + manual_sharding: A `string`. + dim: An optional `int`. Defaults to `-1`. + unspecified_dims: An optional list of `ints`. Defaults to `[]`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSpmdFullToShardShape", name, input, "manual_sharding", + manual_sharding, "dim", dim, "unspecified_dims", unspecified_dims) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_spmd_full_to_shard_shape( + (input, manual_sharding, dim, unspecified_dims, name,), None) + if _result is not NotImplemented: + return _result + return xla_spmd_full_to_shard_shape_eager_fallback( + input, manual_sharding=manual_sharding, dim=dim, + unspecified_dims=unspecified_dims, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_spmd_full_to_shard_shape, (), dict(input=input, + manual_sharding=manual_sharding, + dim=dim, + unspecified_dims=unspecified_dims, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_spmd_full_to_shard_shape( + (input, manual_sharding, dim, unspecified_dims, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") + if dim is None: + dim = -1 + dim = _execute.make_int(dim, "dim") + if unspecified_dims is None: + unspecified_dims = [] + if not isinstance(unspecified_dims, (list, tuple)): + raise TypeError( + "Expected list for 'unspecified_dims' argument to " + "'xla_spmd_full_to_shard_shape' Op, not %r." % unspecified_dims) + unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSpmdFullToShardShape", input=input, + manual_sharding=manual_sharding, dim=dim, + unspecified_dims=unspecified_dims, + name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_spmd_full_to_shard_shape, (), dict(input=input, + manual_sharding=manual_sharding, + dim=dim, + unspecified_dims=unspecified_dims, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "manual_sharding", + _op.get_attr("manual_sharding"), "dim", + _op._get_attr_int("dim"), "unspecified_dims", + _op.get_attr("unspecified_dims")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSpmdFullToShardShape", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaSpmdFullToShardShape = tf_export("raw_ops.XlaSpmdFullToShardShape")(_ops.to_raw_op(xla_spmd_full_to_shard_shape)) +_dispatcher_for_xla_spmd_full_to_shard_shape = xla_spmd_full_to_shard_shape._tf_type_based_dispatcher.Dispatch + + +def xla_spmd_full_to_shard_shape_eager_fallback(input: Annotated[Any, TV_XlaSpmdFullToShardShape_T], manual_sharding: str, dim: int, unspecified_dims, name, ctx) -> Annotated[Any, TV_XlaSpmdFullToShardShape_T]: + manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") + if dim is None: + dim = -1 + dim = _execute.make_int(dim, "dim") + if unspecified_dims is None: + unspecified_dims = [] + if not isinstance(unspecified_dims, (list, tuple)): + raise TypeError( + "Expected list for 'unspecified_dims' argument to " + "'xla_spmd_full_to_shard_shape' Op, not %r." % unspecified_dims) + unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) + _inputs_flat = [input] + _attrs = ("T", _attr_T, "manual_sharding", manual_sharding, "dim", dim, + "unspecified_dims", unspecified_dims) + _result = _execute.execute(b"XlaSpmdFullToShardShape", 1, + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSpmdFullToShardShape", _inputs_flat, _attrs, _result) + _result, = _result + return _result + + +TV_XlaSpmdShardToFullShape_T = TypeVar("TV_XlaSpmdShardToFullShape_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float16, _atypes.Float32, _atypes.Float64, _atypes.Float8e4m3fn, _atypes.Float8e5m2, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int4, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.Resource, _atypes.String, _atypes.UInt16, _atypes.UInt32, _atypes.UInt4, _atypes.UInt64, _atypes.UInt8, _atypes.Variant) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_spmd_shard_to_full_shape') +def xla_spmd_shard_to_full_shape(input: Annotated[Any, TV_XlaSpmdShardToFullShape_T], manual_sharding: str, full_shape, dim:int=-1, unspecified_dims=[], name=None) -> Annotated[Any, TV_XlaSpmdShardToFullShape_T]: + r"""An op used by XLA SPMD partitioner to switch from manual partitioning to + + automatic partitioning. It converts the shard-shaped, manually partitioned input + into full-shaped tensor to be partitioned automatically with the same sharding + used by manual partitioning. The conversion can happen partially in subgroups, + by specifying the dim attribute, where only that dim will be converted. + + Args: + input: A `Tensor`. + manual_sharding: A `string`. + full_shape: A `tf.TensorShape` or list of `ints`. + dim: An optional `int`. Defaults to `-1`. + unspecified_dims: An optional list of `ints`. Defaults to `[]`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSpmdShardToFullShape", name, input, "manual_sharding", + manual_sharding, "full_shape", full_shape, "dim", dim, + "unspecified_dims", unspecified_dims) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_spmd_shard_to_full_shape( + (input, manual_sharding, full_shape, dim, unspecified_dims, name,), + None) + if _result is not NotImplemented: + return _result + return xla_spmd_shard_to_full_shape_eager_fallback( + input, manual_sharding=manual_sharding, full_shape=full_shape, + dim=dim, unspecified_dims=unspecified_dims, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_spmd_shard_to_full_shape, (), dict(input=input, + manual_sharding=manual_sharding, + full_shape=full_shape, + dim=dim, + unspecified_dims=unspecified_dims, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_spmd_shard_to_full_shape( + (input, manual_sharding, full_shape, dim, unspecified_dims, name,), + None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") + full_shape = _execute.make_shape(full_shape, "full_shape") + if dim is None: + dim = -1 + dim = _execute.make_int(dim, "dim") + if unspecified_dims is None: + unspecified_dims = [] + if not isinstance(unspecified_dims, (list, tuple)): + raise TypeError( + "Expected list for 'unspecified_dims' argument to " + "'xla_spmd_shard_to_full_shape' Op, not %r." % unspecified_dims) + unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSpmdShardToFullShape", input=input, + manual_sharding=manual_sharding, + full_shape=full_shape, dim=dim, + unspecified_dims=unspecified_dims, + name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_spmd_shard_to_full_shape, (), dict(input=input, + manual_sharding=manual_sharding, + full_shape=full_shape, + dim=dim, + unspecified_dims=unspecified_dims, + name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op._get_attr_type("T"), "manual_sharding", + _op.get_attr("manual_sharding"), "full_shape", + _op.get_attr("full_shape"), "dim", _op._get_attr_int("dim"), + "unspecified_dims", _op.get_attr("unspecified_dims")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSpmdShardToFullShape", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +XlaSpmdShardToFullShape = tf_export("raw_ops.XlaSpmdShardToFullShape")(_ops.to_raw_op(xla_spmd_shard_to_full_shape)) +_dispatcher_for_xla_spmd_shard_to_full_shape = xla_spmd_shard_to_full_shape._tf_type_based_dispatcher.Dispatch + + +def xla_spmd_shard_to_full_shape_eager_fallback(input: Annotated[Any, TV_XlaSpmdShardToFullShape_T], manual_sharding: str, full_shape, dim: int, unspecified_dims, name, ctx) -> Annotated[Any, TV_XlaSpmdShardToFullShape_T]: + manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") + full_shape = _execute.make_shape(full_shape, "full_shape") + if dim is None: + dim = -1 + dim = _execute.make_int(dim, "dim") + if unspecified_dims is None: + unspecified_dims = [] + if not isinstance(unspecified_dims, (list, tuple)): + raise TypeError( + "Expected list for 'unspecified_dims' argument to " + "'xla_spmd_shard_to_full_shape' Op, not %r." % unspecified_dims) + unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] + _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) + _inputs_flat = [input] + _attrs = ("T", _attr_T, "manual_sharding", manual_sharding, "full_shape", + full_shape, "dim", dim, "unspecified_dims", unspecified_dims) + _result = _execute.execute(b"XlaSpmdShardToFullShape", 1, + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSpmdShardToFullShape", _inputs_flat, _attrs, _result) + _result, = _result + return _result + +_XlaSvdOutput = collections.namedtuple( + "XlaSvd", + ["s", "u", "v"]) + + +TV_XlaSvd_T = TypeVar("TV_XlaSvd_T", _atypes.BFloat16, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_svd') +def xla_svd(a: Annotated[Any, TV_XlaSvd_T], max_iter: int, epsilon: float, precision_config: str, name=None): + r"""Computes the eigen decomposition of a batch of self-adjoint matrices + + (Note: Only real inputs are supported). + + Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in + tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]). + + Args: + a: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. + the input tensor. + max_iter: An `int`. + maximum number of sweep update, i.e., the whole lower triangular + part or upper triangular part based on parameter lower. Heuristically, it has + been argued that approximately log(min (M, N)) sweeps are needed in practice + (Ref: Golub & van Loan "Matrix Computation"). + epsilon: A `float`. the tolerance ratio. + precision_config: A `string`. a serialized xla::PrecisionConfig proto. + name: A name for the operation (optional). + + Returns: + A tuple of `Tensor` objects (s, u, v). + + s: A `Tensor`. Has the same type as `a`. Singular values. The values are sorted in reverse order of magnitude, so + s[..., 0] is the largest value, s[..., 1] is the second largest, etc. + u: A `Tensor`. Has the same type as `a`. Left singular vectors. + v: A `Tensor`. Has the same type as `a`. Right singular vectors. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaSvd", name, a, "max_iter", max_iter, "epsilon", epsilon, + "precision_config", precision_config) + _result = _XlaSvdOutput._make(_result) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_svd( + (a, max_iter, epsilon, precision_config, name,), None) + if _result is not NotImplemented: + return _result + return xla_svd_eager_fallback( + a, max_iter=max_iter, epsilon=epsilon, + precision_config=precision_config, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_svd, (), dict(a=a, max_iter=max_iter, epsilon=epsilon, + precision_config=precision_config, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_svd( + (a, max_iter, epsilon, precision_config, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + max_iter = _execute.make_int(max_iter, "max_iter") + epsilon = _execute.make_float(epsilon, "epsilon") + precision_config = _execute.make_str(precision_config, "precision_config") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaSvd", a=a, max_iter=max_iter, epsilon=epsilon, + precision_config=precision_config, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_svd, (), dict(a=a, max_iter=max_iter, epsilon=epsilon, + precision_config=precision_config, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("max_iter", _op._get_attr_int("max_iter"), "epsilon", + _op.get_attr("epsilon"), "precision_config", + _op.get_attr("precision_config"), "T", _op._get_attr_type("T")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaSvd", _inputs_flat, _attrs, _result) + _result = _XlaSvdOutput._make(_result) + return _result + +XlaSvd = tf_export("raw_ops.XlaSvd")(_ops.to_raw_op(xla_svd)) +_dispatcher_for_xla_svd = xla_svd._tf_type_based_dispatcher.Dispatch + + +def xla_svd_eager_fallback(a: Annotated[Any, TV_XlaSvd_T], max_iter: int, epsilon: float, precision_config: str, name, ctx): + max_iter = _execute.make_int(max_iter, "max_iter") + epsilon = _execute.make_float(epsilon, "epsilon") + precision_config = _execute.make_str(precision_config, "precision_config") + _attr_T, (a,) = _execute.args_to_matching_eager([a], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) + _inputs_flat = [a] + _attrs = ("max_iter", max_iter, "epsilon", epsilon, "precision_config", + precision_config, "T", _attr_T) + _result = _execute.execute(b"XlaSvd", 3, inputs=_inputs_flat, attrs=_attrs, + ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaSvd", _inputs_flat, _attrs, _result) + _result = _XlaSvdOutput._make(_result) + return _result + + +TV_XlaVariadicReduce_T = TypeVar("TV_XlaVariadicReduce_T", _atypes.BFloat16, _atypes.Bool, _atypes.Complex128, _atypes.Complex64, _atypes.Float32, _atypes.Float64, _atypes.Half, _atypes.Int16, _atypes.Int32, _atypes.Int64, _atypes.Int8, _atypes.QInt16, _atypes.QInt32, _atypes.QInt8, _atypes.QUInt16, _atypes.QUInt8, _atypes.UInt16, _atypes.UInt32, _atypes.UInt64, _atypes.UInt8) + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_variadic_reduce') +def xla_variadic_reduce(input: Annotated[List[Any], TV_XlaVariadicReduce_T], init_value: Annotated[List[Any], TV_XlaVariadicReduce_T], dimensions_to_reduce, reducer, name=None): + r"""Wraps the variadic XLA Reduce operator. + + Semantics are documented at + https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. + + This version is limited to operands of the same dtype. + XlaVariadicReduceV2 is a version that supports heterogeneous operands. + + Args: + input: A list of at least 1 `Tensor` objects with the same type in: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. + the input tensor(s) + init_value: A list with the same length as `input` of `Tensor` objects with the same type as `input`. + scalar initial value(s) for the reduction + dimensions_to_reduce: A list of `ints`. + dimension numbers over which to reduce + reducer: A function decorated with @Defun. a reducer function to apply + name: A name for the operation (optional). + + Returns: + A list with the same length as `input` of `Tensor` objects with the same type as `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaVariadicReduce", name, input, init_value, + "dimensions_to_reduce", dimensions_to_reduce, "reducer", reducer) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_variadic_reduce( + (input, init_value, dimensions_to_reduce, reducer, name,), None) + if _result is not NotImplemented: + return _result + return xla_variadic_reduce_eager_fallback( + input, init_value, dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_variadic_reduce, (), dict(input=input, init_value=init_value, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_variadic_reduce( + (input, init_value, dimensions_to_reduce, reducer, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + if not isinstance(input, (list, tuple)): + raise TypeError( + "Expected list for 'input' argument to " + "'xla_variadic_reduce' Op, not %r." % input) + _attr_N = len(input) + if not isinstance(init_value, (list, tuple)): + raise TypeError( + "Expected list for 'init_value' argument to " + "'xla_variadic_reduce' Op, not %r." % init_value) + if len(init_value) != _attr_N: + raise ValueError( + "List argument 'init_value' to 'xla_variadic_reduce' Op with length %d " + "must match length %d of argument 'input'." % + (len(init_value), _attr_N)) + if not isinstance(dimensions_to_reduce, (list, tuple)): + raise TypeError( + "Expected list for 'dimensions_to_reduce' argument to " + "'xla_variadic_reduce' Op, not %r." % dimensions_to_reduce) + dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaVariadicReduce", input=input, init_value=init_value, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_variadic_reduce, (), dict(input=input, init_value=init_value, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("N", _op._get_attr_int("N"), "T", _op._get_attr_type("T"), + "dimensions_to_reduce", _op.get_attr("dimensions_to_reduce"), + "reducer", _op.get_attr("reducer")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaVariadicReduce", _inputs_flat, _attrs, _result) + return _result + +XlaVariadicReduce = tf_export("raw_ops.XlaVariadicReduce")(_ops.to_raw_op(xla_variadic_reduce)) +_dispatcher_for_xla_variadic_reduce = xla_variadic_reduce._tf_type_based_dispatcher.Dispatch + + +def xla_variadic_reduce_eager_fallback(input: Annotated[List[Any], TV_XlaVariadicReduce_T], init_value: Annotated[List[Any], TV_XlaVariadicReduce_T], dimensions_to_reduce, reducer, name, ctx): + if not isinstance(input, (list, tuple)): + raise TypeError( + "Expected list for 'input' argument to " + "'xla_variadic_reduce' Op, not %r." % input) + _attr_N = len(input) + if not isinstance(init_value, (list, tuple)): + raise TypeError( + "Expected list for 'init_value' argument to " + "'xla_variadic_reduce' Op, not %r." % init_value) + if len(init_value) != _attr_N: + raise ValueError( + "List argument 'init_value' to 'xla_variadic_reduce' Op with length %d " + "must match length %d of argument 'input'." % + (len(init_value), _attr_N)) + if not isinstance(dimensions_to_reduce, (list, tuple)): + raise TypeError( + "Expected list for 'dimensions_to_reduce' argument to " + "'xla_variadic_reduce' Op, not %r." % dimensions_to_reduce) + dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] + _attr_T, _inputs_T = _execute.args_to_matching_eager(list(input) + list(init_value), ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) + _inputs_T = [_inputs_T[:_attr_N]] + _inputs_T[_attr_N:] + _inputs_T = _inputs_T[:1] + [_inputs_T[1:]] + (input, init_value) = _inputs_T + _inputs_flat = list(input) + list(init_value) + _attrs = ("N", _attr_N, "T", _attr_T, "dimensions_to_reduce", + dimensions_to_reduce, "reducer", reducer) + _result = _execute.execute(b"XlaVariadicReduce", _attr_N, + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaVariadicReduce", _inputs_flat, _attrs, _result) + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_variadic_reduce_v2') +def xla_variadic_reduce_v2(inputs, init_values, dimensions_to_reduce, reducer, name=None): + r"""Wraps the variadic XLA Reduce operator. + + Semantics are documented at + https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. + + This is an expanded version of XlaVariadicReduce, with support for + operands of different dtypes, and improved shape inference. + + Args: + inputs: A list of `Tensor` objects. the input tensor(s) + init_values: A list of `Tensor` objects. Must have the same type as `inputs`. + scalar initial value(s) for the reduction + dimensions_to_reduce: A list of `ints`. + dimension numbers over which to reduce + reducer: A function decorated with @Defun. a reducer function to apply + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects. Has the same type as `inputs`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaVariadicReduceV2", name, inputs, init_values, + "dimensions_to_reduce", dimensions_to_reduce, "reducer", reducer) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_variadic_reduce_v2( + (inputs, init_values, dimensions_to_reduce, reducer, name,), None) + if _result is not NotImplemented: + return _result + return xla_variadic_reduce_v2_eager_fallback( + inputs, init_values, dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_variadic_reduce_v2, (), dict(inputs=inputs, + init_values=init_values, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_variadic_reduce_v2( + (inputs, init_values, dimensions_to_reduce, reducer, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + if not isinstance(dimensions_to_reduce, (list, tuple)): + raise TypeError( + "Expected list for 'dimensions_to_reduce' argument to " + "'xla_variadic_reduce_v2' Op, not %r." % dimensions_to_reduce) + dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaVariadicReduceV2", inputs=inputs, init_values=init_values, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_variadic_reduce_v2, (), dict(inputs=inputs, + init_values=init_values, + dimensions_to_reduce=dimensions_to_reduce, + reducer=reducer, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op.get_attr("T"), "dimensions_to_reduce", + _op.get_attr("dimensions_to_reduce"), "reducer", + _op.get_attr("reducer")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaVariadicReduceV2", _inputs_flat, _attrs, _result) + return _result + +XlaVariadicReduceV2 = tf_export("raw_ops.XlaVariadicReduceV2")(_ops.to_raw_op(xla_variadic_reduce_v2)) +_dispatcher_for_xla_variadic_reduce_v2 = xla_variadic_reduce_v2._tf_type_based_dispatcher.Dispatch + + +def xla_variadic_reduce_v2_eager_fallback(inputs, init_values, dimensions_to_reduce, reducer, name, ctx): + if not isinstance(dimensions_to_reduce, (list, tuple)): + raise TypeError( + "Expected list for 'dimensions_to_reduce' argument to " + "'xla_variadic_reduce_v2' Op, not %r." % dimensions_to_reduce) + dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] + _attr_T, (inputs, init_values) = _execute.args_to_mixed_eager_tensors((inputs, init_values), ctx) + _inputs_flat = list(inputs) + list(init_values) + _attrs = ("T", _attr_T, "dimensions_to_reduce", dimensions_to_reduce, + "reducer", reducer) + _result = _execute.execute(b"XlaVariadicReduceV2", len(inputs), + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaVariadicReduceV2", _inputs_flat, _attrs, _result) + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_variadic_sort') +def xla_variadic_sort(inputs, dimension: Annotated[Any, _atypes.Int32], comparator, is_stable: bool, name=None): + r"""Wraps the XLA Sort operator, documented at + + https://www.tensorflow.org/performance/xla/operation_semantics#sort + . + + Sorts one or more tensors, with support for custom comparator, dimension, and + is_stable attributes. + + Args: + inputs: A list of `Tensor` objects. + A list of `Tensor` of identical shape but possibly different types. + dimension: A `Tensor` of type `int32`. + The dimension along which to sort. Must be a compile-time constant. + comparator: A function decorated with @Defun. + A comparator function to apply to 2*N scalars and returning a + boolean. N is the number of sort inputs. If you want to sort in ascending + order then the comparator should perform a less-than comparison. + is_stable: A `bool`. Whether to use stable sort. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects. Has the same type as `inputs`. + A list of `Tensor` of same shape and types as the `input`. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaVariadicSort", name, inputs, dimension, "comparator", + comparator, "is_stable", is_stable) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_variadic_sort( + (inputs, dimension, comparator, is_stable, name,), None) + if _result is not NotImplemented: + return _result + return xla_variadic_sort_eager_fallback( + inputs, dimension, comparator=comparator, is_stable=is_stable, + name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_variadic_sort, (), dict(inputs=inputs, dimension=dimension, + comparator=comparator, + is_stable=is_stable, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_variadic_sort( + (inputs, dimension, comparator, is_stable, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + is_stable = _execute.make_bool(is_stable, "is_stable") + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaVariadicSort", inputs=inputs, dimension=dimension, + comparator=comparator, is_stable=is_stable, + name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_variadic_sort, (), dict(inputs=inputs, dimension=dimension, + comparator=comparator, + is_stable=is_stable, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if _execute.must_record_gradient(): + _attrs = ("T", _op.get_attr("T"), "comparator", + _op.get_attr("comparator"), "is_stable", + _op._get_attr_bool("is_stable")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaVariadicSort", _inputs_flat, _attrs, _result) + return _result + +XlaVariadicSort = tf_export("raw_ops.XlaVariadicSort")(_ops.to_raw_op(xla_variadic_sort)) +_dispatcher_for_xla_variadic_sort = xla_variadic_sort._tf_type_based_dispatcher.Dispatch + + +def xla_variadic_sort_eager_fallback(inputs, dimension: Annotated[Any, _atypes.Int32], comparator, is_stable: bool, name, ctx): + is_stable = _execute.make_bool(is_stable, "is_stable") + _attr_T, inputs = _execute.convert_to_mixed_eager_tensors(inputs, ctx) + dimension = _ops.convert_to_tensor(dimension, _dtypes.int32) + _inputs_flat = list(inputs) + [dimension] + _attrs = ("T", _attr_T, "comparator", comparator, "is_stable", is_stable) + _result = _execute.execute(b"XlaVariadicSort", len(inputs), + inputs=_inputs_flat, attrs=_attrs, ctx=ctx, + name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaVariadicSort", _inputs_flat, _attrs, _result) + return _result + + +@_dispatch.add_fallback_dispatch_list +@_dispatch.add_type_based_api_dispatcher +@tf_export('xla_while') +def xla_while(input, cond, body, name=None): + r"""output = input; While (Cond(output)) { output = Body(output) } + + Args: + input: A list of `Tensor` objects. + A list of input tensors whose types are T. + cond: A function decorated with @Defun. + A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. + body: A function decorated with @Defun. + A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. + name: A name for the operation (optional). + + Returns: + A list of `Tensor` objects. Has the same type as `input`. + A list of output tensors whose types are T. + """ + _ctx = _context._context or _context.context() + tld = _ctx._thread_local_data + if tld.is_eager: + try: + _result = pywrap_tfe.TFE_Py_FastPathExecute( + _ctx, "XlaWhile", name, input, "cond", cond, "body", body) + return _result + except _core._NotOkStatusException as e: + _ops.raise_from_not_ok_status(e, name) + except _core._FallbackException: + pass + try: + _result = _dispatcher_for_xla_while( + (input, cond, body, name,), None) + if _result is not NotImplemented: + return _result + return xla_while_eager_fallback( + input, cond=cond, body=body, name=name, ctx=_ctx) + except _core._SymbolicException: + pass # Add nodes to the TensorFlow graph. + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_while, (), dict(input=input, cond=cond, body=body, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + else: + _result = _dispatcher_for_xla_while( + (input, cond, body, name,), None) + if _result is not NotImplemented: + return _result + # Add nodes to the TensorFlow graph. + try: + _, _, _op, _outputs = _op_def_library._apply_op_helper( + "XlaWhile", input=input, cond=cond, body=body, name=name) + except (TypeError, ValueError): + _result = _dispatch.dispatch( + xla_while, (), dict(input=input, cond=cond, body=body, name=name) + ) + if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: + return _result + raise + _result = _outputs[:] + if not _result: + return _op + if _execute.must_record_gradient(): + _attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body", + _op.get_attr("body")) + _inputs_flat = _op.inputs + _execute.record_gradient( + "XlaWhile", _inputs_flat, _attrs, _result) + return _result + +XlaWhile = tf_export("raw_ops.XlaWhile")(_ops.to_raw_op(xla_while)) +_dispatcher_for_xla_while = xla_while._tf_type_based_dispatcher.Dispatch + + +def xla_while_eager_fallback(input, cond, body, name, ctx): + _attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx) + _inputs_flat = list(input) + _attrs = ("T", _attr_T, "cond", cond, "body", body) + _result = _execute.execute(b"XlaWhile", len(input), inputs=_inputs_flat, + attrs=_attrs, ctx=ctx, name=name) + if _execute.must_record_gradient(): + _execute.record_gradient( + "XlaWhile", _inputs_flat, _attrs, _result) + return _result + diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2de7c65e7f9493be6e5078fb91f81bc49fc38c5c Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__pycache__/xla.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__pycache__/xla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a4c2371e446b1960e301b1f1a7c83791cc65a45 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/__pycache__/xla.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/xla.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/xla.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8b19e4088c4b29bb1cb0d824bb5726725d9783 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/python/xla.py @@ -0,0 +1,726 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental library that exposes XLA operations directly in TensorFlow. + +It is sometimes useful to be able to build HLO programs directly from +TensorFlow. This file provides Tensorflow operators that mirror the semantics of +HLO operators as closely as possible. + +Note: Most of the operators defined in this module are used by the jax2tf +converter (see go/jax2tf for details) and are used in SavedModel produced +by jax2tf. Hence, we need to maintain backwards compatibility for these +operators. Please reach out to the JAX team if you want to make changes. +""" + +from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_random_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import random_ops_util +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.numpy_ops import np_utils + +# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing +# ops include: +# infeed/outfeed (available via tf.contrib.tpu) +# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) +# conditional +# gather/scatter +# collapse + +# This file reuses builtin names (following XLA's names, so we can call things +# like xla.max), so we capture the builtin versions here. +# pylint: disable=redefined-builtin +_max = max +_min = min +_slice = slice # pylint: disable=invalid-name + +constant = constant_op.constant + +# Unary operators. + +# For most arithmetic operators there is a TensorFlow operator +# that exactly corresponds to each XLA operator. Rather than defining +# XLA-specific variants, we reuse the corresponding TensorFlow operator. +# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 +# wrap every HLO operator, because that would allow us to be confident that the +# semantics match. + + +def _unary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def unary_op_wrapper(x, name=None): + return fn(x, name=name) + + return unary_op_wrapper + + +abs = _unary_op(math_ops.abs) +# TODO(phawkins): implement clz. +conj = _unary_op(math_ops.conj) +cos = _unary_op(math_ops.cos) +ceil = _unary_op(math_ops.ceil) +digamma = _unary_op(math_ops.digamma) +erf = _unary_op(math_ops.erf) +erfc = _unary_op(math_ops.erfc) +erfinv = _unary_op(math_ops.erfinv) +ndtri = _unary_op(math_ops.ndtri) +exp = _unary_op(math_ops.exp) +expm1 = _unary_op(math_ops.expm1) +floor = _unary_op(math_ops.floor) +imag = _unary_op(math_ops.imag) +is_finite = _unary_op(math_ops.is_finite) +lgamma = _unary_op(math_ops.lgamma) +log = _unary_op(math_ops.log) +log1p = _unary_op(math_ops.log1p) +logical_not = _unary_op(math_ops.logical_not) +neg = _unary_op(math_ops.neg) +real = _unary_op(math_ops.real) +# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for +# numbers halfway between two integers. +round = _unary_op(math_ops.round) +sin = _unary_op(math_ops.sin) +sign = _unary_op(math_ops.sign) +tan = _unary_op(math_ops.tan) +tanh = _unary_op(math_ops.tanh) + +# Bessel +bessel_i0e = _unary_op(special_math_ops.bessel_i0e) +bessel_i1e = _unary_op(special_math_ops.bessel_i1e) + +# Binary operators + +# The main difference between TensorFlow and XLA binary ops is the broadcasting +# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA +# requires an explicit specification of which dimensions to broadcast if the +# arguments have different ranks. + + +def _broadcasting_binary_op(fn): + """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" + + def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): + """Inner wrapper function.""" + broadcast_dims = broadcast_dims or [] + broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) + # Rather than relying on having static shape information in the TensorFlow + # graph, we use an XlaBroadcastHelper op that can compute the correct shapes + # at JIT compilation time. + x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) + return fn(x, y, name=name) + + return broadcasting_binary_op_wrapper + + +# Map from TF signed types to TF unsigned types. +_SIGNED_TO_UNSIGNED_TABLE = { + dtypes.int8: dtypes.uint8, + dtypes.int16: dtypes.uint16, + dtypes.int32: dtypes.uint32, + dtypes.int64: dtypes.uint64, +} + +# Map from TF unsigned types to TF signed types. +_UNSIGNED_TO_SIGNED_TABLE = { + dtypes.uint8: dtypes.int8, + dtypes.uint16: dtypes.int16, + dtypes.uint32: dtypes.int32, + dtypes.uint64: dtypes.int64, +} + + +def _shift_right_logical_helper(x, y, name=None): + """Performs an integer right logical shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + signed = dtype in _SIGNED_TO_UNSIGNED_TABLE + if signed: + unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] + x = math_ops.cast(x, unsigned_dtype) + y = math_ops.cast(y, unsigned_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if signed: + output = math_ops.cast(output, dtype) + return output + + +def _shift_right_arithmetic_helper(x, y, name=None): + """Performs an integer right arithmetic shift irrespective of input type.""" + assert y.dtype == x.dtype + dtype = x.dtype + unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE + if unsigned: + signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] + x = math_ops.cast(x, signed_dtype) + y = math_ops.cast(y, signed_dtype) + output = bitwise_ops.right_shift(x, y, name=name) + if unsigned: + output = math_ops.cast(output, dtype) + return output + + +add = _broadcasting_binary_op(math_ops.add) +sub = _broadcasting_binary_op(math_ops.sub) +mul = _broadcasting_binary_op(math_ops.mul) +div = _broadcasting_binary_op(math_ops.div) +rem = _broadcasting_binary_op(gen_math_ops.mod) +max = _broadcasting_binary_op(math_ops.maximum) +min = _broadcasting_binary_op(math_ops.minimum) +atan2 = _broadcasting_binary_op(math_ops.atan2) +complex = _broadcasting_binary_op(math_ops.complex) +logical_and = _broadcasting_binary_op(math_ops.logical_and) +logical_or = _broadcasting_binary_op(math_ops.logical_or) +logical_xor = _broadcasting_binary_op(math_ops.logical_xor) +eq = _broadcasting_binary_op(math_ops.equal) +ne = _broadcasting_binary_op(math_ops.not_equal) +ge = _broadcasting_binary_op(math_ops.greater_equal) +gt = _broadcasting_binary_op(math_ops.greater) +le = _broadcasting_binary_op(math_ops.less_equal) +lt = _broadcasting_binary_op(math_ops.less) +pow = _broadcasting_binary_op(math_ops.pow) +shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) +shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) +shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) + +igamma = _broadcasting_binary_op(math_ops.igamma) +igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) +random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) +igammac = _broadcasting_binary_op(math_ops.igammac) +polygamma = _broadcasting_binary_op(math_ops.polygamma) +zeta = _broadcasting_binary_op(math_ops.zeta) + + +def _binary_op(fn): + """Wrapper that restricts `fn` to have the correct signature.""" + + def binary_op_wrapper(x, y, name=None): + return fn(x, y, name=name) + + return binary_op_wrapper + + +transpose = _binary_op(array_ops.transpose) +rev = _binary_op(array_ops.reverse) + +bitcast_convert_type = array_ops.bitcast + + +def broadcast(x, dims, name=None): + x = ops.convert_to_tensor(x) + shape = array_ops.concat( + [constant_op.constant(dims), array_ops.shape(x)], axis=0 + ) + return array_ops.broadcast_to(x, shape, name=name) + + +def clamp(a, x, b, name=None): + return min(max(a, x, name=name), b, name=name) + + +concatenate = array_ops.concat + + +def conv( + lhs, + rhs, + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers, + feature_group_count=1, + precision_config=None, + preferred_element_type=None, + name=None, + use_v2=False, + batch_group_count=1, +): + """Wraps the XLA ConvGeneralDilated operator. + + ConvGeneralDilated is the most general form of XLA convolution and is + documented at + https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution + + Args: + lhs: the input tensor + rhs: the kernel tensor + window_strides: the inter-window strides + padding: the padding to apply at the start and end of each input dimensions + lhs_dilation: dilation to apply between input elements + rhs_dilation: dilation to apply between kernel elements + dimension_numbers: a `ConvolutionDimensionNumbers` proto. + feature_group_count: number of feature groups for grouped convolution. + precision_config: a `xla.PrecisionConfig` proto. + preferred_element_type: the result `dtype`. + name: an optional name for the operator. + use_v2: an optional request to use the XlaConvV2 op even if not necessary. + batch_group_count: number of batch groups or grouped filters. + + Returns: + A tensor representing the output of the convolution. + """ + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + needs_v2 = ( + preferred_element_type + or (lhs.dtype != rhs.dtype) + or batch_group_count > 1 + ) + if preferred_element_type is None: + preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) + if needs_v2 or use_v2: + return gen_xla_ops.xla_conv_v2( + lhs, + rhs, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + batch_group_count=batch_group_count, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + preferred_element_type=preferred_element_type, + name=name, + ) + return gen_xla_ops.xla_conv( + lhs, + rhs, + window_strides=window_strides, + padding=padding, + lhs_dilation=lhs_dilation, + rhs_dilation=rhs_dilation, + feature_group_count=feature_group_count, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name, + ) + + +convert_element_type = math_ops.cast + + +def dot(lhs, rhs, name=None): + return math_ops.tensordot(lhs, rhs, axes=1, name=name) + + +DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers +PrecisionConfig = xla_data_pb2.PrecisionConfig + + +def dot_general( + lhs, + rhs, + dimension_numbers, + precision_config=None, + preferred_element_type=None, + name=None, + use_v2=False, +): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) + if preferred_element_type is None: + preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) + if needs_v2 or use_v2: + return gen_xla_ops.xla_dot_v2( + lhs, + rhs, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + preferred_element_type=preferred_element_type, + name=name, + ) + return gen_xla_ops.xla_dot( + lhs, + rhs, + dimension_numbers=dimension_numbers.SerializeToString(), + precision_config=precision_config_proto, + name=name, + ) + + +def self_adjoint_eig(a, lower, max_iter, epsilon): + return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) + + +def svd(a, max_iter, epsilon, precision_config=None): + precision_config_proto = "" + if precision_config: + precision_config_proto = precision_config.SerializeToString() + return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) + + +dynamic_slice = gen_xla_ops.xla_dynamic_slice +dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice +einsum = gen_xla_ops.xla_einsum + +# TODO(phawkins): generalize tf.pad to support interior padding, and then remove +# the XLA-specific pad operator. +pad = gen_xla_ops.xla_pad + + +def random_normal(mu, sigma, dims, name=None): + mu = ops.convert_to_tensor(mu) + return random_ops.random_normal( + dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name + ) + + +def random_uniform(minval, maxval, dims, name=None): + minval = ops.convert_to_tensor(minval) + return random_ops.random_uniform( + dims, minval, maxval, dtype=minval.dtype, name=name + ) + + +def rng_bit_generator(algorithm, initial_state, shape, dtype): + """Stateless PRNG bit generator. + + Wraps the XLA RngBitGenerator operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. + + Args: + algorithm: The PRNG algorithm to use, one of tf.random.Algorithm.{PHILOX, + THREEFRY, AUTO_SELECT}. + initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should + be a u64[2] and for PHILOX a u64[3]. + shape: The output shape of the generated data. + dtype: The type of the tensor. + + Returns: + a tuple with a new state and generated data of the given shape. + """ + alg_int = random_ops_util.convert_alg_to_int(algorithm) + return gen_xla_ops.xla_rng_bit_generator( + alg_int, initial_state, shape, dtype=dtype + ) + + +recv = gen_xla_ops.xla_recv +reduce = gen_xla_ops.xla_reduce +variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2 + +ops.no_gradient("XlaVariadicReduce") + + +def reduce_window( + operand, + init, + reducer, + window_dimensions, + window_strides=None, + base_dilations=None, + window_dilations=None, + padding=None, + name=None, +): + """Wraps the XLA ReduceWindow operator. + + ReduceWindow is documented at + https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . + + Args: + operand: the input tensor + init: a scalar tensor representing the initial value for the reduction + reducer: a reduction function that combines a pair of scalars. + window_dimensions: shape of the window, as a list of integers + window_strides: inter-window strides, as a list of integers. Optional; if + omitted, defaults to strides of 1. + padding: padding to apply to 'operand'. List of (low, high) pairs of + integers that specify the padding to apply before and after each + dimension. Optional; if omitted, defaults to no padding. + name: the operator name, or None. + + Returns: + A tensor that represents the output of the reduce_window operator. + """ + window_strides = window_strides or [1] * len(window_dimensions) + base_dilations = base_dilations or [1] * len(window_dimensions) + window_dilations = window_dilations or [1] * len(window_dimensions) + padding = padding or [(0, 0)] * len(window_dimensions) + return gen_xla_ops.xla_reduce_window( + input=operand, + init_value=init, + window_dimensions=window_dimensions, + window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, + padding=padding, + computation=reducer, + name=name, + ) + + +replica_id = gen_xla_ops.xla_replica_id + +# Set a static bound for the given input value as a hint to Xla compiler, +# returns the same value. +# Usage: +# def f(t, p): +# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. +# return t[:p] # xla knows the bound of the slice is 3. +set_bound = gen_xla_ops.xla_set_bound + +# Make a static dimension into a xla bounded dynamic dimension. The current +# static dimension size will become the bound and the second operand becomes the +# dynamic size of the dimension. +# +# This should mostly be used for testing. +# +# def f(): +# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]]) +# # Tells xla the valid size of the array is 3. +# dim = 0 +# p = xla_set_dynamic_dimension_size(array, dim, 3) +# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid. +set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size + +# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic +# dimension into a static dimension. The bound of the size of dimension +# `dim_index` becomes the static dimension size. +remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size + + +def reshape(x, new_sizes, dimensions=None, name=None): + if dimensions is not None: + x = array_ops.transpose(x, dimensions) + x = array_ops.reshape(x, new_sizes, name=name) + return x + + +def select(condition, x, y, name=None): + return array_ops.where(condition, x, y, name) + + +select_and_scatter = gen_xla_ops.xla_select_and_scatter +send = gen_xla_ops.xla_send + + +def slice(x, start_dims, limit_dims, strides): + spec = [ + _slice(start, limit, stride) + for (start, limit, stride) in zip(start_dims, limit_dims, strides) + ] + return x[tuple(spec)] + + +sharding = gen_xla_ops.xla_sharding + + +@ops.RegisterGradient("XlaSharding") +def _sharding_grad(op, grad): + """Gradient for XlaSharding op.""" + sharding_attr = op.get_attr("sharding") + grad_sharding = gen_xla_ops.xla_sharding( + grad, + sharding=sharding_attr, + unspecified_dims=op.get_attr("unspecified_dims"), + ) + # pylint: disable=protected-access + grad_sharding.op._set_attr( + "_XlaSharding", attr_value_pb2.AttrValue(s=sharding_attr) + ) + return [grad_sharding] + + +spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape +spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape + + +@ops.RegisterGradient("XlaSpmdFullToShardShape") +def _spmd_full_to_shard_shape_grad(op, grad): + s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( + grad, + manual_sharding=op.get_attr("manual_sharding"), + full_shape=op.inputs[0].shape.as_list(), + dim=op.get_attr("dim"), + unspecified_dims=op.get_attr("unspecified_dims"), + ) + return [s2f] + + +@ops.RegisterGradient("XlaSpmdShardToFullShape") +def _spmd_shard_to_full_shape_grad(op, grad): + f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( + grad, + manual_sharding=op.get_attr("manual_sharding"), + dim=op.get_attr("dim"), + unspecified_dims=op.get_attr("unspecified_dims"), + ) + return [f2s] + + +sort = gen_xla_ops.xla_sort +key_value_sort = gen_xla_ops.xla_key_value_sort +variadic_sort = gen_xla_ops.xla_variadic_sort +while_loop = gen_xla_ops.xla_while +dequantize = gen_xla_ops.xla_dequantize +custom_call = gen_xla_ops.xla_custom_call + + +def custom_call_v2( + call_target_name, + operands, + result_specs, + backend_config=None, + has_side_effect=None, + name=None, +): + """Emits an HLO `CustomCall` operation with multiple outputs. + + See `CustomCall` specification at + https://tensorflow.org/xla/operation_semantics#customcall, + and `mhlo.custom_call` specification at + https://tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop. + + Args: + call_target_name: Name of the user function. The function signature must + conform to version 3 of the API, see + `API_VERSION_STATUS_RETURNING_UNIFIED`. All operands and results assumed + to be in the default layout. + operands: A sequence of tensors with possibly different types. + result_specs: A sequence of tensor specs for all results. + backend_config: A string that encodes a metadata for the backend. Empty + string by default. + has_side_effect: Indicates whether the custom call has side effects. `False` + by default. + name: Optional name of the operation. + + Returns: + A tuple of output tensors. + """ + return gen_xla_ops.xla_custom_call_v2( + operands=operands, + call_target_name=call_target_name, + backend_config="" if backend_config is None else backend_config, + has_side_effect=False if has_side_effect is None else has_side_effect, + result_dtypes=tuple(spec.dtype for spec in result_specs), + result_shapes=tuple(spec.shape for spec in result_specs), + name=name, + ) + + +# pylint: disable=g-doc-args +# pylint: disable=g-doc-return-or-yield +def call_module( + args, + *, + version=4, + module, + Tout, + Sout, + platforms=(), + function_list=(), + has_token_input_output=False, + disabled_checks=(), +): + """See documentation for the XlaCallModule op. + + https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_ops.cc+xlacallmodule&type=code + """ + res = gen_xla_ops.xla_call_module( + args, + version=version, + module=module, + dim_args_spec=(), + Tout=Tout, + Sout=Sout, + platforms=platforms, + function_list=function_list, + has_token_input_output=has_token_input_output, + disabled_checks=disabled_checks, + ) + # Since XLACallModule op is stateful, zero return function will return the TF + # op under tf.function. It creates trouble for downstream codes. + # Here we force it return empty tuple to work around it. + # TODO(johnqiangzhang): Figure out a better way to handle control dependency. + if isinstance(res, ops.Operation): + res = () + return res + + +def call_module_maximum_supported_version(): + """Maximum version of XlaCallModule op supported. + + See versioning details documentation for the XlaCallModule op at: + https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code + """ + return 9 + +# pylint: enable=g-doc-args +# pylint: enable=g-doc-return-or-yield + + +def call_module_disable_check_platform(): + # For use with xla_call_module.disabled_checks. + return "platform" + + +def gather( + operand, + start_indices, + dimension_numbers, + slice_sizes, + indices_are_sorted=False, + name=None, +): + return gen_xla_ops.xla_gather( + operand, + start_indices, + slice_sizes=slice_sizes, + dimension_numbers=dimension_numbers.SerializeToString(), + indices_are_sorted=indices_are_sorted, + name=name, + ) + + +def scatter( + operand, + scatter_indices, + updates, + update_computation, + dimension_numbers, + indices_are_sorted=False, + name=None, +): + return gen_xla_ops.xla_scatter( + operand, + scatter_indices, + updates, + update_computation=update_computation, + dimension_numbers=dimension_numbers.SerializeToString(), + indices_are_sorted=indices_are_sorted, + name=name, + ) + + +def optimization_barrier(*args): + return gen_xla_ops.xla_optimization_barrier(args) + + +def reduce_precision(operand, exponent_bits, mantissa_bits): + return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/tf2xla_pb2.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/tf2xla_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..5be7d74bd5adc6a11330569cf0a6e5f4c0b3a5af --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/compiler/tf2xla/tf2xla_pb2.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tensorflow/compiler/tf2xla/tf2xla.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2 +from tensorflow.core.framework import types_pb2 as tensorflow_dot_core_dot_framework_dot_types__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'tensorflow/compiler/tf2xla/tf2xla.proto\x12\x11tensorflow.tf2xla\x1a,tensorflow/core/framework/tensor_shape.proto\x1a%tensorflow/core/framework/types.proto\"3\n\x08TensorId\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x14\n\x0coutput_index\x18\x02 \x01(\x03\"\x8e\x01\n\x04\x46\x65\x65\x64\x12\'\n\x02id\x18\x01 \x01(\x0b\x32\x1b.tensorflow.tf2xla.TensorId\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\"\n\x04type\x18\x04 \x01(\x0e\x32\x14.tensorflow.DataType\"\x8f\x01\n\x05\x46\x65tch\x12\'\n\x02id\x18\x01 \x01(\x0b\x32\x1b.tensorflow.tf2xla.TensorId\x12\x0c\n\x04name\x18\x02 \x01(\t\x12+\n\x05shape\x18\x03 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\"\n\x04type\x18\x04 \x01(\x0e\x32\x14.tensorflow.DataType\"\x8e\x01\n\x08Variable\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12+\n\x05shape\x18\x03 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\"\n\x04type\x18\x04 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x10\n\x08readonly\x18\x05 \x01(\x08\"\x87\x01\n\x06\x43onfig\x12%\n\x04\x66\x65\x65\x64\x18\x01 \x03(\x0b\x32\x17.tensorflow.tf2xla.Feed\x12\'\n\x05\x66\x65tch\x18\x02 \x03(\x0b\x32\x18.tensorflow.tf2xla.Fetch\x12-\n\x08variable\x18\x03 \x03(\x0b\x32\x1b.tensorflow.tf2xla.VariableB*\n\x15org.tensorflow.tf2xlaB\x0cTf2XlaProtosP\x01\xf8\x01\x01\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.tf2xla.tf2xla_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\025org.tensorflow.tf2xlaB\014Tf2XlaProtosP\001\370\001\001' + _TENSORID._serialized_start=147 + _TENSORID._serialized_end=198 + _FEED._serialized_start=201 + _FEED._serialized_end=343 + _FETCH._serialized_start=346 + _FETCH._serialized_end=489 + _VARIABLE._serialized_start=492 + _VARIABLE._serialized_end=634 + _CONFIG._serialized_start=637 + _CONFIG._serialized_end=772 +# @@protoc_insertion_point(module_scope)