diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6799782dbf47c42c13c6bd1c57cfa3c39bffbab2 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_grad.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_grad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3a5db395cfab0c27e717b3284cc953681bb23ac Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_grad.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44bfd1b1c3c04d999a6621a09b1a4e50385e066d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_ops_stack.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_ops_stack.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bcab9c06a7f28a0a6eb15ac60e553c6cc844ac6 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/array_ops_stack.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/autograph_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/autograph_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64c764e4ab3d54a61e02697069707cf99b5b5be3 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/autograph_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/batch_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/batch_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56960861b6a373540a7be84f43d0000ade8ef5bb Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/batch_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/bincount_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/bincount_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e7d5edcf562e2a334a05ad962a3472821bdf2c8 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/bincount_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/bitwise_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/bitwise_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c60b97641ebc82461a3b3f6348585a5a4ed319bd Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/bitwise_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/boosted_trees_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/boosted_trees_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c88d417b27ed360d1df98f166b78371d6bf76df6 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/boosted_trees_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/candidate_sampling_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/candidate_sampling_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b91cde1193764924b53608749d9aa277857bfd4d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/candidate_sampling_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/check_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/check_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c1b08b2d19df65296fda5600a32588116a69504 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/check_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/clip_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/clip_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dbe50b5721193effcc8a4f947beb1678ef4e576 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/clip_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/clustering_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/clustering_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90b3a9d4783538c57e91d88bf1af1090814ec153 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/clustering_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/collective_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/collective_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa06d897a2435dc689fbc6b3e48e66282c1b7e81 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/collective_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/composite_tensor_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/composite_tensor_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a03e5bf1903265ff73941f8cbf06356a903d7a6b Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/composite_tensor_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cond.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cond.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3eed4ae810c6d379e2c2d0e22fb04ad2e4dea4b Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cond.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cond_v2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cond_v2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54b4d7650660e15202bb3c02b9b9323994718da1 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cond_v2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/confusion_matrix.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/confusion_matrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bfd285b1cc81996100654aaeb603183dda21a5e Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/confusion_matrix.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_assert.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_assert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..045ea7a212080a6da0595171428426974e39194e Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_assert.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_case.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_case.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d22c37e1f37f823c259cfa3cd5385cf880c7f959 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_case.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_grad.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_grad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db65ae8f184aa99eba98b22146a9deffb7ffe29b Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_grad.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df86972e1bc6a075f962af94886e25b3bd035498 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_state.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49d59e4667d49772b2cfef66da09e5f7c7b91d5e Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_state.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_switch_case.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_switch_case.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..892f16fdc8991991d986da588cbb47554f1fa337 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_switch_case.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_util.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0921abc1d8764786b3854ff5a1e296e90b77c9ae Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_util.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_util_v2.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_util_v2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8831790c3612904070c22bed49d1abb65fe8c435 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_util_v2.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_v2_func_graphs.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_v2_func_graphs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76f6192dbd1de5b54e9ba603d47a2e772001b8ef Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_v2_func_graphs.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_v2_toggles.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_v2_toggles.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..130611c1bd5b06ef95b6da0fdd04c9633cccc169 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/control_flow_v2_toggles.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/critical_section_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/critical_section_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfc7c607fdfbe8ef21f46ae89b6120e9eccf9557 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/critical_section_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/ctc_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/ctc_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ffda25cfdf0824dd42070018dc960c3d9d73ae2 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/ctc_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cudnn_rnn_grad.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cudnn_rnn_grad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cea74b7f6049069fbc52d45db8ed9b890e372585 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/cudnn_rnn_grad.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/custom_gradient.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/custom_gradient.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8ee4c749c1d9997d26f64620e949be696c7c599 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/custom_gradient.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/data_flow_grad.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/data_flow_grad.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65926827eac0ddccf5e69ae6c087c229f01c3b48 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/data_flow_grad.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/data_flow_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/data_flow_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..779b9e82cea20b8938fe0e9803e8557e2e83e866 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/data_flow_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/default_gradient.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/default_gradient.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53f3bf3674741bc616f719f60f470716538cbe93 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/default_gradient.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/embedding_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/embedding_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb5fd52fe401e3f18f72ef04211211559ebebcf2 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/embedding_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/filesystem_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/filesystem_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54f89f5dd87abc6a78a67f227ce34620f318f6fa Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/filesystem_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/functional_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/functional_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..135ae2180a9f54694b7edb8f8f89ddc985830aa5 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/functional_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_array_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_array_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dd5d799554185a3feef8c0cc178f829eb468374 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_array_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_audio_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_audio_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a3b8bb105ffadd681df5cfe2c9c03b1c3bf7e2a Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_audio_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_batch_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_batch_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef36d83a9446052968c720d1330f607fa116d3d9 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_batch_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_bitwise_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_bitwise_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03662523284b3fe4bb6a1ae1ad7f05ceac4c5500 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_bitwise_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_boosted_trees_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_boosted_trees_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1a7208f66dbcb71b8da4e4e866e29c2d2c9d4eb Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_boosted_trees_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_candidate_sampling_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_candidate_sampling_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b9c0e83b366bb1a4478baea40817d168b8e6c7a Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_candidate_sampling_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_checkpoint_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_checkpoint_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7b21fe9cf65b81fadf823edd047026e85372934 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_checkpoint_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_clustering_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_clustering_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9914c35508ca8eb50fd5fc675b95648ace8f262 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_clustering_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_collective_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_collective_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c51e9bbe1805d15d435d665bc183d1e9bfd6937 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_collective_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_composite_tensor_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_composite_tensor_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..befa9ce53c057f7b8df0daca84d7f8bf797325ba Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_composite_tensor_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_control_flow_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_control_flow_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3c30532d9aeef7f3245b4f748c8b765b7337f8e Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_control_flow_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_count_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_count_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bc1cfe5e3d465cc2bd17cf614f707325d40ce1d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_count_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_ctc_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_ctc_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f74700a2315291bfc86744f29df2fe630cd4fac Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_ctc_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_cudnn_rnn_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_cudnn_rnn_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b353485391d6c8412b59a65bcc36b8c2976a2a2 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_cudnn_rnn_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_data_flow_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_data_flow_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f278ceebb63c506e668be42d47b0a27d55e194b3 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_data_flow_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_dataset_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_dataset_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d369c7b033f8e33ea2156c2da91f6750e77339d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_dataset_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_debug_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_debug_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbbf4e289696fb0f0d34ca421196d976f68d7fd9 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_debug_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_decode_proto_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_decode_proto_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f00c392dcd298aec0a47711b9a8b1604a7c73d30 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/__pycache__/gen_decode_proto_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23c447a5abc80dcb377a6060896b933eada64539 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for pfor, for_loop, jacobian.""" diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de580c5b5585c4672f5dfe47f838c6401699ad9b Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/control_flow_ops.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/control_flow_ops.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c8dcb2b1b6658a262f1d2e5a7d69d295b15103d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/control_flow_ops.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/gradients.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/gradients.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c6b7b0e5b596227eea9cd1caaa50c05a55525e2 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/gradients.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/pfor.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/pfor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e8044ee96f0a19c8bafedb4dde7e25185480b5 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/pfor.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/test_util.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/test_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4173a748b6be089f7a51ac486e134eddbef7a29 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/__pycache__/test_util.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/control_flow_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/control_flow_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e65e4fdd1c1a2c9f9d21ca2a4c8113a3c7cf6814 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -0,0 +1,582 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""for_loop and pfor ops.""" +# pylint: disable=g-direct-tensorflow-import + +import functools + +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.autograph.core import ag_ctx as autograph_ctx +from tensorflow.python.autograph.impl import api as autograph +from tensorflow.python.framework import composite_tensor +from tensorflow.python.framework import indexed_slices +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import type_spec +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import cond +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import while_loop +from tensorflow.python.ops.parallel_for.pfor import PFor +from tensorflow.python.ops.parallel_for.pfor import PForConfig +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect +from tensorflow.python.util import variable_utils +from tensorflow.python.util.tf_export import tf_export + + +def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): + """Runs `loop_fn` `iters` times and stacks the outputs. + + + Runs `loop_fn` `iters` times, with input values from 0 to `iters - 1`, and + stacks corresponding outputs of the different runs. + + Args: + loop_fn: A function that takes an int32 scalar tf.Tensor object representing + the iteration number, and returns a possibly nested structure of tensor + objects. The shape of these outputs should not depend on the input. + loop_fn_dtypes: dtypes for the outputs of `loop_fn`. + iters: Number of iterations for which to run `loop_fn`. + parallel_iterations: The number of iterations that can be dispatched in + parallel. This knob can be used to control the total memory usage. + + Returns: + Returns a nested structure of stacked output tensor objects with the same + nested structure as the output of `loop_fn`. + """ + + flat_loop_fn_dtypes = nest.flatten(loop_fn_dtypes) + is_none_list = [] + + def while_body(i, *ta_list): + """Body of while loop.""" + fn_conv = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) + fn_output = nest.flatten(fn_conv(i)) + if len(fn_output) != len(flat_loop_fn_dtypes): + raise ValueError( + f"Number of expected outputs {len(flat_loop_fn_dtypes)}, does not " + f"match the number of actual outputs {len(fn_output)} from loop_fn: " + f"{loop_fn} with output {fn_output}.") + outputs = [] + del is_none_list[:] + is_none_list.extend(x is None for x in fn_output) + for out, ta in zip(fn_output, ta_list): + # TODO(agarwal): support returning Operation objects from loop_fn. + if out is not None: + # out may be a ref tensor, wrap it in identity to get a non-ref tensor. + ta = ta.write(i, out) + outputs.append(ta) + return tuple([i + 1] + outputs) + + if parallel_iterations is not None: + extra_args = {"parallel_iterations": parallel_iterations} + else: + extra_args = {} + ta_list = while_loop.while_loop(lambda i, *ta: i < iters, while_body, [0] + [ + tensor_array_ops.TensorArray(dtype.base_dtype, iters) + for dtype in flat_loop_fn_dtypes + ], **extra_args)[1:] + + # TODO(rachelim): enable this for sparse tensors + + output = [ + None if is_none else ta.stack() + for ta, is_none in zip(ta_list, is_none_list) + ] + assert len(output) in (0, len(flat_loop_fn_dtypes)) + if not output: + # This may happen for the case where iters == 0. + # Pack a list of empty tensors with the proper ranks to match pfor output on 0 iters + loop_var = array_ops.placeholder_with_default(0, shape=[]) + try: + loop_fn_out = loop_fn(loop_var) + out_shapes = [ + [0] + ops.convert_to_tensor(x).shape + for x in nest.flatten(loop_fn_out) + ] + output = [ + array_ops.zeros(out_shapes[i], dt) + for i, dt in enumerate(flat_loop_fn_dtypes) + ] + except Exception: + output = [array_ops.zeros([0])] + return nest.pack_sequence_as(loop_fn_dtypes, output) + + +def _flatten_first_two_dims(x): + """Flattens the first two dimensions of x into a single dimension.""" + old_shape = array_ops.shape(x) + new_shape = array_ops.concat([[old_shape[0] * old_shape[1]], old_shape[2:]], + axis=0) + return array_ops.reshape(x, new_shape) + + +PFOR_CONFIG_ARG = "pfor_config" + + +def _is_under_xla_context(): + """Check if we are currently inside an XLA compile context.""" + g = ops.get_default_graph() + while g is not None: + control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access + while control_flow_context is not None: + if control_flow_context.IsXLAContext(): + return True + else: + control_flow_context = control_flow_context.outer_context + # If g is a FuncGraph, get its outer_graph. + g = getattr(g, "outer_graph", None) + return False + + +def pfor(loop_fn, + iters, + fallback_to_while_loop=True, + parallel_iterations=None, + warn=False): + """Equivalent to running `loop_fn` `iters` times and stacking the outputs. + + `pfor` has functionality similar to `for_loop`, i.e. running `loop_fn` `iters` + times, with input from 0 to `iters - 1`, and stacking corresponding output of + each iteration. However the implementation does not use a `tf.while_loop`. + Instead it adds new operations to the graph that collectively compute the same + value as what running `loop_fn` in a loop would compute. + + + This is an experimental feature and currently has a lot of limitations: + - There should be no data dependency between the different iterations. For + example, a future iteration should not depend on a value or side-effect of + a previous iteration. + - Stateful kernels may mostly not be supported since these often imply a + data dependency or ordering of the iterations. We do support a limited set + of such stateful kernels though (like RandomFoo, Variable operations like + reads, etc). + - Conversion works only on a limited set of kernels for which a converter + has been registered. + - `loop_fn` has limited support for control flow operations. `tf.cond` in + particular is not supported. + - `loop_fn` should return nested structure of Tensors or Operations. However + if an Operation is returned, it should have zero outputs. + - The shape and dtype of `loop_fn` outputs should not depend on the input + to loop_fn. + + Args: + loop_fn: A function that takes an int32 scalar tf.Tensor object representing + the iteration number, and optionally a keyword argument `pfor_config` set + to a PForConfig object. It returns a possibly nested structure of Tensor + or Operation objects. Note that if setting `parallel_iterations` argument + to something other than None, `loop_fn` may be called more than once + during graph construction. So it may need to avoid mutating global state. + iters: Number of iterations for which to run `loop_fn`. + fallback_to_while_loop: If true, on failing to vectorize an operation, pfor + fallbacks to using a `tf.while_loop` to dispatch the iterations. + parallel_iterations: A knob to control how many iterations are vectorized + and dispatched in parallel. The default value of None corresponds to + vectorizing all the iterations. If `parallel_iterations` is smaller than + `iters`, then chunks of at most that many iterations are dispatched in + sequence. This knob can be used to control the total memory usage. + warn: Whether or not to warn when falling back to while loops. + + Returns: + Returns a nested structure of stacked tensor objects with the same nested + structure as the output of `loop_fn`. + Raises: + ValueError: If parallel_iterations is not None and not an integer > 1. + """ + def f(): + return _pfor_impl( + loop_fn, + iters, + fallback_to_while_loop=fallback_to_while_loop, + parallel_iterations=parallel_iterations, + warn=warn) + # Note that we wrap into a tf.function if in eager execution mode or under + # XLA compilation. The latter is so that we don't compile operations like + # tf.placeholder that are created by the loop body. + functions_run_eagerly = None + if context.executing_eagerly() or _is_under_xla_context(): + functions_run_eagerly = def_function.functions_run_eagerly() + if functions_run_eagerly: + logging.warning( + "It looks like tf.function behavior was disabled, perhaps using " + "tf.config.run_functions_eagerly. Vectorization " + "primitives (e.g. tf.vectorized_map) require tf.function to work. " + "These primitives will override the disable.") + def_function.run_functions_eagerly(False) + f = def_function.function(f) + + outputs = f() + if functions_run_eagerly is not None: + def_function.run_functions_eagerly(functions_run_eagerly) + return outputs + + +def _should_expand_composite(value): + return (isinstance(value, composite_tensor.CompositeTensor) + # Leave sparse tensors to be converted by `PFor._convert_sparse`. + and not isinstance(value, sparse_tensor.SparseTensor) + and not isinstance(value, indexed_slices.IndexedSlices)) + + +# pylint: disable=protected-access +def _composite_to_tensors(value, is_batched=False): + """Converts a CompositeTensor into a list of stackable tensors.""" + if _should_expand_composite(value): + spec = value._type_spec + if not isinstance(spec, type_spec.BatchableTypeSpec): + raise ValueError(f"CompositeTensor instance {value} returned from " + "parallel_for or vectorized_map loop body must provide " + f"a `BatchableTypeSpec` (saw: {spec}).") + if is_batched: + return spec._to_batched_tensor_list(value) + return spec._to_tensor_list(value) + return value +# pylint: enable=protected-access + + +# pylint: disable=protected-access +def _composite_from_tensors(stacked_tensors, + preconverted_value, + batch_size): + """Converts a list of stacked tensors to a batch CompositeTensor.""" + if _should_expand_composite(preconverted_value): + batch_type_spec = preconverted_value._type_spec._batch(batch_size) + return batch_type_spec._from_compatible_tensor_list(stacked_tensors) + return stacked_tensors +# pylint: enable=protected-access + + +def _loop_fn_has_config(loop_fn): + """Test if `loop_fn` has a `pfor_config` argument.""" + if tf_inspect.isfunction(loop_fn): + argspec = tf_inspect.getargspec(loop_fn) + return PFOR_CONFIG_ARG in argspec.args + elif isinstance(loop_fn, functools.partial): + fn = loop_fn.func + argspec = tf_inspect.getargspec(fn) + return (PFOR_CONFIG_ARG in argspec.args and + PFOR_CONFIG_ARG not in loop_fn.keywords) + else: + loop_class = tf_decorator.unwrap(loop_fn)[1] + if not hasattr(loop_class, "__call__"): + raise ValueError("`loop_fn` object did not have a __call__ method") + argspec = tf_inspect.getargspec(loop_class.__call__) + return PFOR_CONFIG_ARG in argspec.args + + +def _pfor_impl(loop_fn, + iters, + fallback_to_while_loop, + parallel_iterations=None, + pfor_config=None, + warn=False): + """Implementation of pfor.""" + assert not context.executing_eagerly() + loop_fn_has_config = _loop_fn_has_config(loop_fn) + existing_ops = set(ops.get_default_graph().get_operations()) + iters_value = tensor_util.constant_value(iters) + # Run the loop body + with ops.name_scope("loop_body"): + loop_var = array_ops.placeholder_with_default(0, shape=[]) + if loop_fn_has_config: + if pfor_config is None: + pfor_config = PForConfig() + pfor_config._set_iters(iters) # pylint: disable=protected-access + loop_fn_outputs = loop_fn(loop_var, **{PFOR_CONFIG_ARG: pfor_config}) + else: + assert pfor_config is None + f = autograph.tf_convert(loop_fn, autograph_ctx.control_status_ctx()) + loop_fn_outputs = f(loop_var) + loop_fn_output_tensors = nest.map_structure(_composite_to_tensors, + loop_fn_outputs) + + # Convert outputs to Tensor if needed. + tmp_loop_fn_outputs = [] + for loop_fn_output in nest.flatten(loop_fn_output_tensors): + if (loop_fn_output is not None and not isinstance( + loop_fn_output, + (ops.Operation, tensor.Tensor, sparse_tensor.SparseTensor))): + if isinstance(loop_fn_output, indexed_slices.IndexedSlices): + logging.warn("Converting %s to a dense representation may make it slow." + " Alternatively, output the indices and values of the" + " IndexedSlices separately, and handle the vectorized" + " outputs directly." % loop_fn_output) + loop_fn_output = ops.convert_to_tensor(loop_fn_output) + else: + loop_fn_output = ops.convert_to_tensor(loop_fn_output) + tmp_loop_fn_outputs.append(loop_fn_output) + loop_fn_output_tensors = nest.pack_sequence_as(loop_fn_output_tensors, + tmp_loop_fn_outputs) + + new_ops = set(ops.get_default_graph().get_operations()) - existing_ops + iters = ops.convert_to_tensor(iters) + if parallel_iterations is not None: + if parallel_iterations < 1: + raise ValueError( + "Argument `parallel_iterations` must be None or a positive integer. " + f"Received: {parallel_iterations}.") + if parallel_iterations == 1: + raise ValueError( + "Found `parallel_iterations == 1`. Use `for_loop` instead.") + if iters_value is not None and iters_value < parallel_iterations: + parallel_iterations = None + if parallel_iterations is None: + with ops.name_scope("pfor"): + converter = PFor( + loop_var, + iters, + new_ops, + fallback_to_while_loop=fallback_to_while_loop, + pfor_config=pfor_config, + warn=warn) + flattened_output_tensors = [] + for loop_fn_output in nest.flatten(loop_fn_output_tensors): + output = converter.convert(loop_fn_output) + flattened_output_tensors.append(output) + else: + if pfor_config is not None and pfor_config._has_reductions(): # pylint: disable=protected-access + raise ValueError("Setting `parallel_iterations` currently unsupported if " + "reductions across iterations are performed.") + num_tiled_iterations = iters // parallel_iterations + num_remaining_iterations = iters % parallel_iterations + # TODO(agarwal): Avoid calling loop_fn twice. Generate the loop body inside + # a tf.function and extract the graph from there to vectorize it. + with ops.name_scope("pfor_untiled"): + converter = PFor(loop_var, num_remaining_iterations, new_ops, + fallback_to_while_loop=fallback_to_while_loop, + pfor_config=pfor_config) + remaining_output_tensors = [] + flattened_output_tensors = nest.flatten(loop_fn_output_tensors) + for loop_fn_output in flattened_output_tensors: + output = converter.convert(loop_fn_output) + remaining_output_tensors.append(output) + + with ops.name_scope("pfor_tiled"): + loop_fn_dtypes = [ops.convert_to_tensor(x).dtype + for x in flattened_output_tensors] + + def tiled_loop_body(j): + offset = j * parallel_iterations + num_remaining_iterations + + def tiled_loop_fn(i, pfor_config=None): + if loop_fn_has_config: + loop_fn_outputs = loop_fn(i + offset, pfor_config=pfor_config) + else: + loop_fn_outputs = loop_fn(i + offset) + return nest.flatten( + # Stacking across iterations requires explicit Tensors. + nest.map_structure(_composite_to_tensors, loop_fn_outputs)) + + return _pfor_impl( + tiled_loop_fn, + parallel_iterations, + fallback_to_while_loop=fallback_to_while_loop, + pfor_config=pfor_config) + + tiled_output_tensors = for_loop( + tiled_loop_body, loop_fn_dtypes, + num_tiled_iterations, parallel_iterations=1) + tiled_output_tensors = [ + _flatten_first_two_dims(y) for y in tiled_output_tensors] + + with ops.name_scope("pfor"): + if iters_value is None or iters_value % parallel_iterations: + output_tensors = cond.cond( + math_ops.equal(num_remaining_iterations, 0), + lambda: tiled_output_tensors, + lambda: [array_ops.concat([x, y], axis=0) # pylint: disable=g-long-lambda + for x, y in zip(remaining_output_tensors, + tiled_output_tensors)]) + else: + output_tensors = tiled_output_tensors + flattened_output_tensors = nest.flatten(output_tensors) + + for output, original_output in zip(flattened_output_tensors, + nest.flatten(loop_fn_output_tensors)): + # Restore any shape information lost from tiling. + # TODO(b/174254748): this may not be correct for stacked `variant`s. + output.set_shape( + tensor_shape.TensorShape([iters_value]).concatenate( + original_output.shape)) + return nest.map_structure_up_to( + loop_fn_outputs, + functools.partial(_composite_from_tensors, batch_size=iters_value), + nest.pack_sequence_as(loop_fn_output_tensors, + flattened_output_tensors), + loop_fn_outputs) + + +def _broadcasting_gather(x, i): + """Wrapper for gather that implicitly broadcasts unit dimensions.""" + static_first_dim = tensor_shape.dimension_value(x.shape[0]) + if static_first_dim == 1: + i = 0 + elif static_first_dim is None: + i = array_ops.where_v2(array_ops.shape(x)[0] > 1, i, 0) + result = array_ops.gather(x, i) + return result + + +# pylint: disable=protected-access +def _gather_from_tensor_or_composite(x, i): + """Wrapper for gather that handles CompositeTensors.""" + if _should_expand_composite(x): + spec = x._type_spec + gathered_tensors = [_broadcasting_gather(t, i) + for t in spec._to_batched_tensor_list(x)] + return spec._unbatch()._from_compatible_tensor_list(gathered_tensors) + return _broadcasting_gather(x, i) +# pylint: enable=protected-access + + +@tf_export("vectorized_map") +def vectorized_map(fn, elems, fallback_to_while_loop=True, warn=True): + """Parallel map on the list of tensors unpacked from `elems` on dimension 0. + + This method works similar to `tf.map_fn` but is optimized to run much faster, + possibly with a much larger memory footprint. The speedups are obtained by + vectorization (see [Auto-Vectorizing TensorFlow Graphs: Jacobians, + Auto-Batching and Beyond](https://arxiv.org/pdf/1903.04243.pdf)). The idea + behind vectorization is to semantically launch all the invocations of `fn` in + parallel and fuse corresponding operations across all these invocations. This + fusion is done statically at graph generation time and the generated code is + often similar in performance to a manually fused version. + + Because `tf.vectorized_map` fully parallelizes the batch, this method will + generally be significantly faster than using `tf.map_fn`, especially in eager + mode. However this is an experimental feature and currently has a lot of + limitations: + - There should be no data dependency between the different semantic + invocations of `fn`, i.e. it should be safe to map the elements of the + inputs in any order. + - Stateful kernels may mostly not be supported since these often imply a + data dependency. We do support a limited set of such stateful kernels + though (like RandomFoo, Variable operations like reads, etc). + - `fn` has limited support for control flow operations. + - `fn` should return nested structure of Tensors or Operations. However + if an Operation is returned, it should have zero outputs. + - The shape and dtype of any intermediate or output tensors in the + computation of `fn` should not depend on the input to `fn`. + + Examples: + ```python + def outer_product(a): + return tf.tensordot(a, a, 0) + + batch_size = 100 + a = tf.ones((batch_size, 32, 32)) + c = tf.vectorized_map(outer_product, a) + assert c.shape == (batch_size, 32, 32, 32, 32) + ``` + + ```python + # Computing per-example gradients + + batch_size = 10 + num_features = 32 + layer = tf.keras.layers.Dense(1) + + def model_fn(arg): + with tf.GradientTape() as g: + inp, label = arg + inp = tf.expand_dims(inp, 0) + label = tf.expand_dims(label, 0) + prediction = layer(inp) + loss = tf.nn.l2_loss(label - prediction) + return g.gradient(loss, (layer.kernel, layer.bias)) + + inputs = tf.random.uniform([batch_size, num_features]) + labels = tf.random.uniform([batch_size, 1]) + per_example_gradients = tf.vectorized_map(model_fn, (inputs, labels)) + assert per_example_gradients[0].shape == (batch_size, num_features, 1) + assert per_example_gradients[1].shape == (batch_size, 1) + ``` + + Args: + fn: The callable to be performed. It accepts one argument, which will have + the same (possibly nested) structure as `elems`, and returns a possibly + nested structure of Tensors and Operations, which may be different than + the structure of `elems`. + elems: A tensor or (possibly nested) sequence of tensors, each of which will + be unpacked along their first dimension. The nested sequence of the + resulting slices will be mapped over by `fn`. The first dimensions of all + elements must broadcast to a consistent value; equivalently, each + element tensor must have first dimension of either `B` or `1`, for some + common batch size `B >= 1`. + fallback_to_while_loop: If true, on failing to vectorize an operation, + the unsupported op is wrapped in a tf.while_loop to execute the map + iterations. Note that this fallback only happens for unsupported ops and + other parts of `fn` are still vectorized. If false, on encountering an + unsupported op, a ValueError is thrown. Note that the fallbacks can result + in slowdowns since vectorization often yields speedup of one to two orders + of magnitude. + warn: If set to `false`, this will supress any warnings due to operation + conversions in the provided `fn` falling back to while loops. + + Returns: + A tensor or (possibly nested) sequence of tensors. Each tensor packs the + results of applying fn to tensors unpacked from elems along the first + dimension, from first to last. + + Although they are less common as user-visible inputs and outputs, note that + tensors of type `tf.variant` which represent tensor lists (for example from + `tf.raw_ops.TensorListFromTensor`) are vectorized by stacking the list + contents rather than the variant itself, and so the container tensor will + have a scalar shape when returned rather than the usual stacked shape. This + improves the performance of control flow gradient vectorization. + + Raises: + ValueError: If vectorization fails and fallback_to_while_loop is False. + """ + elems = variable_utils.convert_variables_to_tensors(elems) + elems = nest.map_structure(ops.convert_to_tensor, + elems, + expand_composites=True) + + def loop_fn(i): + gathered_elems = nest.map_structure( + lambda x: _gather_from_tensor_or_composite(x, i), elems) + return fn(gathered_elems) + + # Extract batch size from the maximum first dimension of any element. + flat_elems = nest.flatten( + nest.map_structure( + functools.partial(_composite_to_tensors, + is_batched=True), + elems)) + def _get_shape(x): + if x.shape.rank is None: + return None + return x.shape.as_list()[0] + static_first_dims = [_get_shape(elem) for elem in flat_elems] + if any(s is None for s in static_first_dims): + batch_size = math_ops.reduce_max( + [array_ops.shape(elem)[0] for elem in flat_elems]) + else: + batch_size = max(static_first_dims) + + return pfor( + loop_fn, + batch_size, + fallback_to_while_loop=fallback_to_while_loop, + warn=warn) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/gradients.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/gradients.py new file mode 100644 index 0000000000000000000000000000000000000000..da667a5e1bbde54f3e46be28d75d8c58fc700b1e --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/gradients.py @@ -0,0 +1,144 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Jacobian ops.""" +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import gradients_impl as gradient_ops +from tensorflow.python.ops.parallel_for import control_flow_ops +from tensorflow.python.util import nest + + +def jacobian(output, inputs, use_pfor=True, parallel_iterations=None): + """Computes jacobian of `output` w.r.t. `inputs`. + + Args: + output: A tensor. + inputs: A tensor or a nested structure of tensor objects. + use_pfor: If true, uses pfor for computing the jacobian. Else uses + tf.while_loop. + parallel_iterations: A knob to control how many iterations and dispatched in + parallel. This knob can be used to control the total memory usage. + + Returns: + A tensor or a nested structure of tensors with the same structure as + `inputs`. Each entry is the jacobian of `output` w.r.t. to the corresponding + value in `inputs`. If output has shape [y_1, ..., y_n] and inputs_i has + shape [x_1, ..., x_m], the corresponding jacobian has shape + [y_1, ..., y_n, x_1, ..., x_m]. Note that in cases where the gradient is + sparse (IndexedSlices), jacobian function currently makes it dense and + returns a Tensor instead. This may change in the future. + """ + flat_inputs = nest.flatten(inputs) + output_tensor_shape = output.shape + output_shape = array_ops.shape(output) + output = array_ops.reshape(output, [-1]) + + def loop_fn(i): + y = array_ops.gather(output, i) + return gradient_ops.gradients(y, flat_inputs) + + try: + output_size = int(output.shape[0]) + except TypeError: + output_size = array_ops.shape(output)[0] + + if use_pfor: + pfor_outputs = control_flow_ops.pfor( + loop_fn, output_size, parallel_iterations=parallel_iterations) + else: + pfor_outputs = control_flow_ops.for_loop( + loop_fn, + [output.dtype] * len(flat_inputs), + output_size, + parallel_iterations=parallel_iterations) + + for i, out in enumerate(pfor_outputs): + if isinstance(out, tensor.Tensor): + new_shape = array_ops.concat( + [output_shape, array_ops.shape(out)[1:]], axis=0) + out = array_ops.reshape(out, new_shape) + out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape)) + pfor_outputs[i] = out + + return nest.pack_sequence_as(inputs, pfor_outputs) + + +def batch_jacobian(output, inp, use_pfor=True, parallel_iterations=None): + """Computes and stacks jacobians of `output[i,...]` w.r.t. `input[i,...]`. + + e.g. + x = tf.constant([[1, 2], [3, 4]], dtype=tf.float32) + y = x * x + jacobian = batch_jacobian(y, x) + # => [[[2, 0], [0, 4]], [[6, 0], [0, 8]]] + + Args: + output: A tensor with shape [b, y1, ..., y_n]. `output[i,...]` should + only depend on `inp[i,...]`. + inp: A tensor with shape [b, x1, ..., x_m] + use_pfor: If true, uses pfor for computing the Jacobian. Else uses a + tf.while_loop. + parallel_iterations: A knob to control how many iterations are vectorized + and dispatched in parallel. The default value of None, when use_pfor is + true, corresponds to vectorizing all the iterations. When use_pfor is + false, the default value of None corresponds to parallel_iterations=10. + This knob can be used to control the total memory usage. + + Returns: + A tensor `t` with shape [b, y_1, ..., y_n, x1, ..., x_m] where `t[i, ...]` + is the jacobian of `output[i, ...]` w.r.t. `inp[i, ...]`, i.e. stacked + per-example jacobians. + + Raises: + ValueError: if first dimension of `output` and `inp` do not match. + """ + output_shape = output.shape + if not output_shape[0].is_compatible_with(inp.shape[0]): + raise ValueError(f"Need first dimension of `output` shape ({output.shape}) " + f"and `inp` shape ({inp.shape}) to match.") + if output_shape.is_fully_defined(): + batch_size = int(output_shape[0]) + output_row_size = output_shape.num_elements() // batch_size + else: + output_shape = array_ops.shape(output) + batch_size = output_shape[0] + output_row_size = array_ops.size(output) // batch_size + inp_shape = array_ops.shape(inp) + # Flatten output to 2-D. + with ops.control_dependencies( + [check_ops.assert_equal(batch_size, inp_shape[0])]): + output = array_ops.reshape(output, [batch_size, output_row_size]) + + def loop_fn(i): + y = array_ops.gather(output, i, axis=1) + return gradient_ops.gradients(y, inp)[0] + + if use_pfor: + pfor_output = control_flow_ops.pfor(loop_fn, output_row_size, + parallel_iterations=parallel_iterations) + else: + pfor_output = control_flow_ops.for_loop( + loop_fn, output.dtype, + output_row_size, + parallel_iterations=parallel_iterations) + if pfor_output is None: + return None + pfor_output = array_ops.reshape(pfor_output, + [output_row_size, batch_size, -1]) + output = array_ops.transpose(pfor_output, [1, 0, 2]) + new_shape = array_ops.concat([output_shape, inp_shape[1:]], axis=0) + return array_ops.reshape(output, new_shape) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/pfor.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/pfor.py new file mode 100644 index 0000000000000000000000000000000000000000..88c9483edd7019368e219e1f1ce6a4a5f57df237 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/pfor.py @@ -0,0 +1,5220 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Compiled parallel-for loop.""" +# pylint: disable=missing-docstring,g-direct-tensorflow-import + +import collections +import functools +import string +import sys +import traceback +from typing import List + +from tensorflow.compiler.tf2xla.python import xla +from tensorflow.core.framework import full_type_pb2 +from tensorflow.python.eager import context +from tensorflow.python.eager import def_function +from tensorflow.python.eager import execute +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import func_graph +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import cond as tf_cond +from tensorflow.python.ops import control_flow_assert +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_switch_case +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.ops import gen_linalg_ops +from tensorflow.python.ops import gen_list_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gen_optional_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import gen_random_ops +from tensorflow.python.ops import gen_sparse_ops +from tensorflow.python.ops import gen_spectral_ops +from tensorflow.python.ops import handle_data_util +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import manip_ops +from tensorflow.python.ops import map_fn +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import while_loop +from tensorflow.python.platform import flags +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import compat +from tensorflow.python.util import nest +from tensorflow.python.util import numpy_compat +from tensorflow.python.util import object_identity + + +# TODO(agarwal): remove flag. +flags.DEFINE_bool( + "op_conversion_fallback_to_while_loop", True, + "DEPRECATED: Flag is ignored.") + + +def _variant_handle_data(t): + """Fetches handle data for a variant tensor `t`, or None if unavailable.""" + handle_data = resource_variable_ops.get_eager_safe_handle_data(t) + if not handle_data.is_set: + return None + return handle_data.shape_and_type + + +def _variant_type_id(t): + """Returns the full_type_pb2 type of `t`, or None if it is not available.""" + if t.dtype != dtypes.variant: + return None + shapes_and_types = _variant_handle_data(t) + if shapes_and_types is None or not shapes_and_types: + # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can + # make this an error instead of assuming TensorLists have handle data. + return None # Presumed not a TensorList/Optional + return shapes_and_types[0].type.type_id + + +_INTERNAL_STACKING_TYPE_IDS = ( + full_type_pb2.TFT_ARRAY, + full_type_pb2.TFT_OPTIONAL) + + +def _is_variant_with_internal_stacking(t): + """Identifies variant tensors which pfor always maintains as scalars. + + For these, the pfor tensor is recorded as "stacked" if the content of the + variant tensor (e.g. the elements of a TensorList) are all stacked. + + Args: + t: A tensor to identify. + Returns: + True if `t` is a TensorList/Optional, False not, None if unknown. + """ + type_id = _variant_type_id(t) + return type_id in _INTERNAL_STACKING_TYPE_IDS + + +def _parse_variant_shapes_and_types(t): + """Extracts shape and dtype information from a variant tensor `t`.""" + shapes_and_types = _variant_handle_data(t) + if shapes_and_types is None or not shapes_and_types: + raise ValueError("Required handle data not set for {!r}".format(t)) + if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: + return shapes_and_types + else: + if shapes_and_types[0].type.type_id == full_type_pb2.TFT_UNSET: + return shapes_and_types + else: + raise ValueError( + "Attempted to stack a variant-dtype tensor with no type set ({!r})" + .format(t)) + + +def _rank(t): + """Returns rank as an integer (when statically known) or as a tensor.""" + rank = t.get_shape().rank if isinstance(t, tensor_lib.Tensor) else None + return array_ops.rank(t) if rank is None else rank + + +def _size(t, dtype=None): + """Returns size as an integer (when statically known) or as a tensor.""" + size = ( + t.get_shape().num_elements() if isinstance(t, tensor_lib.Tensor) else None + ) + return array_ops.size(t, out_type=dtype) if size is None else size + + +def _expand_dims(t, axis, num_axes=1): + """Similar to `expand_dims` but supports insertion of multiple axes.""" + if isinstance(num_axes, int): + for _ in range(num_axes): + t = array_ops.expand_dims(t, axis) + else: + shape = array_ops.shape(t) + ones = array_ops.fill( + array_ops.reshape(num_axes, [1]), constant_op.constant(1, shape.dtype) + ) + new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) + t = array_ops.reshape(t, new_shape) + return t + + +def _stack(t, length): + """stacks `t` `length` times.""" + # Note that this stacking may currently be triggered, for example, when a + # loop invariant tensor with dtype variant is input to a while_loop which then + # produces a loop dependent output. Simply stacking the variants may not be + # suitable since operations on stacked handles may expect a vectorized version + # of the variant. + if t.dtype == dtypes.variant: + shapes_and_types = _parse_variant_shapes_and_types(t) + if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: + if len(shapes_and_types) != 1: + raise ValueError( + f"Expected handle data of length 1, got {shapes_and_types!r} of " + f"length {len(shapes_and_types)}.") + return wrap( + _stack_tensor_list(t, shapes_and_types[0].dtype, length), + True) + else: + raise ValueError( + "Attempted to stack an unhandled variant-dtype tensor of " + f"type {shapes_and_types[0].type!r} ({t!r}).") + shape = array_ops.shape(t) + ones = array_ops.ones_like(shape) + ones = array_ops.reshape(ones, [-1]) + length = array_ops.reshape(length, [-1]) + length = math_ops.cast(length, shape.dtype) + multiples = array_ops.concat([length, ones], 0) + t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) + return wrap(t, True) + + +# The following stateful ops can be safely called once, and with the same +# signature as the unconverted version, if their inputs are loop invariant. +# TODO(agarwal): implement a strategy for converting Variable reads/writes. The +# plan is to map each read/write in the loop_fn to a corresponding merged +# read/write in the converted graph. Writes need to be mergeable (e.g. +# AssignAdd) to be used in `pfor`. Given a certain read/write order in the +# loop_fn, doing a one-to-one conversion will simulate executing such +# instructions in lock-step across all iterations. +passthrough_stateful_ops = set([ + "VariableV2", + "VarHandleOp", + "VariableShape", + "ReadVariableOp", + "StackV2", + "TensorArrayWriteV3", + "TensorArrayReadV3", + "TensorArraySizeV3", +]) + + +# Ops which we will treat like stateful for the purpose of vectorization. +# Typically this is used to force pfor converters to run for these ops. +force_stateful_ops = set([ + # We vectorize this since we need to change the element shape set on the + # list. + "TensorListReserve", +]) + + +def _is_stateful_pfor_op(op): + if isinstance(op, WhileOp): + return op.is_stateful + if op.type == "Const": + # Const didn't have an op_def. + return False + if op.type in passthrough_stateful_ops: + return False + if op.type in force_stateful_ops: + return True + assert hasattr(op, "op_def") and op.op_def is not None, op + return op.op_def.is_stateful + + +# pylint: disable=protected-access +class WhileOp: + """Object for storing state for converting the outputs of a while_loop.""" + + def __init__( + self, + exit_node: tensor_lib.Tensor, + pfor_ops: List[ops.Operation], + fallback_to_while_loop: bool, + pfor_config: "PForConfig", + ): + """Initializer. + + Args: + exit_node: A tensor output from the while_loop. + pfor_ops: list of ops inside the current pfor loop. + fallback_to_while_loop: If True, fallback to while loop when conversion of + an op is not supported + pfor_config: PForConfig object used while constructing loop body. + """ + self._fallback_to_while_loop = fallback_to_while_loop + self._pfor_config = pfor_config + self._pfor_ops = set(pfor_ops) + self._pfor_op_ids = set(x._id for x in pfor_ops) + assert isinstance(exit_node, tensor_lib.Tensor) + self._while_context = exit_node.op._get_control_flow_context() + assert isinstance(self._while_context, control_flow_ops.WhileContext) + self._context_name = self._while_context.name + self._condition = self._while_context.pivot.op.inputs[0] + # Parts of an external while_loop could be created inside a pfor loop. + # However for the purpose here, we declare such loops to be external. Also + # note that we check if the condition was created inside or outside to + # determine if the while_loop was first created inside or outside. + # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. + self._is_inside_loop = self.op_is_inside_loop(self._condition.op) + if self._is_inside_loop: + for e in self._while_context.loop_exits: + assert self.op_is_inside_loop(e.op) + + # Note the code below tries to reverse engineer an existing while_loop graph + # by assuming the following pattern of nodes. + # + # NextIteration <---- Body <--- Enter + # | ^ + # V ___| Y + # Enter -> Merge -> Switch___ + # ^ | N + # | V + # LoopCond Exit + + # Node that elements in the list below correspond one-to-one with each + # other. i.e. these lists are the same size, and the i_th entry corresponds + # to different Operations/Tensors of a single cycle as illustrated above. + # List of Switch ops (ops.Operation) that feed into an Exit Node. + self._exit_switches = [] + # List of inputs (tensor_lib.Tensor) to NextIteration. + self._body_outputs = [] + # List of list of control inputs of the NextIteration nodes. + self._next_iter_control_inputs = [] + # List of Merge ops (ops.Operation). + self._enter_merges = [] + # List of output (tensor_lib.Tensor) of Exit nodes. + self._outputs = [] + + # List of Enter Tensors. + # There are two types of Enter nodes: + # - The Enter nodes that are used in the `loop_vars` argument to + # `while_loop` (see + # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect + # these Enter nodes immediately below by tracing backwards from the Exit + # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the + # diagram above. This allows us to have a 1:1 correspondence between the + # self._outputs and the first elements in self._enters. + # - The Enter nodes that are used only by the body. They don't appear in the + # `loop_vars` and are not returned from the `while_loop`. In Python code, + # they are usually captured by the body lambda. We collect them below by + # iterating over all the ops in the graph. They are appended to the end of + # self._enters or self._direct_enters, and don't correspond to any outputs + # in self._outputs. Note that we keep the resource/variant Enter nodes in + # self._direct_enters and the constructed while_loop's body uses them + # directly as opposed to passing them as loop variables. This is done + # because the while_body cannot partition the resource/variant Tensors, so + # it has to leave them unchanged. + self._enters = [] + self._direct_enters = [] + + for e in self._while_context.loop_exits: + self._outputs.append(e.op.outputs[0]) + switch = e.op.inputs[0].op + assert switch.type == "Switch", switch + self._exit_switches.append(switch) + merge = switch.inputs[0].op + assert merge.type == "Merge", merge + self._enter_merges.append(merge) + enter = merge.inputs[0].op + assert enter.type == "Enter", enter + self._enters.append(enter.outputs[0]) + next_iter = merge.inputs[1].op + assert next_iter.type == "NextIteration", next_iter + self._body_outputs.append(next_iter.inputs[0]) + self._next_iter_control_inputs.append(next_iter.control_inputs) + + # Collect all the Enter nodes that are not part of `loop_vars`, the second + # category described above. + # Also track whether the loop body has any stateful ops. + self._is_stateful = False + for op in ops.get_default_graph().get_operations(): + # TODO(agarwal): make sure this works with nested case. + control_flow_context = op._get_control_flow_context() + if control_flow_context is None: + continue + if control_flow_context.name == self._context_name: + self._is_stateful |= _is_stateful_pfor_op(op) + if op.type == "Enter": + output = op.outputs[0] + if output not in self._enters: + if output.dtype in (dtypes.resource, dtypes.variant): + if output not in self._direct_enters: + self._direct_enters.append(output) + else: + self._enters.append(output) + + def __str__(self) -> str: + """String representation.""" + return "while_loop(%s)" % self.name + + @property + def inputs(self): + """Input to all the Enter nodes.""" + return [x.op.inputs[0] for x in self._enters + self._direct_enters] + + @property + def control_inputs(self): + """Control input to all the Enter nodes.""" + control_inputs = [] + for x in self._enters + self._direct_enters: + control_inputs.extend(x.op.control_inputs) + return control_inputs + + @property + def outputs(self) -> List[tensor_lib.Tensor]: + """Outputs of all the Exit nodes.""" + return self._outputs + + @property + def name(self) -> str: + """Context name for the while loop.""" + return self._context_name + + @property + def is_inside_loop(self) -> bool: + """Returns true if the while_loop was created inside the pfor.""" + return self._is_inside_loop + + def op_is_inside_loop(self, op: ops.Operation) -> bool: + """True if op was created inside the pfor loop body.""" + assert isinstance(op, ops.Operation) + # Note that we use self._pfor_op_ids for the check and not self._pfor_ops + # since it appears there tensorflow API could return different python + # objects representing the same Operation node. + return op._id in self._pfor_op_ids + + @property + def is_stateful(self) -> bool: + return self._is_stateful + + @property + def pfor_converter(self) -> "WhileOp": + """Return a converter for the while loop.""" + return self + + def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, + inputs_stacked): + """Create a PFor object for converting parts of the while_loop. + + Args: + parent_pfor: PFor object being used for converting the while_loop. + indices: int32 Tensor of ids for the iterations that are still active + (i.e. did not exit the while_loop). + cond_stacked: True if the while_loop condition is stacked. + inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note + that these Tensors are a subset of the loop variables for the generated + while_loop. + inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, + indicating if the value is stacked or not. + + Returns: + A PFor instance. The instance is initialized by adding conversion mappings + of nodes that will be external to the conversion that the returned + instance will be used for. e.g. Enter nodes as well as Merge and Switch + outputs are mapped to converted values. + """ + num_outputs = len(self._outputs) + assert len(inputs) == len(self._enters) + assert len(inputs_stacked) == len(self._enters) + loop_var = parent_pfor.loop_var + loop_len = array_ops.size(indices) + pfor = PFor( + loop_var, + loop_len, + pfor_ops=self._pfor_ops, + all_indices=indices, + all_indices_partitioned=cond_stacked, + fallback_to_while_loop=self._fallback_to_while_loop, + pfor_config=self._pfor_config) + # Map all inputs of Enter nodes in self._direct_enters to their converted + # values. + for enter in self._direct_enters: + enter_input = enter.op.inputs[0] + converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( + enter_input) + # Since these are resources / variants, they should be unstacked. + assert not stacked and not is_sparse_stacked, (enter, converted_enter) + pfor._add_conversion(enter, wrap(converted_enter, False)) + + # Map all Enter nodes to the inputs. + for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): + pfor._add_conversion(enter, wrap(inp, stacked)) + # Map outputs of Switch and Merge. + for i in range(num_outputs): + wrapped_inp = wrap(inputs[i], inputs_stacked[i]) + merge = self._enter_merges[i] + pfor._add_conversion(merge.outputs[0], wrapped_inp) + # Note that second output of Merge is typically not used, except possibly + # as a control dependency. To avoid trying to output the correct value, we + # employ a hack here. We output a dummy invalid value with an incorrect + # dtype. This will allow control dependency to work but if using it as an + # input, it should typically lead to errors during graph construction due + # to dtype mismatch. + # TODO(agarwal): Check in the original graph to see if there are any + # consumers of this Tensor that use it as an input. + pfor._add_conversion(merge.outputs[1], + wrap(constant_op.constant(-1.0), False)) + switch = self._exit_switches[i] + # Don't need to worry about switch.output[0] which will feed to Exit node. + pfor._add_conversion(switch.outputs[1], wrapped_inp) + return pfor + + def _convert_enter(self, parent_pfor: "PFor", enter): + """Converts an Enter node.""" + inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) + control_inputs = [] + for x in enter.op.control_inputs: + converted = parent_pfor._convert_helper(x) + if not isinstance(converted, ops.Operation): + converted = converted.t + control_inputs.append(converted) + if control_inputs: + with ops.control_dependencies(control_inputs): + inp = array_ops.identity(inp) + return inp, stacked + + def _maybe_stacked(self, cache, inp): + """Heuristic to figure out if the converting inp leads to a stacked value. + + + Args: + cache: map from Tensor to boolean indicating stacked/unstacked. + inp: input Tensor. + + Returns: + True if `inp` could get stacked. If the function returns False, the + converted value should be guaranteed to be unstacked. If returning True, + it may or may not be stacked. + """ + if inp in cache: + return cache[inp] + if not self.op_is_inside_loop(inp.op): + return False + op = inp.op + output = False + if op.type in [ + "OnesLike", + "Shape", + "Rank", + "ShapeN", + "ZerosLike", + "TensorArrayV3", + "TensorArraySizeV3", + ]: + output = False + elif _is_stateful_pfor_op(op): + # This may be fairly aggressive. + output = True + elif op.type == "Exit": + # This may be fairly aggressive. + output = True + else: + for t in op.inputs: + if self._maybe_stacked(cache, t): + output = True + break + cache[inp] = output + return output + + def _create_init_values(self, pfor_input: "_PforInput"): + """Create arguments passed to converted while_loop.""" + with ops.name_scope("while_init"): + loop_len_vector = pfor_input.pfor.loop_len_vector + loop_len = loop_len_vector[0] + num_outputs = len(self._outputs) + + inputs = [] + maybe_stacked_cache = {} + # Convert all the Enters. Need to do this before checking for stacking + # below. + for i, enter in enumerate(self._enters): + inp, stacked = self._convert_enter(pfor_input.pfor, enter) + inputs.append(inp) + maybe_stacked_cache[enter] = stacked + # Since this enter node is part of the `loop_vars`, it corresponds to an + # output and its preceding switch. We mark this switch's output the same + # stackness, to act at the base case for the logic below. Below, we will + # be going through the body figuring out which inputs might need to be + # stacked and which inputs can safely remain unstacked. + if i < num_outputs: + maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked + + # Shape invariants for init_values corresponding to self._enters. + input_shape_invariants = [] + # TensorArrays for outputs of converted while loop + output_tas = [] + # Shape invariants for output TensorArrays. + ta_shape_invariants = [] + # List of booleans indicating stackness of inputs, i.e. tensors + # corresponding to self._enters. + inputs_stacked = [] + for i, inp in enumerate(inputs): + enter = self._enters[i] + inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) + # Note that even when an input is unstacked, the body could make it + # stacked. we use a heuristic below to figure out if body may be making + # it stacked. + if i < num_outputs: + body_output = self._body_outputs[i] + if enter.op in self._pfor_ops: + body_output_stacked = self._maybe_stacked(maybe_stacked_cache, + body_output) + else: + # If constructed outside of pfor loop, then the output would not be + # stacked. + body_output_stacked = False + if body_output_stacked and not inp_stacked: + inp = _stack(inp, loop_len_vector).t + inputs[i] = inp + inp_stacked = True + # TODO(agarwal): other attributes for the TensorArray ? + output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) + ta_shape_invariants.append(tensor_shape.TensorShape(None)) + + inputs_stacked.append(inp_stacked) + input_shape_invariants.append(tensor_shape.TensorShape(None)) + + # See documentation for __call__ for the structure of init_values. + init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas + # TODO(agarwal): try stricter shape invariants + shape_invariants = ( + [tensor_shape.TensorShape(None), + tensor_shape.TensorShape(None)] + input_shape_invariants + + ta_shape_invariants) + + return init_values, inputs_stacked, shape_invariants + + def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): + """Handles case when condition is unstacked. + + Note that all iterations end together. So we don't need to partition the + inputs. When all iterations are done, we write the inputs to the + TensorArrays. Note that we only write to index 0 of output_tas. Since all + iterations end together, they can all be output together. + """ + not_all_done = array_ops.reshape(conditions, []) + new_output_tas = [] + # pylint: disable=cell-var-from-loop + for i, out_ta in enumerate(output_tas): + inp = inputs[i] + new_output_tas.append( + tf_cond.cond(not_all_done, lambda: out_ta, + lambda: out_ta.write(0, inp))) + # pylint: enable=cell-var-from-loop + return not_all_done, indices, inputs, new_output_tas + + def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, + output_tas): + num_outputs = len(self._outputs) + # Compute if all iterations are done. + not_all_done = math_ops.reduce_any(conditions) + conditions_int = math_ops.cast(conditions, dtypes.int32) + # Partition the indices. + done_indices, new_indices = data_flow_ops.dynamic_partition( + indices, conditions_int, 2) + + new_inputs = [] + new_output_tas = [] + for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): + # Partition the inputs. + if stacked: + done_inp, new_inp = data_flow_ops.dynamic_partition( + inp, conditions_int, 2) + else: + # TODO(agarwal): avoid this stacking. See TODO earlier in + # _process_cond_unstacked. + done_inp = _stack(inp, [array_ops.size(done_indices)]).t + new_inp = inp + new_inputs.append(new_inp) + # For iterations that are done, write them to TensorArrays. + if i < num_outputs: + out_ta = output_tas[i] + # Note that done_indices can be empty. done_inp should also be empty in + # that case. + new_output_tas.append(out_ta.scatter(done_indices, done_inp)) + return not_all_done, new_indices, new_inputs, new_output_tas + + def _process_body( + self, + pfor_input: "_PforInput", + inputs_stacked, + new_indices, + cond_stacked, + new_inputs, + not_all_done, + ): + """Convert the body function.""" + + def true_fn(control_inputs, body_pfor, body_output, stacked): + """Converts the body function for all but last iteration. + + This essentially converts body_output. Additionally, it needs to handle + any control dependencies on the NextIteration node. So it creates another + Identity node with the converted dependencies. + """ + converted_control_inp = [] + for x in control_inputs: + for t in x.outputs: + converted_control_inp.append(body_pfor._convert_helper(t).t) + if stacked: + # Note convert always does the stacking. + output = body_pfor.convert(body_output) + else: + output, convert_stacked, _ = body_pfor._convert_helper(body_output) + assert convert_stacked == stacked, body_output + with ops.control_dependencies(converted_control_inp): + return array_ops.identity(output) + + body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked, + new_inputs, inputs_stacked) + new_outputs = [] + + for i, (body_output, + stacked) in enumerate(zip(self._body_outputs, inputs_stacked)): + control_inp = self._next_iter_control_inputs[i] + out_dtype = body_output.dtype + # Note that we want to run the body only if not all pfor iterations are + # done. If all are done, we return empty tensors since these values will + # not be used. Notice that the value returned by the loop is based on + # TensorArrays and not directly on these returned values. + # pylint: disable=cell-var-from-loop + new_output = tf_cond.cond( + not_all_done, + lambda: true_fn(control_inp, body_pfor, body_output, stacked), + lambda: constant_op.constant([], dtype=out_dtype)) + # pylint: enable=cell-var-from-loop + new_outputs.append(new_output) + return new_outputs + + def __call__(self, pfor_input: "_PforInput"): + """Converter for the while_loop. + + The conversion of a while_loop is another while_loop. + + The arguments to this converted while_loop are as follows: + not_all_done: Boolean scalar Tensor indicating if all the pfor iterations + are done. + indices: int32 1-D Tensor storing the id of the iterations that are not + done. + args: Remaining arguments. These can be divided into 3 categories: + - First set of arguments are the tensors that correspond to the initial + elements of self._enters. The elements that appear in original while + loop's `loop_vars`. + - The second set of arguments are the tensors that correspond to the + remaining elements of self._enters. These are the tensors that directly + enter the original while loop body. + - Finally, the last set of arguments are TensorArrays. These TensorArrays + correspond to the outputs of the original while_loop, i.e. to the + elements in self._outputs. Each TensorArray has `PFor.loop_len` + elements, i.e. the number of pfor iterations. At the end, the i'th + element of each TensorArray will contain the output computed by the + i'th iteration of pfor. Note that elements can be written into these + tensors arrays in any order, depending on when the corresponding pfor + iteration is done. + If the original while_loop had `k` tensors in its `loop_vars` and its body + directly captured `m` tensors, the `args` will contain `2 * k + m` values. + + In each iteration, the while_loop body recomputes the condition for all + active pfor iterations to see which of them are now done. It then partitions + all the inputs and passes them along to the converted body. Values for all + the iterations that are done are written to TensorArrays indexed by the pfor + iteration number. When all iterations are done, the TensorArrays are stacked + to get the final value. + + Args: + pfor_input: A PForInput object corresponding to the output of any Exit + node from this while loop. + + Returns: + List of converted outputs. + """ + # Create init_values that will be passed to the while_loop. + init_values, inputs_stacked, shape_invariants = self._create_init_values( + pfor_input) + # Note that we use a list as a hack since we need the nested function body + # to set the value of cond_is_stacked. python2.x doesn't support nonlocal + # variables. + cond_is_stacked = [None] + + def cond(not_all_done, *_): + return not_all_done + + def body(not_all_done, indices, *args): + # See documentation for __call__ for the structure of *args. + num_enters = len(self._enters) + inputs = args[:num_enters] + output_tas = args[num_enters:] + # TODO(agarwal): see which outputs have consumers and only populate the + # TensorArrays corresponding to those. Or do those paths get trimmed out + # from inside the while_loop body? + assert len(inputs) >= len(output_tas) + assert len(inputs) == len(inputs_stacked) + + # Convert condition + with ops.name_scope("while_cond"): + # Note that we set cond_stacked to True here. At this point we don't + # know if it could be loop invariant, hence the conservative value is + # to assume stacked. + cond_pfor = self._init_pfor( + pfor_input.pfor, + indices, + cond_stacked=True, + inputs=inputs, + inputs_stacked=inputs_stacked) + conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) + cond_is_stacked[0] = cond_stacked + + # Recompute the new condition, write outputs of done iterations, and + # partition the inputs if needed. + if not cond_stacked: + (not_all_done, new_indices, new_inputs, + new_output_tas) = self._process_cond_unstacked(conditions, indices, + inputs, output_tas) + else: + (not_all_done, new_indices, new_inputs, + new_output_tas) = self._process_cond_stacked(conditions, indices, + inputs, inputs_stacked, + output_tas) + + # Convert body + with ops.name_scope("while_body"): + # Compute the outputs from the body. + new_outputs = self._process_body(pfor_input, inputs_stacked, + new_indices, cond_stacked, new_inputs, + not_all_done) + + # Note that the first num_outputs new values of inputs are computed using + # the body. Rest of them were direct Enters into the condition/body and + # the partitioning done earlier is sufficient to give the new value. + num_outputs = len(self._outputs) + new_args = ([not_all_done, new_indices] + new_outputs + + list(new_inputs[num_outputs:]) + new_output_tas) + return tuple(new_args) + + while_outputs = while_loop.while_loop( + cond, body, init_values, shape_invariants=shape_invariants) + output_tas = while_outputs[-len(self._outputs):] + outputs = [] + assert cond_is_stacked[0] is not None + for inp_stacked, ta in zip(inputs_stacked, output_tas): + if cond_is_stacked[0]: + outputs.append(wrap(ta.stack(), True)) + else: + # Note that if while_loop condition is unstacked, all iterations exit at + # the same time and we wrote those outputs in index 0 of the tensor + # array. + outputs.append(wrap(ta.read(0), inp_stacked)) + return outputs + + +class ConversionNotImplementedError(Exception): + pass + + +class _PforInput: + """Input object passed to registered pfor converters.""" + + __slots__ = ["pfor", "_op", "_inputs"] + + def __init__(self, pfor: "PFor", op: ops.Operation, inputs): + """Creates a _PforInput object. + + Args: + pfor: PFor converter object. + op: the Operation object that is being converted. + inputs: list of WrappedTensor objects representing converted values of the + inputs of `op`. + """ + self.pfor = pfor + self._op = op + self._inputs = inputs + + def stack_inputs(self, stack_indices=None, tile_variants=False): + """Stacks unstacked inputs at `stack_indices`. + + Args: + stack_indices: indices of inputs at which stacking is done. If None, + stacking is done at all indices. + tile_variants: If True, affected indices which have a variant dtype will + be tiled after this operation to match the expected shape of a + vectorized tensor. Variants generally need to be un-tiled when they are + inputs to operations and tiled when returned. + """ + if stack_indices is None: + stack_indices = range(len(self._inputs)) + length = self.pfor.loop_len_vector + for i in stack_indices: + inp = self._inputs[i] + is_variant = inp.t.dtype == dtypes.variant + if not inp.is_stacked: + self._inputs[i] = _stack(inp.t, length) + if tile_variants and is_variant: + self._inputs[i] = wrap( + _tile_variant_with_length(self._inputs[i].t, length), True) + elif not tile_variants and is_variant: + self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True) + + def expanddim_inputs_for_broadcast(self): + """Reshapes stacked inputs to prepare them for broadcast. + + Since stacked inputs have an extra leading dimension, automatic broadcasting + rules could incorrectly try to expand dimensions before that leading + dimension. To avoid that, we reshape these stacked inputs to the maximum + rank they will need to be broadcasted to. + + IMPORTANT: This function is heavily optimized for statically known ranks + because it's on the critical path of some huge training graphs. + """ + if len(self._inputs) < 2: + return + + ranks = [ + _rank(inp.t) if inp.is_stacked else (_rank(inp.t) + 1) + for inp in self._inputs + ] + if all(isinstance(rank, int) for rank in ranks): + max_rank = max(ranks) + else: + max_rank = functools.reduce(math_ops.maximum, ranks) + + for i, inp in enumerate(self._inputs): + if not inp.is_stacked: + continue + if isinstance(max_rank, int) and ranks[i] == max_rank: + continue + self._inputs[i] = wrap(_expand_dims(inp.t, 1, max_rank - ranks[i]), True) + + @property + def inputs(self): + return self._inputs + + @property + def num_inputs(self): + return len(self._inputs) + + def input(self, index): + assert len(self._inputs) > index, (index, self._inputs) + return self._inputs[index] + + def stacked_input(self, index): + t, is_stacked, _ = self.input(index) + if not is_stacked: + op_type = self.op_type + op_def = getattr(self._op, "op_def", None) + if op_def is None: + input_name = "at index %d" % index + else: + input_name = "\"%s\"" % op_def.input_arg[index].name + raise ConversionNotImplementedError( + f"Input {input_name} of op '{op_type}' expected to be not loop " + "invariant.") + return t + + def unstacked_input(self, index): + t, is_stacked, _ = self.input(index) + if is_stacked: + op_type = self.op_type + op_def = getattr(self._op, "op_def", None) + if op_def is None: + input_name = "at index %d" % index + else: + input_name = "\"%s\"" % op_def.input_arg[index].name + raise ConversionNotImplementedError( + f"Input {input_name} of op '{op_type}' expected to be loop " + "invariant.") + return t + + @property + def op(self) -> ops.Operation: + return self._op + + @property + def op_type(self): + return self._op.type + + def get_attr(self, attr): + return self._op.get_attr(attr) + + @property + def outputs(self): + return self._op.outputs + + def output(self, index): + assert index < len(self._op.outputs) + return self._op.outputs[index] + + +_pfor_converter_registry = {} + + +class RegisterPFor: + """Utility to register converters for pfor. + + Usage: + @RegisterPFor(foo_op_type) + def _foo_converter(pfor_input: _PforInput): + ... + + The above will register conversion function `_foo_converter` for handling + conversion of `foo_op_type`. These converters are called during vectorization + of a `pfor` loop body. For each operation node in this loop body, + the vectorization process will call the converter corresponding to the + operation type of the node. + + During conversion, the registered function will be called with a single + argument `pfor_input`, of type `PForInput`, which will contain state needed + for the conversion. When the converter is called for a node, all its inputs + should already have been converted and these converted values are stored in + `pfor_input.inputs`. This registered function should output a list of + WrappedTensor objects with the same length as the number of outputs of the + node being converted. If the node had zero outputs, then it should return an + ops.Operation object. These new sets of nodes should implement the + functionality of running that operation for the number of iterations specified + by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each + iteration are picked from `pfor_inputs.inputs()`. + + One tricky aspect of the conversion process is keeping track of, and + leveraging loop invariance of computation. Each converted input is a + WrappedTensor which indicates whether the input was loop invariant or not. If + the converted value is loop invariant, its rank should match the rank of the + corresponding tensor in the loop body, else its rank is larger by 1. The + converter should look at the loop invariance of the inputs and generate new + nodes based on that. Note that the converter will not be called if all inputs + are loop invariant and the operation is not stateful. The converter should + determine if its own output is loop invariant and `wrap` its output + accordingly. + + Example: + + Here, the converter is trying to convert a Reshape node in the loop body. This + node will have two inputs: the tensor to reshape, and the new shape. The + example here only handles the case where the shape is loop invariant. + + @RegisterPFor("Reshape") + def _convert_reshape(pfor_input: _PforInput): + # We assume that input is not loop invariant. Call to `stacked_input` + # asserts that and returns the converted value. This value will have a rank + # larger by 1 compared to the rank of the input in the loop body. + t = pfor_input.stacked_input(0) + + # We assume that shape input is loop invariant. Call to `unstacked_input` + # asserts that and returns the converted value. + shape = pfor_input.unstacked_input(1) + + # We compute `new_shape` by prepending the number of iterations to the + # original shape. + new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], + axis=0) + + # The vectorized output involves reshaping the converted input `t` using + # `new_shape`. + new_output = array_ops.reshape(t, new_shape) + + # The converted output is marked as not loop invariant using the call to + # wrap. + return wrap(new_output, True) + """ + + def __init__(self, op_type): + """Creates an object to register a converter for op with type `op_type`.""" + self.op_type = op_type + + def __call__(self, converter): + name = self.op_type + assert name not in _pfor_converter_registry, "Re-registering %s " % name + _pfor_converter_registry[name] = converter + return converter + + +class RegisterPForWithArgs(RegisterPFor): + """Utility to register converters for pfor. + + Usage: + @RegisteRPFor(foo_op_type, foo=value, ....) + def _foo_converter(pfor_input, foo=None, ....): + ... + + See RegisterPFor for details on the conversion function. + `RegisterPForWithArgs` allows binding extra arguments to the + conversion function at registration time. + """ + + def __init__(self, op_type, *args, **kw_args): + super(RegisterPForWithArgs, self).__init__(op_type) + self._args = args + self._kw_args = kw_args + + def __call__(self, converter): + + def _f(pfor_input: _PforInput): + return converter(pfor_input, self.op_type, *self._args, **self._kw_args) + + super(RegisterPForWithArgs, self).__call__(_f) + return converter + + +# TODO(agarwal): call raw_ops instead of calling these low level routines. +def _create_op(op_type, inputs, op_dtypes, attrs=None): + """Utility to create an op.""" + op = ops.get_default_graph().create_op( + op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) + flat_attrs = [] + # The tape expects an alternating flat list of names and attribute values. + for a in attrs: + flat_attrs.append(str(a)) + flat_attrs.append(op.get_attr(str(a))) + execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:]) + return op + + +WrappedTensor = collections.namedtuple("WrappedTensor", + ["t", "is_stacked", "is_sparse_stacked"]) +"""Wrapper around the result of a Tensor conversion. + +The additional fields are useful for keeping track of the conversion state as +data flows through the ops in the loop body. For every op whose output is a +Tensor, its converter should return either a WrappedTensor or a list of +WrappedTensors. + +Args: + t: The converted tensor + is_stacked: True if the tensor is stacked, i.e. represents the results of all + the iterations of the loop, where each row i of the tensor corresponds to + that op's output on iteration i of the loop. False if the tensor is not + stacked, i.e. represents the result of the op on of a single iteration of + the loop, where the result does not vary between iterations. + is_sparse_stacked: True if the tensor corresponds to a component tensor + (indices, values, or dense_shape) of a sparse tensor, and has been logically + stacked via a sparse conversion. +""" + + +def wrap(tensor, is_stacked=True, is_sparse_stacked=False): + """Helper to create a WrappedTensor object.""" + assert isinstance(is_stacked, bool) + assert isinstance(is_sparse_stacked, bool) + assert isinstance(tensor, tensor_lib.Tensor), type(tensor) + assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " + "stacked via a sparse " + "conversion, it must also be " + "stacked.") + return WrappedTensor(tensor, is_stacked, is_sparse_stacked) + + +def _wrap_and_tile_variants(tensor, length): + if tensor.dtype == dtypes.variant: + tensor = _tile_variant_with_length(tensor, length) + return wrap(tensor) + + +def _fallback_converter(pfor_input: _PforInput, root_cause="", warn=False): + msg = ("Using a while_loop for converting " + f"{pfor_input.op_type} cause {root_cause}") + if warn: + logging.warning(msg) + else: + logging.debug(msg) + output_dtypes = [x.dtype for x in pfor_input.outputs] + iter_vec = pfor_input.pfor.loop_len_vector + # Use constant value if available, so that output shapes are static. + iter_vec_value = tensor_util.constant_value(iter_vec) + if iter_vec_value is not None: + iters = iter_vec_value[0].item() + else: + iters = iter_vec[0] + + def while_body(i, *ta_list): + """Body of while loop.""" + inputs = [ + x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs + ] + op_outputs = _create_op( + pfor_input.op_type, + inputs, + output_dtypes, + attrs=pfor_input.op.node_def.attr).outputs + + outputs = [] + # TODO(agarwal): Add tf.debugging asserts to check that the shapes across + # the different iterations are the same. + for out, ta in zip(op_outputs, ta_list): + assert isinstance(out, tensor_lib.Tensor) + outputs.append(ta.write(i, out)) + return tuple([i + 1] + outputs) + + ta_list = while_loop.while_loop( + lambda i, *ta: i < iters, while_body, [0] + + [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes + ])[1:] + return tuple([wrap(ta.stack(), True) for ta in ta_list]) + + +class PForConfig: + """A configuration object used to communicate with loop body function.""" + + def __init__(self): + # This may be set to the number of iterations. + self._maybe_iters = None + # Map from reduction node, created by `reduce`, to the bundle of reduction + # function and arguments. + self._reduce_map = {} + + def _has_reductions(self): + """True if some reductions where performed by loop body.""" + return len(self._reduce_map) + + def _set_iters(self, iters): + """Set number of pfor iterations.""" + if isinstance(iters, tensor_lib.Tensor): + iters = tensor_util.constant_value(iters) + self._maybe_iters = iters + + def reduce(self, fn, *args): + """Performs reduction `fn` on `args` vectorized across pfor iterations. + + Note that `fn` is traced once inside the loop function context. Hence any + captures or side-effects will happen in that context. Call to the traced + version of `fn` happens during the construction of the vectorized code. + + Note that this currently may not work inside a control flow construct. + Args: + fn: a reduction function. It will be called with arguments that have the + same structure as *args but with individual values whose rank may be + higher by 1 since they represent loop invariant vectorized versions of + the corresponding Tensors in *args. + *args: unvectorized Tensors. + + Returns: + The result of running `fn` on the vectorized versions of `*args`. These + outputs will be available as loop invariant values to all the iterations. + """ + assert not context.executing_eagerly() + # Creates a concrete function that will be used for reduction. + tensor_specs = [] + for arg in args: + if not isinstance(arg, tensor_lib.Tensor): + raise ValueError(f"Got a non-Tensor argument {arg} in reduce.") + batched_shape = tensor_shape.TensorShape([self._maybe_iters + ]).concatenate(arg.shape) + tensor_specs.append( + tensor_lib.TensorSpec(shape=batched_shape, dtype=arg.dtype)) + concrete_function = def_function.function(fn).get_concrete_function( + *tensor_specs) + + # Creates PlaceholderWithDefault and IdentityN nodes corresponding the + # reduction. + pl_outputs = [] + with ops.control_dependencies(args): + for output in concrete_function.outputs: + if not isinstance(output, tensor_lib.Tensor): + raise ValueError(f"Got a non-Tensor output {output} while running " + "reduce.") + # Note that we use placeholder_with_default just to make XLA happy since + # it does not like placeholder ops. + if output.shape.is_fully_defined(): + dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype) + pl_outputs.append( + array_ops.placeholder_with_default(dummy, shape=output.shape)) + else: + # TODO(agarwal): support case when under XLA and output.shape is not + # fully defined. + pl_outputs.append( + array_ops.placeholder(output.dtype, shape=output.shape)) + + reduction_op = array_ops.identity_n(pl_outputs)[0].op + self._reduce_map[reduction_op] = (concrete_function, args) + if len(reduction_op.outputs) == 1: + return reduction_op.outputs[0] + else: + return tuple(reduction_op.outputs) + + # TODO(agarwal): handle reductions inside control flow constructs. + def reduce_concat(self, x): + """Performs a concat reduction on `x` across pfor iterations. + + Note that this currently may not work inside a control flow construct. + Args: + x: an unvectorized Tensor. + + Returns: + A Tensor that has rank one higher than `x`. The value is the vectorized + version of `x`, i.e. stacking the value of `x` across different pfor + iterations. + """ + return self.reduce(lambda y: y, x) + + def reduce_mean(self, x): + """Performs a mean reduction on `x` across pfor iterations. + + Note that this currently may not work inside a control flow construct. + Args: + x: an unvectorized Tensor. + + Returns: + A Tensor that has same rank as `x`. The value is the mean of the values + of `x` across the pfor iterations. + """ + return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x) + + def reduce_sum(self, x): + """Performs a sum reduction on `x` across pfor iterations. + + Note that this currently may not work inside a control flow construct. + Args: + x: an unvectorized Tensor. + + Returns: + A Tensor that has same rank as `x`. The value is the sum of the values + of `x` across the pfor iterations. + """ + return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x) + + def _lookup_reduction(self, t): + """Lookups Tensor `t` in the reduction maps.""" + assert isinstance(t, tensor_lib.Tensor), t + return self._reduce_map.get(t.op) + + +class PFor: + """Implementation of rewrite of parallel-for loops. + + This class takes a DAG or a set of DAGs representing the body of a + parallel-for loop, and adds new operations to the graph that implements + functionality equivalent to running that loop body for a specified number of + iterations. This new set of nodes may or may not use a tensorflow loop + construct. + + The process of conversion does not delete or change any existing operations. + It only adds operations that efficiently implement the equivalent + functionality. We refer to the added ops as "converted ops". + + The conversion process uses a simple greedy heuristic. It walks the loop body + and tries to express the functionality of running each node in a loop with a + new set of nodes. When converting an op several cases are possible: + - The op is not inside the loop body. Hence it can be used as is. + - The op does not depend on the iteration number and is stateless. In this + case, it can be used as is. + - The op is not stateful, and depends on iteration number only through control + dependencies. In this case, we can create a single op with same inputs and + attributes, but with "converted" control dependencies. + - The op is not stateful, and all its inputs are loop invariant. In this + case, similar to above, we can create a single op with same inputs and + attributes, but with "converted" control dependencies. + - The op is stateful or at least one of the inputs is not loop invariant. In + this case, we run the registered converter for that op to create a set of + converted ops. All nodes in the set will have converted control dependencies + corresponding to control dependencies of the original op. If the op returned + multiple outputs, "converted outputs" could be produced by different ops in + this set. + """ + + def __init__(self, + loop_var, + loop_len, + pfor_ops, + fallback_to_while_loop, + all_indices=None, + all_indices_partitioned=False, + pfor_config=None, + warn=False): + """Creates an object to rewrite a parallel-for loop. + + Args: + loop_var: Tensor output of a Placeholder operation. The value should + be an int32 scalar representing the loop iteration number. + loop_len: A scalar or scalar Tensor representing the number of iterations + the loop is run for. + pfor_ops: List of all ops inside the loop body. + fallback_to_while_loop: If True, on failure to vectorize an op, a while + loop is used to sequentially execute that op. + all_indices: If not None, an int32 vector with size `loop_len` + representing the iteration ids that are still active. These values + should be unique and sorted. However they may not be contiguous. This is + typically the case when inside a control flow construct which has + partitioned the indices of the iterations that are being converted. + all_indices_partitioned: If True, this object is being constructed from a + control flow construct where not all the pfor iterations are guaranteed + to be active. + pfor_config: PForConfig object used while constructing the loop body. + warn: Whether or not to warn on while loop conversions. + """ + assert isinstance(loop_var, tensor_lib.Tensor) + assert loop_var.op.type == "PlaceholderWithDefault" + self._loop_var = loop_var + loop_len_value = tensor_util.constant_value(loop_len) + if loop_len_value is not None: + loop_len = loop_len_value + self._loop_len_vector = ops.convert_to_tensor([loop_len]) + else: + self._loop_len_vector = array_ops.reshape(loop_len, [1]) + self._all_indices_partitioned = all_indices_partitioned + if all_indices_partitioned: + assert all_indices is not None + self.all_indices = ( + math_ops.range(loop_len) if all_indices is None else all_indices) + + self._conversion_map = object_identity.ObjectIdentityDictionary() + self._conversion_map[loop_var] = wrap(self.all_indices, True) + self._pfor_ops = set(pfor_ops) + self._pfor_op_ids = set(x._id for x in pfor_ops) + self._fallback_to_while_loop = fallback_to_while_loop + self._warn = warn + self._pfor_config = pfor_config + + def op_is_inside_loop(self, op): + """True if op was created inside the pfor loop body.""" + assert isinstance(op, ops.Operation) + # Note that we use self._pfor_op_ids for the check and not self._pfor_ops + # since it appears there tensorflow API could return different python + # objects representing the same Operation node. + return op._id in self._pfor_op_ids + + def _convert_sparse(self, y): + """Returns the converted value corresponding to SparseTensor y. + + For SparseTensors, instead of stacking the component tensors separately, + resulting in component tensors with shapes (N, m, rank), (N, m), and (N, + rank) respectively for indices, values, and dense_shape (where N is the loop + length and m is the number of sparse tensor values per loop iter), we want + to logically stack the SparseTensors, to create a SparseTensor whose + components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) + respectively. + + Here, we try to get the conversion of each component tensor. + If the tensors are stacked via a sparse conversion, return the resulting + SparseTensor composed of the converted components. Otherwise, the component + tensors are either unstacked or stacked naively. In the latter case, we + unstack the component tensors to reform loop_len SparseTensor elements, + then correctly batch them. + + The unstacked tensors must have the same rank. Each dimension of each + SparseTensor will expand to be the largest among all SparseTensor elements + for that dimension. For example, if there are N SparseTensors of rank 3 + being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), + the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). + + Args: + y: A tf.sparse.SparseTensor. + + Returns: + A tf.sparse.SparseTensor that is the converted value corresponding to y. + """ + outputs = [ + self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) + ] + assert all(isinstance(o, WrappedTensor) for o in outputs) + + if all(w.is_sparse_stacked for w in outputs): + return sparse_tensor.SparseTensor(*[w.t for w in outputs]) + + assert not any(w.is_sparse_stacked for w in outputs), ( + "Error converting SparseTensor. All components should be logically " + "stacked, or none.") + + # If component tensors were not sparsely stacked, they are either unstacked + # or stacked without knowledge that they are components of sparse tensors. + # In this case, we have to restack them. + return self._restack_sparse_tensor_logically( + *[self._unwrap_or_tile(w) for w in outputs]) + + def _restack_sparse_tensor_logically(self, indices, values, shape): + sparse_tensor_rank = indices.get_shape().dims[-1].value + if sparse_tensor_rank is not None: + sparse_tensor_rank += 1 + + def fn(args): + res = gen_sparse_ops.serialize_sparse( + args[0], args[1], args[2], out_type=dtypes.variant) + return res + + # Applies a map function to the component tensors to serialize each + # sparse tensor element and batch them all, then deserializes the batch. + # TODO(rachelim): Try to do this without map_fn -- add the right offsets + # to shape and indices tensors instead. + result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant) + return sparse_ops.deserialize_sparse( + result, dtype=values.dtype, rank=sparse_tensor_rank) + + def _unwrap_or_tile(self, wrapped_tensor): + """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" + output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked + if is_stacked: + return output + else: + return _stack(output, self._loop_len_vector).t + + def convert(self, y): + """Returns the converted value corresponding to y. + + Args: + y: A Tensor or a ops.Operation object. If latter, y should not have + any outputs. + + Returns: + If y does not need to be converted, it returns y as is. Else it returns + the "converted value" corresponding to y. + """ + if y is None: + return None + if isinstance(y, sparse_tensor.SparseTensor): + return self._convert_sparse(y) + assert isinstance(y, (tensor_lib.Tensor, ops.Operation)), y + output = self._convert_helper(y) + if isinstance(output, WrappedTensor): + assert isinstance(y, tensor_lib.Tensor) + return self._unwrap_or_tile(output) + else: + assert isinstance(y, ops.Operation) + assert not y.outputs + assert isinstance(output, ops.Operation) + return output + + def _was_converted(self, t): + """True if t is not a conversion of itself.""" + converted_t = self._conversion_map[t] + return converted_t.t is not t + + def _add_conversion(self, old_output, new_output): + assert isinstance( + old_output, (tensor_lib.Tensor, ops.Operation)), old_output + assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output + self._conversion_map[old_output] = new_output + + def _convert_reduction(self, y): + # Handle reductions. + if self._pfor_config is None or isinstance(y, ops.Operation): + return None + reduction = self._pfor_config._lookup_reduction(y) + if reduction is None: + return None + (reduction_fn, reduction_args) = reduction + batched_args = [] + for reduction_arg in reduction_args: + assert isinstance(reduction_arg, tensor_lib.Tensor), reduction_arg + # Tensor being reduced should already be converted due to a control + # dependency on the created placeholder. + # Note that in cases where reduction_arg is in an outer context, one + # needs to locate the corresponding Enter node and use that to lookup + # the conversion. + # TODO(agarwal): handle reductions inside control flow constructs. + assert reduction_arg in self._conversion_map, ( + "Unable to handle reduction of %s, possibly as it was used " + "inside a control flow construct. Note that reductions across " + "pfor iterations are currently not supported inside control flow " + "constructs." % reduction_arg) + batched_arg = self._conversion_map[reduction_arg] + batched_args.append(self._unwrap_or_tile(batched_arg)) + outputs = reduction_fn(*batched_args) + return [wrap(output, False) for output in nest.flatten(outputs)] + + def _convert_helper(self, op_or_tensor): + stack = collections.deque([op_or_tensor]) + while stack: + y = stack[0] + if y in self._conversion_map: + assert isinstance(self._conversion_map[y], + (WrappedTensor, ops.Operation)) + stack.popleft() + continue + if isinstance(y, ops.Operation): + assert not y.outputs, ( + "We only support converting Operation objects with no outputs. " + "Got %s", y) + y_op = y + else: + assert isinstance(y, tensor_lib.Tensor), y + y_op = y.op + + is_while_loop = y_op.type == "Exit" + if is_while_loop: + while_op = WhileOp( + y, pfor_ops=self._pfor_ops, + fallback_to_while_loop=self.fallback_to_while_loop, + pfor_config=self._pfor_config) + is_inside_loop = while_op.is_inside_loop + # If all nodes in the while_loop graph were created inside the pfor, we + # treat the whole loop subgraph as a single op (y_op) and try to convert + # it. For while_loops that are created completely or partially outside, + # we treat them as external and should be able to simply return the Exit + # node output as is without needing any conversion. Note that for + # while_loops that are partially constructed inside, we assume they will + # be loop invariant. If that is not the case, it will create runtime + # errors since the converted graph would depend on the self._loop_var + # placeholder. + if is_inside_loop: + y_op = while_op + else: + is_inside_loop = self.op_is_inside_loop(y_op) + + # If this op was not created inside the loop body, we will return as is. + # 1. Convert inputs and control inputs. + + def _add_to_stack(x): + if x not in self._conversion_map: + stack.appendleft(x) + return True + else: + return False + + if is_inside_loop: + added_to_stack = False + for inp in y_op.inputs: + added_to_stack |= _add_to_stack(inp) + for cinp in y_op.control_inputs: + if cinp.outputs: + for t in cinp.outputs: + added_to_stack |= _add_to_stack(t) + else: + added_to_stack |= _add_to_stack(cinp) + if added_to_stack: + continue + + converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] + some_input_converted = any(self._was_converted(x) for x in y_op.inputs) + some_input_stacked = any(x.is_stacked for x in converted_inputs) + + converted_control_ops = set() + some_control_input_converted = False + for cinp in y_op.control_inputs: + if cinp.outputs: + for t in cinp.outputs: + converted_t = self._conversion_map[t] + if self._was_converted(t): + some_control_input_converted = True + converted_control_ops.add(converted_t.t.op) + else: + converted_cinp = self._conversion_map[cinp] + assert isinstance(converted_cinp, ops.Operation) + if converted_cinp != cinp: + some_control_input_converted = True + converted_control_ops.add(converted_cinp) + converted_control_ops = list(converted_control_ops) + is_stateful = _is_stateful_pfor_op(y_op) + else: + converted_inputs = [] + converted_control_ops = [] + logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, + converted_inputs, converted_control_ops) + + # 2. Convert y_op + # If converting a while_loop, we let the while_loop convertor deal with + # putting the control dependencies appropriately. + control_dependencies = [] if is_while_loop else converted_control_ops + with ops.control_dependencies(control_dependencies), ops.name_scope( + y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op): + # Op is a placeholder for a reduction. + reduce_output = self._convert_reduction(y) + if reduce_output is not None: + new_outputs = reduce_output + # None of the inputs and control inputs were converted. + elif ((not is_inside_loop or + (not is_stateful and not some_input_converted and + not some_control_input_converted)) and + y.graph == ops.get_default_graph()): + if y is y_op: + assert not isinstance(y_op, WhileOp) + new_outputs = y_op + else: + new_outputs = [wrap(x, False) for x in y_op.outputs] + elif not (is_stateful or is_while_loop or some_input_stacked): + # All inputs are unstacked or unconverted but some control inputs are + # converted. + # TODO(rachelim): Handle the case where some inputs are sparsely + # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs)) + new_op = _create_op(y_op.type, [x.t for x in converted_inputs], + [x.dtype for x in y_op.outputs], + y_op.node_def.attr) + if y is y_op: + new_outputs = new_op + else: + new_outputs = [] + for old_output, new_output in zip(y_op.outputs, new_op.outputs): + handle_data_util.copy_handle_data(old_output, new_output) + new_outputs.append(wrap(new_output, False)) + else: + # Either some inputs are not loop invariant or op is stateful. + if hasattr(y_op, "pfor_converter"): + converter = y_op.pfor_converter + else: + converter = _pfor_converter_registry.get(y_op.type, None) + if converter is None: + root_cause = "there is no registered converter for this op." + has_variant_outputs = any(x.dtype == dtypes.variant for x in + y_op.outputs) + has_vectorized_variant_inputs = any( + _is_variant_with_internal_stacking(x) for x in + y_op.inputs) + if (self._fallback_to_while_loop and not has_variant_outputs + and not has_vectorized_variant_inputs): + converter = functools.partial( + _fallback_converter, root_cause=root_cause, warn=self._warn) + else: + message = (f"No pfor vectorization defined for {y_op.type}\n" + f"{y_op}\n inputs: {converted_inputs}.") + if not self._fallback_to_while_loop: + message += ("Consider enabling the fallback_to_while_loop " + "option to pfor, which may run slower.") + raise ValueError(message) + # TODO(rachelim): Handle the case where some inputs are sparsely + # stacked. We should only call the converter if it supports handling + # those inputs. + pfor_inputs = _PforInput(self, y_op, converted_inputs) + try: + try: + new_outputs = converter(pfor_inputs) + except ConversionNotImplementedError as e: + has_vectorized_variant_inputs = any( + _is_variant_with_internal_stacking(x) for x in + y_op.inputs) + if (self._fallback_to_while_loop + and not has_vectorized_variant_inputs): + new_outputs = _fallback_converter( + pfor_inputs, root_cause=str(e)) + else: + raise ValueError(str(e)).with_traceback(sys.exc_info()[2]) + except Exception as e: # pylint: disable=broad-except + logging.error( + f"Got error while pfor was converting op {y_op} with inputs " + f"{y_op.inputs[:]}\n, converted inputs {pfor_inputs.inputs}\n" + f"Here are the pfor conversion stack traces: {e}") + original_op = y_op + while isinstance(original_op, ops.Operation): + logging.error( + "%s\ncreated at:\n %s", original_op, + " ".join(traceback.format_list(original_op.traceback))) + original_op = original_op._original_op + raise + + if isinstance(new_outputs, WrappedTensor): + new_outputs = [new_outputs] + assert isinstance(new_outputs, + (list, tuple, ops.Operation)), new_outputs + logging.vlog(2, f"converted {y_op} {new_outputs}") + + # Insert into self._conversion_map + if y is y_op: + assert isinstance(new_outputs, ops.Operation) + self._add_conversion(y_op, new_outputs) + else: + assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs, + new_outputs) + for old_output, new_output in zip(y_op.outputs, new_outputs): + assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) + assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op) + # Set shape for converted output. + output_shape = old_output.shape + if not new_output.is_sparse_stacked: + if new_output.is_stacked: + loop_len = tensor_util.constant_value(self.loop_len_vector) + if loop_len is None: + batch_dim = tensor_shape.TensorShape([None]) + else: + batch_dim = tensor_shape.TensorShape(loop_len) + output_shape = batch_dim.concatenate(output_shape) + if _is_variant_with_internal_stacking(new_output.t): + new_output.t.set_shape([]) + else: + new_output.t.set_shape(output_shape) + self._add_conversion(old_output, new_output) + stack.popleft() + + return self._conversion_map[op_or_tensor] + + @property + def loop_len_vector(self): + """Returns a single element vector whose value is number of iterations.""" + return self._loop_len_vector + + @property + def loop_var(self): + """Returns placeholder loop variable.""" + return self._loop_var + + @property + def pfor_ops(self): + return self._pfor_ops + + @property + def pfor_config(self): + return self._pfor_config + + @property + def all_indices_partitioned(self): + """all_indices_partitioned property. + + Returns: + True if we are inside a control flow construct and not all pfor iterations + may be active. + """ + return self._all_indices_partitioned + + @property + def fallback_to_while_loop(self): + return self._fallback_to_while_loop + + +# The code below defines converters for different operations. Please see comment +# for RegisterPFor to see how converters should be defined. + + +# image_ops + + +@RegisterPFor("AdjustContrastv2") +def _convert_adjust_contrastv2(pfor_input: _PforInput): + images = pfor_input.stacked_input(0) + contrast_factor = pfor_input.unstacked_input(1) + return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True) + + +@RegisterPFor("AdjustHue") +def _convert_adjust_hue(pfor_input: _PforInput): + images = pfor_input.stacked_input(0) + delta = pfor_input.unstacked_input(1) + return wrap(gen_image_ops.adjust_hue(images, delta), True) + + +@RegisterPFor("AdjustSaturation") +def _convert_adjust_saturation(pfor_input: _PforInput): + images = pfor_input.stacked_input(0) + scale = pfor_input.unstacked_input(1) + return wrap(gen_image_ops.adjust_saturation(images, scale), True) + + +# nn_ops + + +def _flatten_first_two_dims(x): + """Merges first two dimensions.""" + old_shape = array_ops.shape(x) + first_dim = constant_op.constant([-1], dtype=old_shape.dtype) + new_shape = array_ops.concat([first_dim, old_shape[2:]], axis=0) + return array_ops.reshape(x, new_shape) + + +def _unflatten_first_dim(x, first_dim): + """Splits first dimension into [first_dim, -1].""" + old_shape = array_ops.shape(x) + first_dim = math_ops.cast(first_dim, old_shape.dtype) + second_dim = constant_op.constant([-1], dtype=old_shape.dtype) + new_shape = array_ops.concat([first_dim, second_dim, old_shape[1:]], axis=0) + return array_ops.reshape(x, new_shape) + + +def _inputs_with_flattening(pfor_input: _PforInput, input_indices): + """Stacks and flattens first dim of inputs at indices `input_indices`.""" + if input_indices is None: + input_indices = [] + pfor_input.stack_inputs(stack_indices=input_indices) + inputs = [] + for i in range(pfor_input.num_inputs): + if i in input_indices: + inp = pfor_input.stacked_input(i) + inp = _flatten_first_two_dims(inp) + else: + inp = pfor_input.unstacked_input(i) + inputs.append(inp) + return inputs + + +@RegisterPForWithArgs("Conv2D", dims=[0]) +@RegisterPForWithArgs("DepthToSpace", dims=[0]) +@RegisterPForWithArgs("AvgPool", dims=[0]) +@RegisterPForWithArgs("AvgPool3D", dims=[0]) +@RegisterPForWithArgs("MaxPool", dims=[0]) +@RegisterPForWithArgs("MaxPoolV2", dims=[0]) +@RegisterPForWithArgs("MaxPool3D", dims=[0]) +@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2]) +@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) +@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2]) +@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) +@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1]) +@RegisterPForWithArgs("SpaceToDepth", dims=[0]) +def _convert_flatten_batch(pfor_input: _PforInput, op_type, dims): + del op_type + inputs = _inputs_with_flattening(pfor_input, dims) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs] + return [wrap(x, True) for x in outputs] + + +_channel_flatten_input_cache = {} + + +@RegisterPFor("BatchToSpaceND") +def _convert_batch_to_space_nd(pfor_input: _PforInput): + inp = pfor_input.stacked_input(0) + block_shape = pfor_input.unstacked_input(1) + crops = pfor_input.unstacked_input(2) + + inp_shape = array_ops.shape(inp) + n = math_ops.cast(pfor_input.pfor.loop_len_vector, inp_shape.dtype) + block_shape = math_ops.cast(block_shape, inp_shape.dtype) + + # Reshape and transpose to move the vectorization axis inside the axes that + # will move to space. + # Reshape to 4D and transpose + block_size = math_ops.reduce_prod(block_shape) + neg_one = constant_op.constant(-1, dtype=inp_shape.dtype) + new_shape = [n[0], block_size, inp_shape[1] // block_size, neg_one] + inp = array_ops.reshape(inp, new_shape) + + inp = array_ops.transpose(inp, [1, 0, 2, 3]) + # Reshape back to merge the block, vectorization and batch dimension, and + # restore the other dimensions. + new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0) + inp = array_ops.reshape(inp, new_shape) + + # Call batch_to_space and then split the new batch axis. + output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops) + output = _unflatten_first_dim(output, n) + return wrap(output, True) + + +@RegisterPFor("SpaceToBatchND") +def _convert_space_to_batch_nd(pfor_input: _PforInput): + inp = pfor_input.stacked_input(0) + block_shape = pfor_input.unstacked_input(1) + paddings = pfor_input.unstacked_input(2) + + inp_shape = array_ops.shape(inp) + n = math_ops.cast(pfor_input.pfor.loop_len_vector, inp_shape.dtype) + block_shape = math_ops.cast(block_shape, inp_shape.dtype) + + inp = _flatten_first_two_dims(inp) + output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings) + output_shape = array_ops.shape(output) + + block_size = math_ops.reduce_prod(block_shape) + neg_one = constant_op.constant(-1, dtype=inp_shape.dtype) + new_shape = [block_size, n[0], neg_one] + output = array_ops.reshape(output, new_shape) + + output = array_ops.transpose(output, [1, 0, 2]) + new_shape = array_ops.concat( + [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0) + output = array_ops.reshape(output, new_shape) + return wrap(output, True) + + +def _channel_flatten_input(x, data_format): + """Merge the stack dimension with the channel dimension. + + If S is pfor's stacking dimension, then, + - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose + should be cheap. + - for SNHWC, we transpose to NHWSC. + We then merge the S and C dimension. + + Args: + x: tensor_lib.Tensor to transform. + data_format: "NCHW" or "NHWC". + + Returns: + A 3-element tuple with the transformed value, along with the shape for + reshape and order for transpose required to transform back. + """ + + graph = ops.get_default_graph() + cache_key = (graph, x.ref(), data_format) + if cache_key not in _channel_flatten_input_cache: + x_shape = array_ops.shape(x) + neg_ones = constant_op.constant([-1], dtype=x_shape.dtype) + if data_format == b"NCHW": + order = [1, 0, 2, 3, 4] + shape = array_ops.concat([x_shape[1:2], neg_ones, x_shape[3:]], axis=0) + reverse_order = order + else: + order = [1, 2, 3, 0, 4] + shape = array_ops.concat([x_shape[1:4], neg_ones], axis=0) + reverse_order = [3, 0, 1, 2, 4] + # Move S dimension next to C dimension. + x = array_ops.transpose(x, order) + reverse_shape = array_ops.shape(x) + # Reshape to merge the S and C dimension. + x = array_ops.reshape(x, shape) + outputs = x, reverse_order, reverse_shape + _channel_flatten_input_cache[cache_key] = outputs + else: + outputs = _channel_flatten_input_cache[cache_key] + return outputs + + +# Note that with training=True, running FusedBatchNormV3 on individual examples +# is very different from running FusedBatchNormV3 on a batch of those examples. +# This is because, for the latter case, the operation can be considered as first +# computing the mean and variance over all the examples and then using these +# to scale all those examples. This creates a data dependency between these +# different "iterations" since the inputs to the scaling step depends on the +# statistics coming from all these inputs. +# As with other kernels, the conversion here effectively runs the kernel +# independently for each iteration, and returns outputs by stacking outputs from +# each of those iterations. +@RegisterPFor("FusedBatchNormV3") +def _convert_fused_batch_norm(pfor_input: _PforInput): + is_training = pfor_input.get_attr("is_training") + # When BatchNorm is used with training=False, mean and variance are provided + # externally and used as is by the op. Thus, we can merge the S and N + # dimensions as we do for regular operations. + # When BatchNorm is used with training=True, mean and variance are computed + # for each channel across the batch dimension (first one). If we merge S and N + # dimensions, mean and variances will be computed over a larger set. So, we + # merge the S and C dimensions instead. + if not is_training: + # We return zeros for batch_mean and batch_variance output. Note that CPU + # and GPU seem to have different behavior for those two outputs. CPU outputs + # zero because these values are not used during inference. GPU outputs + # something, probably real means and variances. + inputs = _inputs_with_flattening(pfor_input, [0]) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + y = outputs[0] + n = pfor_input.pfor.loop_len_vector + y = _unflatten_first_dim(y, n) + mean = pfor_input.unstacked_input(3) + zeros = array_ops.zeros_like(mean) + return [wrap(y, True)] + [wrap(zeros, False)] * 5 + + pfor_input.stack_inputs() + data_format = pfor_input.get_attr("data_format") + # We merge the first dimension with the "C" dimension, run FusedBatchNormV3, + # and then transpose back. + x = pfor_input.stacked_input(0) + x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) + # Note that we stack all the other inputs as well so that they are the same + # size as the new size of the channel dimension. + inputs = [x] + [ + array_ops.reshape(pfor_input.stacked_input(i), [-1]) + for i in range(1, pfor_input.num_inputs) + ] + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + y = outputs[0] + y = array_ops.reshape(y, reverse_shape) + y = array_ops.transpose(y, reverse_order) + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] + outputs = [y] + outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("FusedBatchNormGradV3") +def _convert_fused_batch_norm_grad(pfor_input: _PforInput): + pfor_input.stack_inputs() + data_format = pfor_input.get_attr("data_format") + y_backprop = pfor_input.stacked_input(0) + y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) + x = pfor_input.stacked_input(1) + x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) + inputs = [y_backprop, x] + [ + array_ops.reshape(pfor_input.stacked_input(i), [-1]) + for i in range(2, pfor_input.num_inputs) + ] + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + x_backprop = outputs[0] + x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) + x_backprop = array_ops.transpose(x_backprop, x_reverse_order) + n = pfor_input.pfor.loop_len_vector + outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] + outputs = [x_backprop] + outputs + return [wrap(output, True) for output in outputs] + + +@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) +@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) +@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0) +def _convert_flatten_batch_shape_input( + pfor_input: _PforInput, op_type, flatten_dims, shape_dim): + del op_type + inputs = _inputs_with_flattening(pfor_input, flatten_dims) + n = pfor_input.pfor.loop_len_vector + # Adjust the `input_sizes` input. + ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1], + dtype=n.dtype) + inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) + outputs = _create_op( + pfor_input.op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + outputs = [_unflatten_first_dim(x, n) for x in outputs] + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("Conv2DBackpropFilter") +def _convert_conv2d_backprop_filter(pfor_input: _PforInput): + pfor_input.stack_inputs(stack_indices=[2]) + inputs, inputs_stacked, _ = pfor_input.input(0) + filter_sizes = pfor_input.unstacked_input(1) + grads = pfor_input.stacked_input(2) + strides = pfor_input.get_attr("strides") + padding = pfor_input.get_attr("padding") + use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") + data_format = pfor_input.get_attr("data_format") + dilations = pfor_input.get_attr("dilations") + if inputs_stacked: + # TODO(agarwal): Implement this efficiently. + logging.warning("Conv2DBackpropFilter uses a while_loop. Fix that!") + + def while_body(i, ta): + inp_i = inputs[i, ...] + grad_i = grads[i, ...] + output = nn_ops.conv2d_backprop_filter( + inp_i, + filter_sizes, + grad_i, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations) + return i + 1, ta.write(i, output) + + n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) + _, ta = while_loop.while_loop( + lambda i, ta: i < n, while_body, + (0, tensor_array_ops.TensorArray(inputs.dtype, n))) + output = ta.stack() + return wrap(output, True) + else: + # We merge the stack dimension with the channel dimension of the gradients + # and pretend we had a larger filter (see change to filter_sizes below). + # Once the filter backprop is computed, we reshape and transpose back + # appropriately. + grads, _, _ = _channel_flatten_input(grads, data_format) + n = pfor_input.pfor.loop_len_vector + old_filter_sizes = filter_sizes + filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) + output = nn_ops.conv2d_backprop_filter( + inputs, + filter_sizes, + grads, + strides=strides, + padding=padding, + use_cudnn_on_gpu=use_cudnn_on_gpu, + data_format=data_format, + dilations=dilations) + new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) + output = array_ops.reshape(output, new_filter_shape) + output = array_ops.transpose(output, [3, 0, 1, 2, 4]) + return wrap(output, True) + + +def _flatten_with_inner_dim(x, dim, x_rank): + """Merges the first dim with the specified dim.""" + shape = array_ops.shape(x) + x = array_ops.transpose(x, + list(range(1, dim)) + [0] + list(range(dim, x_rank))) + + if dim < x_rank - 1: + new_shape_pieces = [shape[1:dim], [-1], shape[dim + 1:]] + else: + new_shape_pieces = [shape[1:dim], [-1]] + new_shape = array_ops.concat(new_shape_pieces, axis=0) + return array_ops.reshape(x, new_shape) + + +def _unflatten_with_inner_dim(x, dim, x_rank, stack_size): + """Undoes _flatten_with_inner_dim.""" + shape = array_ops.shape(x) + if dim < x_rank - 1: + new_shape_pieces = [shape[:dim], [stack_size], [-1], shape[dim + 1:]] + else: + new_shape_pieces = [shape[:dim], [stack_size], [-1]] + new_shape = array_ops.concat(new_shape_pieces, axis=0) + x = array_ops.reshape(x, new_shape) + dims_permutation = [dim] + list(range(dim)) + list(range(dim + 1, x_rank + 1)) + return array_ops.transpose(x, dims_permutation) + + +@RegisterPFor("DepthwiseConv2dNative") +def _convert_depthwise_conv2d_native(pfor_input: _PforInput): + # Kernel can be vectorized, so folding to batch dimension does not work. We + # instead fold into the channel dimension because it is parallel. + stack_size = pfor_input.pfor.loop_len_vector[0] + data_format = pfor_input.get_attr("data_format") + c_dim = 1 if data_format == b"NCHW" else 3 + t = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) + kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) + conv = _create_op( + "DepthwiseConv2dNative", [t, kernel], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs[0] + return wrap(_unflatten_with_inner_dim(conv, c_dim, 4, stack_size), True) + + +@RegisterPFor("DepthwiseConv2dNativeBackpropInput") +def _convert_depthwise_conv2d_native_backprop_input(pfor_input: _PforInput): + stack_size = pfor_input.pfor.loop_len_vector[0] + input_sizes = pfor_input.unstacked_input(0) + data_format = pfor_input.get_attr("data_format") + c_dim = 1 if data_format == b"NCHW" else 3 + input_sizes_mutipliers = [ + constant_op.constant([1] * c_dim, dtype=dtypes.int32), [stack_size] + ] + if c_dim < 3: + input_sizes_mutipliers += [ + constant_op.constant([1] * (3 - c_dim), dtype=dtypes.int32) + ] + input_sizes *= array_ops.concat(input_sizes_mutipliers, axis=0) + kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) + out_backprop = _flatten_with_inner_dim( + pfor_input.stacked_input(2), c_dim + 1, 5) + result = _create_op( + "DepthwiseConv2dNativeBackpropInput", [input_sizes, kernel, out_backprop], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs[0] + return wrap(_unflatten_with_inner_dim(result, c_dim, 4, stack_size), True) + + +@RegisterPFor("DepthwiseConv2dNativeBackpropFilter") +def _convert_depthwise_conv2d_native_backprop_filter(pfor_input: _PforInput): + stack_size = pfor_input.pfor.loop_len_vector[0] + data_format = pfor_input.get_attr("data_format") + c_dim = 1 if data_format == b"NCHW" else 3 + inputs = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) + filter_sizes = pfor_input.unstacked_input(1) + filter_sizes_multipliers = [ + constant_op.constant([1, 1], dtype=dtypes.int32), [stack_size], + constant_op.constant([1], dtype=dtypes.int32) + ] + filter_sizes *= array_ops.concat(filter_sizes_multipliers, axis=0) + out_backprop = _flatten_with_inner_dim( + pfor_input.stacked_input(2), c_dim + 1, 5) + result = _create_op( + "DepthwiseConv2dNativeBackpropFilter", + [inputs, filter_sizes, out_backprop], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs[0] + return wrap(_unflatten_with_inner_dim(result, 2, 4, stack_size), True) + + +@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax) +@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax) +def _convert_softmax(pfor_input: _PforInput, op_type, op_func): + del op_type + return wrap(op_func(pfor_input.stacked_input(0)), True) + + +# array_ops + + +@RegisterPForWithArgs("Identity", array_ops.identity) +@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) +@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag) +@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) +@RegisterPForWithArgs("_EagerConst", array_ops.identity) +def _convert_identity(pfor_input: _PforInput, op_type, op_func): + del op_type + return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) + + +@RegisterPFor("IdentityN") +def _convert_identity_n(pfor_input: _PforInput): + outputs = array_ops.identity_n([x.t for x in pfor_input.inputs]) + return [ + wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs) + ] + + +@RegisterPFor("Reshape") +def _convert_reshape(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + shape = pfor_input.unstacked_input(1) + n = math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype) + new_shape = array_ops.concat([n, shape], axis=0) + return wrap(array_ops.reshape(t, new_shape), True) + + +@RegisterPFor("TopK") +@RegisterPFor("TopKV2") +def _convert_top_k(pfor_input: _PforInput): + outputs = _create_op( + op_type=pfor_input.op_type, + inputs=[x.t for x in pfor_input.inputs], + op_dtypes=[x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("Fill") +def _convert_fill(pfor_input: _PforInput): + dims = pfor_input.unstacked_input(0) + value = pfor_input.stacked_input(1) + # Expand the rank of `value` + new_shape = array_ops.concat( + [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)], + axis=0) + value = array_ops.reshape(value, new_shape) + # Compute the new output shape + new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0) + # Broadcast + return wrap(array_ops.broadcast_to(value, new_dims), True) + + +@RegisterPFor("BroadcastTo") +def _convert_broadcast_to(pfor_input: _PforInput): + shape = pfor_input.unstacked_input(1) + n = math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype) + new_shape = array_ops.concat([n, shape], axis=0) + new_rank = _size(new_shape, dtypes.int32) + + t = pfor_input.stacked_input(0) + t = _expand_dims(t, 1, new_rank - _rank(t)) + return wrap(array_ops.broadcast_to(t, new_shape), True) + + +@RegisterPFor("ExpandDims") +def _convert_expanddims(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + dim = pfor_input.unstacked_input(1) + dim += math_ops.cast(dim >= 0, dim.dtype) + return wrap(array_ops.expand_dims(t, axis=dim), True) + + +@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound) +@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound) +def _convert_searchsorted(pfor_input: _PforInput, _, op_func): + pfor_input.stack_inputs() + sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0)) + values = _flatten_first_two_dims(pfor_input.stacked_input(1)) + out_type = pfor_input.get_attr("out_type") + output = op_func(sorted_inputs, values, out_type) + return wrap( + _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True) + + +@RegisterPFor("MatrixBandPart") +def _convert_matrix_band_part(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + num_lower = pfor_input.unstacked_input(1) + num_upper = pfor_input.unstacked_input(2) + return wrap( + array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper), + True) + + +@RegisterPFor("MatrixSetDiag") +def _convert_matrix_set_diag(pfor_input: _PforInput): + pfor_input.stack_inputs() + t = pfor_input.stacked_input(0) + diag = pfor_input.stacked_input(1) + return wrap(array_ops.matrix_set_diag(t, diag), True) + + +# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3. +# The input orders defined in the OpKernel and the actual python API are +# different (for compatibility with V1), so we cannot use _convert_identity. +# v2 is not compatible with v3 and is never exposed on the public API. +@RegisterPFor("MatrixDiagV2") +@RegisterPFor("MatrixDiagV3") +def _convert_matrix_diag_v2(pfor_input: _PforInput): + params = { + "diagonal": pfor_input.stacked_input(0), + "k": pfor_input.unstacked_input(1), + "num_rows": pfor_input.unstacked_input(2), + "num_cols": pfor_input.unstacked_input(3), + "padding_value": pfor_input.unstacked_input(4) + } + if pfor_input.op_type == "MatrixDiagV2": + return wrap(array_ops.matrix_diag_v2(**params), True) + params["align"] = pfor_input.get_attr("align") + return wrap(array_ops.matrix_diag(**params), True) + + +@RegisterPFor("Diag") +def _convert_diag(pfor_input: _PforInput): + diag = pfor_input.stacked_input(0) + if diag.shape.ndims == 2: + # We can use matrix_diag. + return wrap(array_ops.matrix_diag(diag), True) + else: + # It is not clear if we can do better than a while loop here with existing + # kernels. + return _fallback_converter(pfor_input, warn=False) + + +# See notes for MatrixDiagV2 +@RegisterPFor("MatrixDiagPartV2") +@RegisterPFor("MatrixDiagPartV3") +def _convert_matrix_diag_part_v2(pfor_input: _PforInput): + params = { + "input": pfor_input.stacked_input(0), + "k": pfor_input.unstacked_input(1), + "padding_value": pfor_input.unstacked_input(2) + } + if pfor_input.op_type == "MatrixDiagPartV2": + return wrap(array_ops.matrix_diag_part_v2(**params), True) + params["align"] = pfor_input.get_attr("align") + return wrap(array_ops.matrix_diag_part(**params), True) + + +# See notes for MatrixDiagV2 +@RegisterPFor("MatrixSetDiagV2") +@RegisterPFor("MatrixSetDiagV3") +def _convert_matrix_set_diag_v2(pfor_input: _PforInput): + pfor_input.stack_inputs([0, 1]) + params = { + "input": pfor_input.stacked_input(0), + "diagonal": pfor_input.stacked_input(1), + "k": pfor_input.unstacked_input(2) + } + if pfor_input.op_type == "MatrixSetDiagV2": + return wrap(array_ops.matrix_set_diag_v2(**params), True) + params["align"] = pfor_input.get_attr("align") + return wrap(array_ops.matrix_set_diag(**params), True) + + +@RegisterPFor("DiagPart") +def _convert_diag_part(pfor_input: _PforInput): + inp = pfor_input.stacked_input(0) + if inp.shape.ndims == 3: + # We can use matrix_diag_part. + return wrap(array_ops.matrix_diag_part(inp), True) + else: + # It is not clear if we can do better than a while loop here with existing + # kernels. + return _fallback_converter(pfor_input, warn=False) + + +@RegisterPFor("OneHot") +def _convert_one_hot(pfor_input: _PforInput): + indices = pfor_input.stacked_input(0) + depth = pfor_input.unstacked_input(1) + on_value = pfor_input.unstacked_input(2) + off_value = pfor_input.unstacked_input(3) + axis = pfor_input.get_attr("axis") + if axis >= 0: + axis += 1 + return wrap( + array_ops.one_hot(indices, depth, on_value, off_value, axis), True) + + +@RegisterPFor("Slice") +def _convert_slice(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + begin, begin_stacked, _ = pfor_input.input(1) + size = pfor_input.unstacked_input(2) + if not begin_stacked: + begin = array_ops.concat([[0], begin], axis=0) + size = array_ops.concat([[-1], size], axis=0) + return wrap(array_ops.slice(t, begin, size), True) + else: + # Handle negative sizes. + # + # If the `begin` entry corresponding to a negative `size` is loop-variant, + # the output would be ragged. This case is not supported. But `size` having + # some negative values and some loop-variant `begin`s is OK (and it's hard + # to tell the difference statically). + t_shape = array_ops.shape(t) + size = math_ops.cast(size, t_shape.dtype) + begin = math_ops.cast(begin, t_shape.dtype) + n = math_ops.cast(pfor_input.pfor.loop_len_vector, t_shape.dtype) + original_unstacked_shape = _stack(t_shape[1:], n).t + broadcast_size = _stack(size, n).t + result_shape = array_ops.where( + math_ops.less(broadcast_size, 0), + original_unstacked_shape - begin + broadcast_size + 1, broadcast_size) + result_shape = math_ops.cast(math_ops.reduce_max(result_shape, axis=0), + dtypes.int64) + + # Now we enumerate points in the sliced region for each pfor iteration and + # gather them. + cumsize = math_ops.cumprod(result_shape, exclusive=True, reverse=True) + result_num_elements = math_ops.reduce_prod(result_shape) + # Offsets are loop-variant. We first compute loop-invariant gather + # coordinates, then broadcast-add the loop-variant `begin` offsets. + result_base_coordinates = ( + math_ops.range(result_num_elements, dtype=dtypes.int64)[:, None] + // cumsize[None, :]) % result_shape[None, :] + result_coordinates = ( + begin[:, None, :] + + math_ops.cast(result_base_coordinates, begin.dtype)[None, :, :]) + result_flat = array_ops.gather_nd(params=t, indices=result_coordinates, + batch_dims=1) + result_stacked_shape = array_ops.concat( + [math_ops.cast(pfor_input.pfor.loop_len_vector, result_shape.dtype), + result_shape], + axis=0) + return wrap(array_ops.reshape(result_flat, result_stacked_shape), True) + + +@RegisterPFor("Tile") +def _convert_tile(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + multiples = pfor_input.unstacked_input(1) + multiples = array_ops.concat([[1], multiples], 0) + return wrap(array_ops.tile(t, multiples), True) + + +@RegisterPFor("Pack") +def _convert_pack(pfor_input: _PforInput): + pfor_input.stack_inputs() + axis = pfor_input.get_attr("axis") + if axis >= 0: + axis += 1 + return wrap( + array_ops_stack.stack([x.t for x in pfor_input.inputs], axis=axis), True) + + +@RegisterPFor("Unpack") +def _convert_unpack(pfor_input: _PforInput): + value = pfor_input.stacked_input(0) + axis = pfor_input.get_attr("axis") + if axis >= 0: + axis += 1 + num = pfor_input.get_attr("num") + return [wrap(x, True) for x + in array_ops_stack.unstack(value, axis=axis, num=num)] + + +@RegisterPFor("Pad") +def _convert_pad(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + paddings = pfor_input.unstacked_input(1) + paddings = array_ops.concat([[[0, 0]], paddings], 0) + return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) + + +@RegisterPFor("PadV2") +def _convert_pad_v2(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + paddings = pfor_input.unstacked_input(1) + paddings = array_ops.concat([[[0, 0]], paddings], 0) + return wrap(array_ops.pad_v2(t, paddings, mode="CONSTANT"), True) + + +@RegisterPFor("Split") +def _convert_split(pfor_input: _PforInput): + split_dim = pfor_input.unstacked_input(0) + t = pfor_input.stacked_input(1) + num_split = pfor_input.get_attr("num_split") + split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) + return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] + + +@RegisterPFor("SplitV") +def _convert_split_v(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + splits = pfor_input.unstacked_input(1) + split_dim = pfor_input.unstacked_input(2) + split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) + return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] + + +@RegisterPFor("Squeeze") +def _convert_squeeze(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + squeeze_dims = pfor_input.get_attr("squeeze_dims") + squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims] + return wrap(array_ops.squeeze(t, axis=squeeze_dims), True) + + +@RegisterPFor("ReverseV2") +def _convert_reverse(pfor_input: _PforInput): + value = pfor_input.stacked_input(0) + axis = pfor_input.unstacked_input(1) + new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis) + return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True) + + +@RegisterPForWithArgs("Transpose", gen_array_ops.transpose) +@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose) +def _convert_transpose(pfor_input: _PforInput, _, op_func): + t = pfor_input.stacked_input(0) + perm = pfor_input.unstacked_input(1) + new_perm = array_ops.concat([[0], perm + 1], axis=0) + return wrap(op_func(t, new_perm), True) + + +@RegisterPFor("ZerosLike") +def _convert_zeros_like(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + shape = array_ops.shape(t)[1:] + return wrap(array_ops.zeros(shape, dtype=t.dtype), False) + + +@RegisterPFor("OnesLike") +def _convert_ones_like(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + shape = array_ops.shape(t)[1:] + return wrap(array_ops.ones(shape, dtype=t.dtype), False) + + +@RegisterPFor("Gather") +@RegisterPFor("GatherV2") +def _convert_gather(pfor_input: _PforInput): + param, param_stacked, _ = pfor_input.input(0) + indices, indices_stacked, _ = pfor_input.input(1) + batch_dims = pfor_input.get_attr("batch_dims") + + op_type = pfor_input.op_type + if op_type == "Gather": + validate_indices = pfor_input.get_attr("validate_indices") + axis = 0 + else: + validate_indices = None + # Assume we will never have a Tensor with rank > 2**32. + axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32) + axis_value = tensor_util.constant_value(axis) + if axis_value is not None: + axis = axis_value + if indices_stacked and not param_stacked: + if indices is pfor_input.pfor.all_indices and axis == 0: + param_shape0 = tensor_shape.dimension_value(param.shape[0]) + indices_shape0 = tensor_shape.dimension_value(indices.shape[0]) + if param_shape0 is not None and indices_shape0 == param_shape0: + # Note that with loops and conditionals, indices may not be contiguous. + # However they will be sorted and unique. So if the shape matches, then + # it must be picking up all the rows of param. + return wrap(param, True) + + if batch_dims != 0: + # Convert `batch_dims` to its positive equivalent if necessary. + batch_dims_pos = batch_dims + if batch_dims < 0: + batch_dims_pos += array_ops.rank(indices) + # In order to maintain + # indices.shape[:batch_dims] == params.shape[:batch_dims] + # with stacked indices, we move the first dimension of `indices` to the + # `batch_dims + 1`th position. The (non-batch) index dimensions will be + # inserted into the shape of `output` at the `axis` dimension, which is + # then transposed to the front (below). + order = array_ops.concat([ + math_ops.range(1, batch_dims_pos + 1), + [0], + math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0) + indices = array_ops.transpose(indices, order) + + output = array_ops.gather( + param, indices, validate_indices=validate_indices, axis=axis, + batch_dims=batch_dims) + if axis != 0: + axis = smart_cond.smart_cond(axis < 0, + lambda: axis + array_ops.rank(param), + lambda: ops.convert_to_tensor(axis)) + order = array_ops.concat( + [[axis], + math_ops.range(axis), + math_ops.range(axis + 1, array_ops.rank(output))], + axis=0) + output = smart_cond.smart_cond( + math_ops.equal(axis, 0), lambda: output, + lambda: array_ops.transpose(output, order)) + return wrap(output, True) + if param_stacked: + pfor_input.stack_inputs(stack_indices=[1]) + indices = pfor_input.stacked_input(1) + if isinstance(axis, tensor_lib.Tensor): + axis = array_ops.where(axis >= 0, axis + 1, axis) + else: + axis = axis + 1 if axis >= 0 else axis + batch_dims = batch_dims + 1 if batch_dims >= 0 else batch_dims + output = array_ops.gather(param, indices, axis=axis, batch_dims=batch_dims) + return wrap(output, True) + + +@RegisterPFor("GatherNd") +def _convert_gather_nd(pfor_input: _PforInput): + # TODO(jmenick): Add support for unstacked params. + pfor_input.stack_inputs(stack_indices=[1]) + params = pfor_input.stacked_input(0) + indices = pfor_input.stacked_input(1) + stacked_result = array_ops.gather_nd(params, indices, batch_dims=1) + return wrap(stacked_result, True) + + +@RegisterPFor("ConcatV2") +def _convert_concatv2(pfor_input: _PforInput): + n = pfor_input.num_inputs + pfor_input.stack_inputs(stack_indices=range(n - 1)) + axis = pfor_input.unstacked_input(n - 1) + axis += math_ops.cast(axis >= 0, axis.dtype) + return wrap( + array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), + True) + + +@RegisterPFor("StridedSlice") +def _convert_strided_slice(pfor_input: _PforInput): + inp = pfor_input.stacked_input(0) + begin = pfor_input.unstacked_input(1) + end = pfor_input.unstacked_input(2) + strides = pfor_input.unstacked_input(3) + begin_mask = pfor_input.get_attr("begin_mask") + end_mask = pfor_input.get_attr("end_mask") + ellipsis_mask = pfor_input.get_attr("ellipsis_mask") + new_axis_mask = pfor_input.get_attr("new_axis_mask") + shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") + + begin = array_ops.concat([[0], begin], axis=0) + end = array_ops.concat([[0], end], axis=0) + strides = array_ops.concat([[1], strides], axis=0) + begin_mask = begin_mask << 1 | 1 + end_mask = end_mask << 1 | 1 + ellipsis_mask <<= 1 + new_axis_mask <<= 1 + shrink_axis_mask <<= 1 + return wrap( + array_ops.strided_slice( + inp, + begin, + end, + strides, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask), True) + + +@RegisterPFor("StridedSliceGrad") +def _convert_strided_slice_grad(pfor_input: _PforInput): + shape = pfor_input.unstacked_input(0) + begin = pfor_input.unstacked_input(1) + end = pfor_input.unstacked_input(2) + strides = pfor_input.unstacked_input(3) + dy = pfor_input.stacked_input(4) + begin_mask = pfor_input.get_attr("begin_mask") + end_mask = pfor_input.get_attr("end_mask") + ellipsis_mask = pfor_input.get_attr("ellipsis_mask") + new_axis_mask = pfor_input.get_attr("new_axis_mask") + shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") + + shape = array_ops.concat( + [math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype), shape], + axis=0) + begin = array_ops.concat([[0], begin], axis=0) + end = array_ops.concat([[0], end], axis=0) + strides = array_ops.concat([[1], strides], axis=0) + begin_mask = begin_mask << 1 | 1 + end_mask = end_mask << 1 | 1 + ellipsis_mask <<= 1 + new_axis_mask <<= 1 + shrink_axis_mask <<= 1 + return wrap( + array_ops.strided_slice_grad( + shape, + begin, + end, + strides, + dy, + begin_mask=begin_mask, + end_mask=end_mask, + ellipsis_mask=ellipsis_mask, + new_axis_mask=new_axis_mask, + shrink_axis_mask=shrink_axis_mask), True) + + +@RegisterPFor("CheckNumerics") +def _convert_check_numerics(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + message = pfor_input.get_attr("message") + return wrap(gen_array_ops.check_numerics(t, message), True) + + +@RegisterPFor("EnsureShape") +def _convert_ensure_shape(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + shape = tensor_shape.TensorShape(pfor_input.get_attr("shape")) + return wrap(gen_array_ops.ensure_shape(t, [None] + shape), True) + + +# manip_ops + + +@RegisterPFor("Roll") +def _convert_roll(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + shift, shift_stacked, _ = pfor_input.input(1) + axis = pfor_input.unstacked_input(2) + if not shift_stacked: + return wrap(manip_ops.roll(t, shift, axis + 1), True) + else: + # `axis` and `shift` may both be vectors, with repeated axes summing the + # corresponding `shift`s. We scatter shifts into a dense array of shape + # [loop_len, num_unstacked_axes] indicating the offset for each axis. + num_unstacked_axes = math_ops.cast(array_ops.rank(t), dtypes.int64) - 1 + axis = math_ops.cast(array_ops.reshape(axis, [-1]), dtypes.int64) + loop_len = math_ops.cast(pfor_input.pfor.loop_len_vector[0], dtypes.int64) + shift = math_ops.cast(array_ops.reshape(shift, [loop_len, -1]), + dtypes.int64) + axis_segment_ids = ( + math_ops.range(loop_len, dtype=dtypes.int64)[:, None] + * num_unstacked_axes + axis[None, :]) + axis_offsets = array_ops.reshape( + math_ops.unsorted_segment_sum( + data=shift, segment_ids=axis_segment_ids, + num_segments=loop_len * num_unstacked_axes), + [loop_len, num_unstacked_axes]) + + # Determine the coordinates in the input array of each result and gather + # them. + unstacked_shape = array_ops.shape(t, out_type=dtypes.int64)[1:] + cumsize = math_ops.cumprod(unstacked_shape, exclusive=True, reverse=True) + num_unstacked_elements = math_ops.reduce_prod(unstacked_shape) + result_coordinates = ( + (math_ops.range(num_unstacked_elements, + dtype=dtypes.int64)[None, :, None] + // cumsize[None, None, :] - axis_offsets[:, None, :]) + % unstacked_shape[None, None, :]) + result_flat = array_ops.gather_nd(params=t, indices=result_coordinates, + batch_dims=1) + return wrap(array_ops.reshape(result_flat, array_ops.shape(t)), + True) + +# math_ops + + +@RegisterPFor("MatMul") +def _convert_matmul(pfor_input: _PforInput): + # TODO(agarwal): Check if tiling is faster than two transposes. + a, a_stacked, _ = pfor_input.input(0) + b, b_stacked, _ = pfor_input.input(1) + tr_a = pfor_input.get_attr("transpose_a") + tr_b = pfor_input.get_attr("transpose_b") + if a_stacked and b_stacked: + output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) + return output + elif a_stacked: + if tr_a: + a = array_ops.transpose(a, [0, 2, 1]) + if a.shape.is_fully_defined(): + x, y, z = a.shape + else: + x, y, z = [ + array_ops.reshape(i, []) + for i in array_ops.split(array_ops.shape(a), 3) + ] + a = array_ops.reshape(a, [x * y, z]) + prod = math_ops.matmul(a, b, transpose_b=tr_b) + return wrap(array_ops.reshape(prod, [x, y, -1]), True) + else: + assert b_stacked + if tr_b: + perm = [2, 0, 1] + b = array_ops.transpose(b, perm) + else: + # As an optimization, if one of the first two dimensions is 1, then we can + # reshape instead of transpose. + # TODO(agarwal): This check can be done inside Transpose kernel. + b_shape = array_ops.shape(b) + min_dim = math_ops.minimum(b_shape[0], b_shape[1]) + perm = array_ops.where( + math_ops.equal(min_dim, 1), [0, 1, 2], [1, 0, 2]) + new_shape = array_ops_stack.stack([b_shape[1], b_shape[0], b_shape[2]]) + b = array_ops.transpose(b, perm) + b = array_ops.reshape(b, new_shape) + + if b.shape.is_fully_defined(): + x, y, z = b.shape + else: + x, y, z = [ + array_ops.reshape(i, []) + for i in array_ops.split(array_ops.shape(b), 3) + ] + b = array_ops.reshape(b, [x, y * z]) + prod = math_ops.matmul(a, b, transpose_a=tr_a) + prod = array_ops.reshape(prod, [-1, y, z]) + prod = array_ops.transpose(prod, [1, 0, 2]) + return wrap(prod, True) + + +# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window +# is met. +@RegisterPFor("BatchMatMul") +def _convert_batch_mat_mul(pfor_input: _PforInput): + # TODO(agarwal): There may be a more efficient way to do this instead of + # stacking the inputs. + pfor_input.stack_inputs() + x = pfor_input.stacked_input(0) + y = pfor_input.stacked_input(1) + adj_x = pfor_input.get_attr("adj_x") + adj_y = pfor_input.get_attr("adj_y") + + x = _flatten_first_two_dims(x) + y = _flatten_first_two_dims(y) + output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) + output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) + return wrap(output, True) + + +@RegisterPFor("BatchMatMulV2") +def _convert_batch_mat_mul_v2(pfor_input: _PforInput): + pfor_input.expanddim_inputs_for_broadcast() + x = pfor_input.input(0)[0] + y = pfor_input.input(1)[0] + adj_x = pfor_input.get_attr("adj_x") + adj_y = pfor_input.get_attr("adj_y") + + output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) + return wrap(output, True) + + +@RegisterPForWithArgs("Sum", math_ops.reduce_sum) +@RegisterPForWithArgs("Prod", math_ops.reduce_prod) +@RegisterPForWithArgs("Max", math_ops.reduce_max) +@RegisterPForWithArgs("Min", math_ops.reduce_min) +@RegisterPForWithArgs("Mean", math_ops.reduce_mean) +@RegisterPForWithArgs("All", math_ops.reduce_all) +@RegisterPForWithArgs("Any", math_ops.reduce_any) +def _convert_reduction(pfor_input: _PforInput, _, op_func): + t = pfor_input.stacked_input(0) + indices = pfor_input.unstacked_input(1) + # Shift positive indices by one to account for the extra dimension. + indices += math_ops.cast(indices >= 0, indices.dtype) + keep_dims = pfor_input.get_attr("keep_dims") + return wrap(op_func(t, indices, keepdims=keep_dims), True) + + +@RegisterPForWithArgs("ArgMax", math_ops.argmax) +@RegisterPForWithArgs("ArgMin", math_ops.argmin) +def _convert_argmax_argmin(pfor_input: _PforInput, _, op_func): + t = pfor_input.stacked_input(0) + dimension = pfor_input.unstacked_input(1) + dimension += math_ops.cast(dimension >= 0, dimension.dtype) + output_type = pfor_input.get_attr("output_type") + return wrap(op_func(t, axis=dimension, output_type=output_type), True) + + +@RegisterPFor("Bucketize") +def _convert_bucketize(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + boundaries = pfor_input.get_attr("boundaries") + return wrap(math_ops.bucketize(t, boundaries), True) + + +@RegisterPFor("ClipByValue") +def _convert_clip_by_value(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + clip_value_min = pfor_input.unstacked_input(1) + clip_value_max = pfor_input.unstacked_input(2) + return wrap(gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max), + True) + + +@RegisterPForWithArgs("Cumsum", math_ops.cumsum) +@RegisterPForWithArgs("Cumprod", math_ops.cumprod) +def _convert_cumfoo(pfor_input: _PforInput, _, op_func): + t = pfor_input.stacked_input(0) + axis = pfor_input.unstacked_input(1) + # Shift positive indices by one to account for the extra dimension. + axis += math_ops.cast(axis >= 0, axis.dtype) + exclusive = pfor_input.get_attr("exclusive") + reverse = pfor_input.get_attr("reverse") + return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) + + +@RegisterPFor("BiasAdd") +def _convert_biasadd(pfor_input: _PforInput): + t, t_stacked, _ = pfor_input.input(0) + bias, bias_stacked, _ = pfor_input.input(1) + data_format = pfor_input.get_attr("data_format").decode() + if bias_stacked: + # BiasAdd only supports 1-D biases, so cast bias to match value and use Add. + pfor_input.expanddim_inputs_for_broadcast() + t, _, _ = pfor_input.input(0) + bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype) + if compat.as_bytes(data_format) == b"NCHW": + b_shape = array_ops.shape(bias) + new_b_shape = array_ops.concat( + [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0) + bias = array_ops.reshape(bias, new_b_shape) + return wrap(math_ops.add(t, bias), True) + else: + assert t_stacked, "At least one input to BiasAdd should be loop variant." + if compat.as_bytes(data_format) == b"NCHW": + shape = array_ops.shape(t) + flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) + t = array_ops.reshape(t, flattened_shape) + t = nn_ops.bias_add(t, bias, data_format="NCHW") + t = array_ops.reshape(t, shape) + return wrap(t, True) + return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) + + +@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum) +@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max) +@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min) +@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod) +def _convert_unsortedsegmentsum(pfor_input: _PforInput, _, op_func): + pfor_input.stack_inputs([0, 1]) + data = pfor_input.stacked_input(0) + segment_ids = pfor_input.stacked_input(1) + # TODO(agarwal): handle stacked? + num_segments = pfor_input.unstacked_input(2) + if segment_ids.dtype != num_segments.dtype: + segment_ids = math_ops.cast(segment_ids, dtypes.int64) + num_segments = math_ops.cast(num_segments, dtypes.int64) + dtype = segment_ids.dtype + segment_shape = array_ops.shape(segment_ids, out_type=dtype) + n = segment_shape[0] + ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:] + segment_offset = num_segments * math_ops.range(n, dtype=dtype) + segment_offset = array_ops.reshape(segment_offset, + array_ops.concat([[n], ones], axis=0)) + segment_ids = array_ops.where( + segment_ids >= 0, segment_ids + segment_offset, segment_ids + ) + num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast( + n, dtypes.int64) + output = op_func(data, segment_ids, num_segments) + new_output_shape = array_ops.concat( + [[n, -1], array_ops.shape(output)[1:]], axis=0) + output = array_ops.reshape(output, new_output_shape) + return wrap(output, True) + + +def _flatten_array_with_offset(ids, offset_delta, num_rows): + """Flattens a rank 2 tensor, adding an offset to each row.""" + # Note that if `ids` is rank 1, it is broadcast to rank 2. + offset_delta = math_ops.cast(offset_delta, ids.dtype) + n = math_ops.cast(num_rows, dtype=ids.dtype) + offsets = math_ops.range( + start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype) + offsets = array_ops.expand_dims(offsets, -1) + ids += offsets + return array_ops.reshape(ids, [-1]) + + +@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2) +@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2) +@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2) +@RegisterPForWithArgs("SparseSegmentSumWithNumSegments", + math_ops.sparse_segment_sum_v2) +@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments", + math_ops.sparse_segment_mean_v2) +@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments", + math_ops.sparse_segment_sqrt_n_v2) +def _convert_sparse_segment(pfor_input: _PforInput, _, op_func): + _, segment_ids_stacked, _ = pfor_input.input(2) + if segment_ids_stacked: + pfor_input.stack_inputs([1]) + data, data_stacked, _ = pfor_input.input(0) + indices, _, _ = pfor_input.input(1) + num_inputs = len(pfor_input.inputs) + assert num_inputs in (3, 4) + if num_inputs == 3: + # `segment_ids` needs to be unstacked since otherwise output sizes could + # differ across pfor iterations. + segment_ids = pfor_input.unstacked_input(2) + num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) + else: + segment_ids, _, _ = pfor_input.input(2) + num_segments = pfor_input.unstacked_input(3) + + n = pfor_input.pfor.loop_len_vector[0] + if data_stacked: + indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n) + data = _flatten_first_two_dims(data) + else: + indices = array_ops.reshape(indices, [-1]) + segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) + + if num_inputs == 3: + num_segments = None + else: + num_segments *= n + output = op_func(data, indices, segment_ids, num_segments=num_segments) + output = _unflatten_first_dim(output, [n]) + return wrap(output, True) + + +@RegisterPForWithArgs("SparseSegmentSumGrad", math_ops.sparse_segment_sum_grad) +@RegisterPForWithArgs("SparseSegmentMeanGrad", + math_ops.sparse_segment_mean_grad) +@RegisterPForWithArgs("SparseSegmentSqrtNGrad", + math_ops.sparse_segment_sqrt_n_grad) +def _convert_sparse_segment_grad(pfor_input: _PforInput, _, op_func): + grad = pfor_input.stacked_input(0) + indices = pfor_input.unstacked_input(1) + segment_ids = pfor_input.unstacked_input(2) + dim0 = pfor_input.unstacked_input(3) + + n = pfor_input.pfor.loop_len_vector[0] + indices = _flatten_array_with_offset(indices, dim0, n) + num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) + segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) + grad = _flatten_first_two_dims(grad) + dim0 *= n + output = op_func(grad, indices, segment_ids, dim0) + output = _unflatten_first_dim(output, [n]) + return wrap(output, True) + + +@RegisterPFor("Cast") +def _convert_cast(pfor_input: _PforInput): + inp = pfor_input.stacked_input(0) + dtype = pfor_input.get_attr("DstT") + return wrap(math_ops.cast(inp, dtype), True) + + +@RegisterPFor("Abs") +@RegisterPFor("Acos") +@RegisterPFor("Acosh") +@RegisterPFor("Add") +@RegisterPFor("AddV2") +@RegisterPFor("Angle") +@RegisterPFor("Asin") +@RegisterPFor("Asinh") +@RegisterPFor("Atan") +@RegisterPFor("Atan2") +@RegisterPFor("Atanh") +@RegisterPFor("BesselI0") +@RegisterPFor("BesselI1") +@RegisterPFor("BesselI0e") +@RegisterPFor("BesselI1e") +@RegisterPFor("BesselK0") +@RegisterPFor("BesselK1") +@RegisterPFor("BesselK0e") +@RegisterPFor("BesselK1e") +@RegisterPFor("BesselJ0") +@RegisterPFor("BesselJ1") +@RegisterPFor("BesselY0") +@RegisterPFor("BesselY1") +@RegisterPFor("BitwiseAnd") +@RegisterPFor("BitwiseOr") +@RegisterPFor("BitwiseXor") +@RegisterPFor("Ceil") +@RegisterPFor("Complex") +@RegisterPFor("ComplexAbs") +@RegisterPFor("Conj") +@RegisterPFor("Cos") +@RegisterPFor("Cosh") +@RegisterPFor("Dawsn") +@RegisterPFor("Digamma") +@RegisterPFor("Div") +@RegisterPFor("DivNoNan") +@RegisterPFor("Elu") +@RegisterPFor("Erf") +@RegisterPFor("Erfc") +@RegisterPFor("Erfinv") +@RegisterPFor("Exp") +@RegisterPFor("Expint") +@RegisterPFor("Expm1") +@RegisterPFor("Floor") +@RegisterPFor("FloorDiv") +@RegisterPFor("FloorMod") +@RegisterPFor("FresnelCos") +@RegisterPFor("FresnelSin") +@RegisterPFor("Greater") +@RegisterPFor("GreaterEqual") +@RegisterPFor("Igamma") +@RegisterPFor("IgammaGradA") +@RegisterPFor("Igammac") +@RegisterPFor("Imag") +@RegisterPFor("Inv") +@RegisterPFor("Invert") +@RegisterPFor("IsFinite") +@RegisterPFor("IsInf") +@RegisterPFor("IsNan") +@RegisterPFor("LeftShift") +@RegisterPFor("Less") +@RegisterPFor("LessEqual") +@RegisterPFor("Lgamma") +@RegisterPFor("Log") +@RegisterPFor("Log1p") +@RegisterPFor("LogicalAnd") +@RegisterPFor("LogicalNot") +@RegisterPFor("LogicalOr") +@RegisterPFor("LogicalXor") +@RegisterPFor("Maximum") +@RegisterPFor("Minimum") +@RegisterPFor("Mod") +@RegisterPFor("Mul") +@RegisterPFor("MulNoNan") +@RegisterPFor("Ndtri") +@RegisterPFor("Neg") +@RegisterPFor("Polygamma") +@RegisterPFor("Pow") +@RegisterPFor("Real") +@RegisterPFor("RealDiv") +@RegisterPFor("Reciprocal") +@RegisterPFor("Relu") +@RegisterPFor("Relu6") +@RegisterPFor("RightShift") +@RegisterPFor("Rint") +@RegisterPFor("Round") +@RegisterPFor("Rsqrt") +@RegisterPFor("Selu") +@RegisterPFor("Sigmoid") +@RegisterPFor("Sign") +@RegisterPFor("Sin") +@RegisterPFor("Sinh") +@RegisterPFor("Softplus") +@RegisterPFor("Softsign") +@RegisterPFor("Spence") +@RegisterPFor("Sqrt") +@RegisterPFor("Square") +@RegisterPFor("SquaredDifference") +@RegisterPFor("Sub") +@RegisterPFor("Tan") +@RegisterPFor("Tanh") +@RegisterPFor("TruncateDiv") +@RegisterPFor("TruncateMod") +@RegisterPFor("Xdivy") +@RegisterPFor("Xlogy") +@RegisterPFor("Xlog1py") +@RegisterPFor("Zeta") +def _convert_cwise(pfor_input: _PforInput): + if pfor_input.num_inputs > 1: + pfor_input.expanddim_inputs_for_broadcast() + + out = _create_op( + pfor_input.op_type, [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + assert len(out) == 1 + out = out[0] + + op_output = wrap(out, True) + return op_output + + +@RegisterPFor("XlaSharding") +def _convert_xla_sharding(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + sharding = pfor_input.get_attr("sharding") + return wrap(xla.sharding(t, sharding=sharding), True) + + +@RegisterPFor("LeakyRelu") +def _convert_leaky_relu(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + alpha = pfor_input.get_attr("alpha") + return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True) + + +@RegisterPFor("Equal") +def _convert_equal(pfor_input: _PforInput): + pfor_input.expanddim_inputs_for_broadcast() + x = pfor_input.input(0)[0] + y = pfor_input.input(1)[0] + incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") + return wrap(gen_math_ops.equal( + x, y, incompatible_shape_error=incompatible_shape_error), True) + + +@RegisterPFor("NotEqual") +def _convert_not_equal(pfor_input: _PforInput): + pfor_input.expanddim_inputs_for_broadcast() + x = pfor_input.input(0)[0] + y = pfor_input.input(1)[0] + incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") + return wrap(gen_math_ops.not_equal( + x, y, incompatible_shape_error=incompatible_shape_error), True) + + +@RegisterPFor("ApproximateEqual") +def _convert_approximate_equal(pfor_input: _PforInput): + pfor_input.expanddim_inputs_for_broadcast() + x = pfor_input.input(0)[0] + y = pfor_input.input(1)[0] + tolerance = pfor_input.get_attr("tolerance") + return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True) + + +@RegisterPFor("Shape") +def _convert_shape(pfor_input: _PforInput): + out_type = pfor_input.get_attr("out_type") + return wrap( + array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], + False) + + +@RegisterPFor("ShapeN") +def _convert_shape_n(pfor_input: _PforInput): + out_type = pfor_input.get_attr("out_type") + shapes = [ + array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape( + x, out_type=out_type) for x, stacked, _ in pfor_input.inputs + ] + return [wrap(x, False) for x in shapes] + + +@RegisterPFor("Size") +def _convert_size(pfor_input: _PforInput): + out_type = pfor_input.get_attr("out_type") + n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) + return wrap( + array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, + False) + + +@RegisterPFor("Rank") +def _convert_rank(pfor_input: _PforInput): + return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) + + +@RegisterPFor("AddN") +def _convert_addn(pfor_input: _PforInput): + # AddN does not support broadcasting. + pfor_input.stack_inputs(tile_variants=False) + return _wrap_and_tile_variants( + math_ops.add_n([x.t for x in pfor_input.inputs]), + pfor_input.pfor.loop_len_vector) + + +@RegisterPFor("Cross") +def _convert_cross(pfor_input: _PforInput): + pfor_input.stack_inputs() + a = pfor_input.stacked_input(0) + b = pfor_input.stacked_input(1) + return wrap(math_ops.cross(a, b), True) + + +@RegisterPFor("BiasAddGrad") +def _convert_biasaddgrad(pfor_input: _PforInput): + grad = pfor_input.stacked_input(0) + fmt = pfor_input.get_attr("data_format") + if fmt == b"NCHW": + output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) + else: + grad_shape = array_ops.shape(grad) + last_dim_shape = grad_shape[-1] + first_dim_shape = grad_shape[0] + output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) + output = math_ops.reduce_sum(output, axis=[1], keepdims=False) + return wrap(output, True) + + +# Some required ops are not exposed under the tf namespace. Hence relying on +# _create_op to create them. +@RegisterPForWithArgs("EluGrad") +@RegisterPForWithArgs("LeakyReluGrad") +@RegisterPForWithArgs("ReciprocalGrad") +@RegisterPForWithArgs("Relu6Grad") +@RegisterPForWithArgs("ReluGrad") +@RegisterPForWithArgs("RsqrtGrad") +@RegisterPForWithArgs("SeluGrad") +@RegisterPForWithArgs("SigmoidGrad") +@RegisterPForWithArgs("SoftplusGrad") +@RegisterPForWithArgs("SoftsignGrad") +@RegisterPForWithArgs("SqrtGrad") +@RegisterPForWithArgs("TanhGrad") +def _convert_grads(pfor_input: _PforInput, op_type, *args, **kw_args): + del args + del kw_args + # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we + # have to use tiling here. + pfor_input.stack_inputs() + outputs = _create_op( + op_type, [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("Select") +def _convert_select(pfor_input: _PforInput): + pfor_input.stack_inputs() + cond = pfor_input.stacked_input(0) + t = pfor_input.stacked_input(1) + e = pfor_input.stacked_input(2) + cond_rank = array_ops.rank(cond) + cond, t, e = smart_cond.smart_cond( + cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), + lambda: [cond, t, e]) + outputs = _create_op( + pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + n = pfor_input.pfor.loop_len_vector + out = smart_cond.smart_cond(cond_rank > 1, + lambda: _unflatten_first_dim(outputs[0], n), + lambda: outputs[0]) + return [wrap(out, True) for x in outputs] + + +@RegisterPFor("SelectV2") +def _convert_selectv2(pfor_input: _PforInput): + pfor_input.expanddim_inputs_for_broadcast() + cond = pfor_input.input(0)[0] + t = pfor_input.input(1)[0] + e = pfor_input.input(2)[0] + out = array_ops.where_v2(cond, t, e) + return wrap(out, True) + + +# random_ops + + +def _transpose_dim_to_front(x, dim): + rank = array_ops.rank(x) + return array_ops.transpose( + x, + perm=array_ops.concat( + [[dim], math_ops.range(0, dim), + math_ops.range(dim + 1, rank)], + axis=0)) + + +@RegisterPForWithArgs("RandomUniform") +@RegisterPForWithArgs("RandomUniformInt") +@RegisterPForWithArgs("RandomStandardNormal") +@RegisterPForWithArgs("TruncatedNormal") +def _convert_random(pfor_input: _PforInput, op_type, *args, **kw_args): + del args + del kw_args + inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] + # inputs[0] is "shape" + n = math_ops.cast(pfor_input.pfor.loop_len_vector, inputs[0].dtype) + inputs[0] = array_ops.concat([n, inputs[0]], axis=0) + # TODO(b/222761732): Turn this warning back on when legacy RNGs are + # deprecated. + # logging.warning( + # "Note that %s inside pfor op may not give same output as " + # "inside a sequential loop.", op_type) + outputs = _create_op( + op_type, + inputs, [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("RandomGamma") +@RegisterPFor("RandomPoissonV2") +def _convert_random_with_param(pfor_input: _PforInput): + shape = pfor_input.unstacked_input(0) + # param is lam (Poisson rate) or alpha (Gamma shape). + param, param_stacked, _ = pfor_input.input(1) + # TODO(b/222761732): Turn this warning back on when legacy RNGs are + # deprecated. + # logging.warning( + # "Note that %s inside pfor op may not give same output as " + # "inside a sequential loop.", pfor_input.op_type) + + if param_stacked: + samples = _create_op( + pfor_input.op_type, + inputs=[shape, param], + op_dtypes=[x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs[0] + loop_dim = array_ops.shape(shape)[0] + stacked_samples = _transpose_dim_to_front(samples, loop_dim) + else: + n = math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype) + shape = array_ops.concat([n, shape], axis=0) + stacked_samples = _create_op( + pfor_input.op_type, + inputs=[shape, param], + op_dtypes=[x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs[0] + + return wrap(stacked_samples, True) + + +@RegisterPFor("Multinomial") +def _convert_multinomial(pfor_input: _PforInput): + logits, logits_stacked, _ = pfor_input.input(0) + num_samples = pfor_input.unstacked_input(1) + seed = pfor_input.get_attr("seed") + seed2 = pfor_input.get_attr("seed2") + output_dtype = pfor_input.get_attr("output_dtype") + # TODO(b/222761732): Turn this warning back on when legacy RNGs are + # deprecated. + # logging.warning( + # "Note that Multinomial inside pfor op may not give same output as " + # "inside a sequential loop.") + + n = pfor_input.pfor.loop_len_vector[0] + if logits_stacked: + flattened_logits = _flatten_first_two_dims(logits) + samples = gen_random_ops.multinomial( + flattened_logits, + num_samples, + seed=seed, + seed2=seed2, + output_dtype=output_dtype) + stacked_samples = _unflatten_first_dim(samples, [n]) + else: + samples = gen_random_ops.multinomial( + logits, + num_samples * n, + seed=seed, + seed2=seed2, + output_dtype=output_dtype) + stacked_samples = array_ops.transpose( + array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2]) + + return wrap(stacked_samples, True) + + +@RegisterPFor("StatelessMultinomial") +@RegisterPFor("StatelessParameterizedTruncatedNormal") +@RegisterPFor("StatelessRandomBinomial") +@RegisterPFor("StatelessRandomGammaV2") +@RegisterPFor("StatelessRandomNormal") +@RegisterPFor("StatelessRandomPoisson") +@RegisterPFor("StatelessRandomUniform") +@RegisterPFor("StatelessRandomUniformInt") +@RegisterPFor("StatelessRandomUniformFullInt") +@RegisterPFor("StatelessTruncatedNormal") +def _convert_stateless_multinomial(pfor_input: _PforInput): + # Unlike stateful random ops, for stateless ones we want better + # reproducibility based on seed. Hence we don't want to use a similar strategy + # as used for stateful ones where we generate a possibly different set of + # random numbers under vectorization. + # Unfortunately, the kernels currently are not necessarily setup to do this + # efficiently and hence we fallback to a sequential loop for vectorization. + return _fallback_converter(pfor_input, warn=False) + + +# linalg_ops + + +@RegisterPForWithArgs("XlaEinsum") +@RegisterPForWithArgs("Einsum") +def _convert_einsum(pfor_input: _PforInput, op_type): + # Einsum may have either 1 or 2 inputs. + inputs, input_stacked, _ = zip(*[ + pfor_input.input(i) + for i in range(pfor_input.num_inputs)]) + + # Parse the einsum equation. + equation = pfor_input.get_attr("equation").decode("utf-8") + input_expr, output_expr = equation.split("->") + input_exprs = input_expr.split(",") + + # Pick a placeholder symbol to use for the new axis. + chosen_symbol = None + for s in string.ascii_letters: + if s in equation: + continue + else: + chosen_symbol = s + break + + if chosen_symbol is None: + raise ValueError("Could not figure out what symbol to use for new axis.") + + assert any(input_stacked) + for i in range(len(inputs)): + if input_stacked[i]: + input_exprs[i] = "{}{}".format(chosen_symbol, input_exprs[i]) + output_expr = "{}{}".format(chosen_symbol, output_expr) + + new_equation = "{}->{}".format(",".join(input_exprs), output_expr) + + if op_type == "XlaEinsum": + if len(inputs) == 1: + result = xla.einsum(equation=new_equation, a=inputs[0]) + else: + result = xla.einsum(equation=new_equation, a=inputs[0], b=inputs[1]) + else: + assert op_type == "Einsum" + result = special_math_ops.einsum(new_equation, *inputs) + + return wrap(result, True) + + +@RegisterPFor("Cholesky") +def _convert_cholesky(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + return wrap(linalg_ops.cholesky(t), True) + + +@RegisterPFor("LogMatrixDeterminant") +def _convert_log_matrix_determinant(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)] + + +@RegisterPFor("MatrixInverse") +def _convert_matrix_inverse(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + adjoint = pfor_input.get_attr("adjoint") + return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True) + + +@RegisterPFor("MatrixSolve") +def _convert_matrix_solve(pfor_input: _PforInput): + pfor_input.stack_inputs() + matrix = pfor_input.stacked_input(0) + rhs = pfor_input.stacked_input(1) + adjoint = pfor_input.get_attr("adjoint") + output = gen_linalg_ops.matrix_solve( + matrix, rhs, adjoint=adjoint) + return wrap(output, True) + + +@RegisterPFor("MatrixTriangularSolve") +def _convert_matrix_triangular_solve(pfor_input: _PforInput): + pfor_input.expanddim_inputs_for_broadcast() + matrix = pfor_input.input(0)[0] + rhs = pfor_input.input(1)[0] + lower = pfor_input.get_attr("lower") + adjoint = pfor_input.get_attr("adjoint") + output = linalg_ops.matrix_triangular_solve( + matrix, rhs, lower=lower, adjoint=adjoint) + return wrap(output, True) + + +@RegisterPFor("SelfAdjointEigV2") +def _convert_self_adjoint_eig(pfor_input: _PforInput): + t = pfor_input.stacked_input(0) + compute_v = pfor_input.get_attr("compute_v") + e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v) + # If compute_v is False, v will have shape [0]. + return wrap(e, True), wrap(v, compute_v) + + +# logging_ops + + +@RegisterPFor("Assert") +def _convert_assert(pfor_input: _PforInput): + cond, cond_stacked, _ = pfor_input.input(0) + if cond_stacked: + cond = math_ops.reduce_all(cond) + + data_list = [x.t for x in pfor_input.inputs][1:] + return _create_op( + "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr) + + +@RegisterPFor("Print") +def _convert_print(pfor_input: _PforInput): + # Note that we don't stack all the inputs. Hence unstacked values are printed + # once here vs multiple times in a while_loop. + pfor_input.stack_inputs([0]) + outputs = _create_op( + "Print", [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr).outputs + return [wrap(x, True) for x in outputs] + + +@RegisterPFor("PrintV2") +def _convert_print_v2(pfor_input: _PforInput): + # Print the full input Tensor(s), including the batch dimension if stacked. + return _create_op( + "PrintV2", [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr) + + +@RegisterPFor("StringFormat") +def _convert_string_format(pfor_input: _PforInput): + # Format using the full input Tensor(s), including the batch dimension if + # stacked. + op = _create_op( + "StringFormat", [x.t for x in pfor_input.inputs], + [x.dtype for x in pfor_input.outputs], + attrs=pfor_input.op.node_def.attr) + return [wrap(output, False) for output in op.outputs] + + +# data_flow_ops + +# TensorArray conversion is tricky since we don't support arrays of +# TensorArrays. For converting them, we consider two distinct cases: +# +# 1. The array is constructed outside the pfor call, and read/written inside the +# loop. +# This is an easier case since we don't need to make an array of TensorArrays. +# A correctness requirement is that these parallel iterations shouldn't attempt +# to write to the same location. Hence at conversion time we disallow indices to +# be loop-invariant as that would guarantee a collision. Even if the indices are +# not loop-invariant, they could conflict and that shall trigger runtime errors. +# +# 2. The array is constructed and used entirely inside each pfor iteration. +# For simplicity, here we require that the indices used for write/scatter are +# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in +# different pfor iterations. We consider two sub_cases: +# +# 2a Elements written to the array are "stacked" +# To simulate multiple TensorArrays, we may increase the dimension of each +# element of the array. i.e. the i_th row of the j_th entry of the converted +# TensorArray corresponds to the j_th entry of the TensorArray in the i_th +# pfor iteration. +# +# 2b Elements written to the array are "unstacked" +# In this case we don't increase the dimensions to avoid redundant tiling. Each +# iteration is trying to write the same value. So we convert that to a single +# write. +# +# Here are some tricks used to implement the above: +# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of +# trying to trace whether future writes are stacked or unstacked in order to set +# this attr, we set it to correspond to unknown shape. +# - We use the "flow" output of the different ops to track whether the array +# elements are stacked or unstacked. If a stacked write/scatter is done, we make +# the flow stacked as well. +# - We use some heuristic traversal of the graph to track whether the +# TensorArray handle was created inside or outside the pfor loop. + + +@RegisterPFor("TensorArrayV3") +def _convert_tensor_array_v3(pfor_input: _PforInput): + size = pfor_input.unstacked_input(0) + dtype = pfor_input.get_attr("dtype") + dynamic_size = pfor_input.get_attr("dynamic_size") + clear_after_read = pfor_input.get_attr("clear_after_read") + identical_element_shapes = pfor_input.get_attr("identical_element_shapes") + tensor_array_name = pfor_input.get_attr("tensor_array_name") + handle, flow = data_flow_ops.tensor_array_v3( + size, + dtype=dtype, + # We don't set element shape since we don't know if writes are stacked or + # not yet. + element_shape=None, + dynamic_size=dynamic_size, + clear_after_read=clear_after_read, + identical_element_shapes=identical_element_shapes, + tensor_array_name=tensor_array_name) + # Note we keep flow unstacked for now since we don't know if writes will be + # stacked or not. + return wrap(handle, False), wrap(flow, False) + + +@RegisterPFor("TensorArraySizeV3") +def _convert_tensor_array_size_v3(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + flow, flow_stacked, _ = pfor_input.input(1) + if flow_stacked: + flow = _unstack_flow(flow) + size = data_flow_ops.tensor_array_size_v3(handle, flow) + return wrap(size, False) + + +def _handle_inside_pfor(pfor_input: _PforInput, handle): + """Returns True if handle was created inside the pfor loop.""" + # We use some heuristic to find the original TensorArray creation op. + # The logic should handle the common cases (except cond based subgraphs). + # In theory the user could perform different operations on the handle (like + # Reshape, stack multiple handles, etc) which could break this logic. + # TODO(agarwal): handle Switch/Merge. + while handle.op.type in ("Enter", "Identity"): + handle = handle.op.inputs[0] + if handle.op.type not in [ + "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape" + ]: + raise ValueError(f"Unable to find source for handle {handle}.") + else: + return pfor_input.pfor.op_is_inside_loop(handle.op) + + +def _unstack_flow(value): + # TODO(agarwal): consider looking if this is a Tile op then get its input. + # This may avoid running the Tile operations. + return array_ops.gather(value, 0) + + +@RegisterPFor("TensorArrayReadV3") +def _convert_tensor_array_read_v3(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + index, index_stacked, _ = pfor_input.input(1) + dtype = pfor_input.get_attr("dtype") + flow, flow_stacked, _ = pfor_input.input(2) + if flow_stacked: + flow = _unstack_flow(flow) + + is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside_pfor: + # Note that if we are inside a control flow construct inside the pfor, and + # only some of the iterations are doing the read (i.e. + # `all_indices_partitioned` is True), then the read operation should only + # return values for the currently active pfor iterations (`all_indices` + # below). Hence, whenever the returned value is stacked (i.e. `flow` is + # stacked), we may need to do an extra gather after reading the values. Also + # note that if `is_inside` is false, then values in the tensor array are + # unstacked. So the check is only needed in this branch. + all_indices = pfor_input.pfor.all_indices + all_indices_partitioned = pfor_input.pfor.all_indices_partitioned + # Note: flow_stacked indicates if values in the TensorArray are stacked or + # not. + if index_stacked: + if flow_stacked: + raise ValueError( + "It looks like TensorArrayReadV3 was called on a TensorArray whose" + " values are not loop-invariant, and the read indices were also" + " not loop invariant. This is currently unsupported.") + value = data_flow_ops.tensor_array_gather_v3( + handle, index, flow, dtype=dtype) + return wrap(value, True) + value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) + if flow_stacked and all_indices_partitioned: + value = array_ops.gather(value, all_indices) + return wrap(value, flow_stacked) + # Values in the TensorArray should be unstacked (since different iterations + # couldn't write to the same location). So whether output is stacked or not + # depends on index_stacked. + if index_stacked: + value = data_flow_ops.tensor_array_gather_v3( + handle, index, flow, dtype=dtype) + else: + value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) + return wrap(value, index_stacked) + + +@RegisterPFor("TensorArrayWriteV3") +def _convert_tensor_array_write_v3(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + index, index_stacked, _ = pfor_input.input(1) + value, value_stacked, _ = pfor_input.input(2) + flow, flow_stacked, _ = pfor_input.input(3) + if value_stacked and pfor_input.pfor.all_indices_partitioned: + # Looks like we are in a control flow in a pfor where not all iterations are + # active now. We don't allow that since that could lead to different indices + # having different shapes which will be hard to merge later. + raise ValueError("Writing non loop invariant values to TensorArray from " + "inside a while_loop/cond not supported.") + if flow_stacked: + flow = _unstack_flow(flow) + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + if index_stacked: + raise ValueError(f"Need indices for {handle} to be loop invariant.") + if not flow_stacked and not value_stacked: + flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) + return wrap(flow_out, False) + else: + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + # TODO(agarwal): Note that if flow is unstacked and value is stacked, then + # this may or may not be a safe situation. flow is unstacked both for a + # freshly created TensorArray, as well as after unstacked values are + # written to it. If it is the latter, then we cannot write a stacked value + # now since that may cause runtime errors due to different shapes in the + # array. At the moment we are not able to handle this gracefully and + # distinguish between the two cases. That would require some heuristic + # traversal of the graph to figure out whether all the writes are + # unstacked or not. + flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + else: + if not index_stacked: + raise ValueError(f"Need indices for {handle} to be not loop invariant.") + # Note that even when index_stacked is true, actual values in index may + # still not be unique. However that will cause runtime error when executing + # the scatter operation below. + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + + +def _transpose_first_two_dims(value): + # TODO(agarwal): optimize if one of the dims == 1. + value_shape = array_ops.shape(value) + v0 = value_shape[0] + v1 = value_shape[1] + value = array_ops.reshape(value, [v0, v1, -1]) + value = array_ops.transpose(value, [1, 0, 2]) + new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) + return array_ops.reshape(value, new_shape) + + +@RegisterPFor("TensorArrayGatherV3") +def _convert_tensor_array_gather_v3(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + indices, indices_stacked, _ = pfor_input.input(1) + indices = array_ops.reshape(indices, [-1]) + flow, flow_stacked, _ = pfor_input.input(2) + if flow_stacked: + flow = _unstack_flow(flow) + dtype = pfor_input.get_attr("dtype") + # TODO(agarwal): support element_shape attr? + + n = pfor_input.pfor.loop_len_vector + value = data_flow_ops.tensor_array_gather_v3( + handle, indices, flow, dtype=dtype) + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + # flow_stacked indicates if values in the TensorArray are stacked or not. + if indices_stacked: + if flow_stacked: + raise ValueError( + "It looks like TensorArrayGatherV3 was called on a TensorArray " + "whose values are not loop-invariant, and the indices were also " + "not loop invariant. This is currently unsupported.") + else: + value = _unflatten_first_dim(value, n) + return wrap(value, True) + else: + if flow_stacked: + # Since elements in this array are stacked and `value` was produced by + # gather, its first two dims are "gathered elements" and "stack + # dimension". Our semantics require these two to be flipped. + value = _transpose_first_two_dims(value) + return wrap(value, flow_stacked) + else: + # Values in the TensorArray should be unstacked (since different iterations + # couldn't write to the same location). So whether output is stacked or not + # depends on indices_stacked. + if indices_stacked: + value = _unflatten_first_dim(value, n) + return wrap(value, indices_stacked) + + +@RegisterPFor("TensorArrayScatterV3") +def _convert_tensor_array_scatter_v3(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + indices, indices_stacked, _ = pfor_input.input(1) + indices = array_ops.reshape(indices, [-1]) + value, value_stacked, _ = pfor_input.input(2) + flow, flow_stacked, _ = pfor_input.input(3) + + if flow_stacked: + flow = _unstack_flow(flow) + + is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) + if is_inside: + if indices_stacked: + raise ValueError(f"Need indices for {handle} to be loop invariant.") + # Note that flow_stacked indicates if existing values in the array are + # stacked or not. + if not flow_stacked and not value_stacked: + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return wrap(flow_out, False) + if not value_stacked: + # TODO(agarwal): tile in the second dimension directly instead of + # transposing below. + value = _stack(value, pfor_input.pfor.loop_len_vector).t + + value = _transpose_first_two_dims(value) + # TODO(agarwal): Note that if a previous write was unstacked, flow will be + # unstacked, and a stacked value may be written here which may cause + # runtime error due to different elements having different shape. We do + # not try to prevent that. + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, + flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + if not indices_stacked: + raise ValueError(f"Need indices for {handle} to be not loop invariant.") + if not value_stacked: + value = _stack(value, pfor_input.pfor.loop_len_vector).t + value = _flatten_first_two_dims(value) + flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow) + return _stack(flow_out, pfor_input.pfor.loop_len_vector) + + +@RegisterPFor("TensorArrayGradV3") +def _convert_tensor_array_grad_v3(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + flow, flow_stacked, _ = pfor_input.input(1) + if flow_stacked: + flow = _unstack_flow(flow) + source = pfor_input.get_attr("source") + # TODO(agarwal): For now, we assume that gradients are stacked if the + # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong + # will give runtime error due to incorrect shape being written to the + # accumulator. It is difficult to know in advance if gradients written will be + # stacked or not. Note that flow being stacked is not indicative of the + # gradient being stacked or not. Revisit this later. + shape_to_prepend = pfor_input.pfor.loop_len_vector + grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( + handle=handle, + flow_in=flow, + shape_to_prepend=shape_to_prepend, + source=source) + flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t + return [wrap(grad_handle, False), wrap(flow_out, True)] + + +def _stack_tensor_list_shape(shape, first_dim): + shape_value = tensor_util.constant_value(shape) + # Note that negative values in the shape are used to signify unknown shapes + # and are handled in a special way. + if shape_value is not None: + shape_value = numpy_compat.np_asarray(shape_value) + if -1 in shape_value: + return constant_op.constant(-1) + elif not shape_value.size: + return first_dim + else: + shape = array_ops.reshape(shape, [-1]) + return tf_cond.cond( + math_ops.reduce_any(shape < 0), + lambda: constant_op.constant(-1), + lambda: array_ops.concat([first_dim, shape], axis=0)) + + +def _tile_variant_with_length(t, length): + """stacks `t` `length` times.""" + if _is_variant_with_internal_stacking(t): + # The content of TensorLists is vectorized, not the variant itself. + return t + original_tensor = t + t.set_shape([]) + t = array_ops.reshape(t, [-1]) + with ops.device("CPU:0"): + result = array_ops.tile(t, length) + # TODO(b/169968286): Should regular shape functions do handle data + # propagation here? + handle_data_util.copy_handle_data(original_tensor, result) + return result + + +def _tile_variant(t, pfor_input: _PforInput): + """stacks `t` according to its loop context.""" + return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector) + + +def _untile_variant(t): + if _is_variant_with_internal_stacking(t): + # The content of TensorLists is vectorized, not the variant itself. + if not t.shape.is_compatible_with([]): + raise AssertionError( + ("Unexpectedly saw a vectorized variant (e.g. TensorList) with " + f"non-scalar shape: {t!r}")) + return t + return array_ops.gather(t, 0) + + +@RegisterPFor("OptionalFromValue") +def _convert_optional_from_value(pfor_input: _PforInput): + pfor_input.stack_inputs() + return wrap( + gen_optional_ops.optional_from_value([x.t for x in pfor_input.inputs]), + True, + ) + + +@RegisterPFor("OptionalGetValue") +def _convert_optional_get_value(pfor_input: _PforInput): + handle = pfor_input.stacked_input(0) + output_types = pfor_input.get_attr("output_types") + original_output_shapes = pfor_input.get_attr("output_shapes") + output_shapes = [] + for shape in original_output_shapes: + shape = tensor_shape.TensorShape(shape) + loop_len_value = tensor_util.constant_value(pfor_input.pfor.loop_len_vector) + loop_len_shape = tensor_shape.TensorShape( + [loop_len_value[0] if loop_len_value is not None else None] + ) + shape = loop_len_shape.concatenate(shape) + output_shapes.append(shape.as_proto()) + results = gen_optional_ops.optional_get_value( + handle, output_types, output_shapes + ) + return [wrap(t, True) for t in results] + + +@RegisterPFor("TensorListReserve") +def _convert_tensor_list_reserve(pfor_input: _PforInput): + element_shape = pfor_input.unstacked_input(0) + num_elements = pfor_input.unstacked_input(1) + element_dtype = pfor_input.get_attr("element_dtype") + + # Prepend a dimension to element_shape. + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) + handle = list_ops.tensor_list_reserve( + element_shape, num_elements, element_dtype=element_dtype) + + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListElementShape") +def _convert_tensor_list_element_shape(pfor_input: _PforInput): + handle = _untile_variant(pfor_input.stacked_input(0)) + shape_type = pfor_input.get_attr("shape_type") + shape = list_ops.tensor_list_element_shape(handle, shape_type) + shape = array_ops.reshape(shape, [-1]) + shape = shape[1:] + return wrap(shape, False) + + +@RegisterPFor("TensorListLength") +def _convert_tensor_list_length(pfor_input: _PforInput): + handle = _untile_variant(pfor_input.stacked_input(0)) + return wrap(list_ops.tensor_list_length(handle), False) + + +def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None): + if element_shape is None: + element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32) + length = list_ops.tensor_list_length(handle) + new_handle = list_ops.tensor_list_reserve( + _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype) + + def _body_fn(i, h): + elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape) + elem = _stack(elem, loop_len_vector).t + return i + 1, list_ops.tensor_list_set_item(h, i, elem) + + return while_loop.while_loop(lambda i, _: i < length, _body_fn, + [0, new_handle])[1] + + +@RegisterPFor("TensorListGetItem") +def _convert_tensor_list_get_item(pfor_input: _PforInput): + handle, handle_stacked, _ = pfor_input.input(0) + index, index_stacked, _ = pfor_input.input(1) + element_shape = pfor_input.unstacked_input(2) + element_dtype = pfor_input.get_attr("element_dtype") + + if handle_stacked: + handle = _untile_variant(handle) + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) + if index_stacked: + # We use a sequential loop since that may be more efficient than first + # gathering and concatenating all the element corresponding to `index`, + # and then doing a gather on it. + def _map_fn(i): + item_i = list_ops.tensor_list_get_item( + handle, + index[i], + element_dtype=element_dtype) + return array_ops.gather(item_i, i) + + output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) + return wrap(output, True) + else: + output = list_ops.tensor_list_get_item( + handle, + index, + element_shape=element_shape, + element_dtype=element_dtype) + return wrap(output, True) + else: + assert index_stacked + return wrap( + list_ops.tensor_list_gather( + handle, + index, + element_shape=element_shape, + element_dtype=element_dtype), True) + + +@RegisterPFor("TensorListSetItem") +def _convert_tensor_array_set_item(pfor_input: _PforInput): + handle, handle_stacked, _ = pfor_input.input(0) + index, index_stacked, _ = pfor_input.input(1) + item, item_stacked, _ = pfor_input.input(2) + + if not handle_stacked: + # Special case where we can statically guarantee that the indices are + # disjoint. + if index is pfor_input.pfor.all_indices: + if not item_stacked: + item = _stack(item, pfor_input.pfor.loop_len_vector).t + return wrap( + list_ops.tensor_list_scatter(item, index, input_handle=handle), False) + else: + handle = _stack_tensor_list(handle, item.dtype, + pfor_input.pfor.loop_len_vector) + else: + handle = _untile_variant(handle) + + if index_stacked: + # TODO(agarwal): handle this. + raise ValueError("Vectorizing writes to a TensorList with loop " + "variant indices is currently unsupported.") + + else: + if not item_stacked: + item = _stack(item, pfor_input.pfor.loop_len_vector).t + handle = list_ops.tensor_list_set_item(handle, index, item) + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListPushBack") +def _convert_tensor_list_push_back(pfor_input: _PforInput): + handle, handle_stacked, _ = pfor_input.input(0) + tensor, tensor_stacked, _ = pfor_input.input(1) + if handle_stacked: + handle = _untile_variant(handle) + else: + handle = _stack_tensor_list(handle, tensor.dtype, + pfor_input.pfor.loop_len_vector) + if not tensor_stacked: + tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t + handle = list_ops.tensor_list_push_back(handle, tensor) + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListPopBack") +def _convert_tensor_array_push_back(pfor_input: _PforInput): + handle = pfor_input.stacked_input(0) + element_shape = pfor_input.unstacked_input(1) + handle = _untile_variant(handle) + + if element_shape.shape.ndims == 0: + # Default / unspecified + vectorized_shape = -1 + else: + # PopBack has an element shape set when it's the gradient of PushBack, only + # used when the list is uninitialized. + n = math_ops.cast(pfor_input.pfor.loop_len_vector, element_shape.dtype) + vectorized_shape = array_ops.concat([n, element_shape], axis=0) + + output_handle, tensor = gen_list_ops.tensor_list_pop_back( + input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"), + element_shape=vectorized_shape) + return wrap(output_handle, True), wrap(tensor, True) + + +@RegisterPFor("TensorListConcatV2") +def _convert_tensor_list_concat_v2(pfor_input: _PforInput): + input_handle = pfor_input.stacked_input(0) + element_shape = pfor_input.unstacked_input(1) + leading_dims = pfor_input.unstacked_input(2) + element_dtype = pfor_input.get_attr("element_dtype") + + handle = _untile_variant(input_handle) + length = list_ops.tensor_list_length(handle) + # Note that element_shape attribute can have incomplete shapes. This doesn't + # seem to work well when creating another list and then doing a concat on it. + # Hence we try to find the dynamic shape here. + element_shape = tf_cond.cond( + length > 0, lambda: array_ops.shape( + list_ops.tensor_list_get_item(handle, 0, element_dtype, None)), + lambda: constant_op.constant([0, 0], dtype=dtypes.int32)) + # The code below creates a copy of the list with each elements' first two + # dimensions transposed. + new_element_shape = array_ops.concat( + [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0) + + # Create a new TensorList with elements transposed. + def _transpose_elem(i, h): + elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None) + elem = _transpose_first_two_dims(elem) + return i + 1, list_ops.tensor_list_set_item(h, i, elem) + + new_handle = list_ops.tensor_list_reserve(new_element_shape, length, + element_dtype) + new_handle = while_loop.while_loop(lambda i, _: i < length, _transpose_elem, + [0, new_handle])[1] + output, lengths = gen_list_ops.tensor_list_concat_v2( + input_handle=new_handle, + element_dtype=element_dtype, + element_shape=new_element_shape, + leading_dims=leading_dims) + output = _transpose_first_two_dims(output) + return wrap(output, True), wrap(lengths, False) + + +@RegisterPFor("TensorListStack") +def _convert_tensor_list_stack(pfor_input: _PforInput): + handle = pfor_input.stacked_input(0) + input_shape = pfor_input.unstacked_input(1) + element_dtype = pfor_input.get_attr("element_dtype") + num_elements = pfor_input.get_attr("num_elements") + + handle = _untile_variant(handle) + input_shape = _stack_tensor_list_shape(input_shape, + pfor_input.pfor.loop_len_vector) + output = list_ops.tensor_list_stack( + handle, + element_dtype, + element_shape=input_shape, + num_elements=num_elements) + output = _transpose_first_two_dims(output) + return wrap(output, True) + + +@RegisterPFor("TensorListGather") +def _convert_tensor_list_gather(pfor_input: _PforInput): + handle, handle_stacked, _ = pfor_input.input(0) + index, index_stacked, _ = pfor_input.input(1) + element_shape = pfor_input.unstacked_input(2) + element_dtype = pfor_input.get_attr("element_dtype") + + if handle_stacked: + handle = _untile_variant(handle) + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) + if index_stacked: + # We use a sequential loop since that may be more efficient than first + # gathering and concatenating all the element corresponding to `index`, + # and then doing a gather on it. + def _map_fn(i): + item_i = list_ops.tensor_list_gather( + handle, + index[i], + element_dtype=element_dtype) + axis = array_ops.rank(index) - 1 + return array_ops.gather(item_i, i, axis=axis) + + output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) + return wrap(output, True) + else: + output = list_ops.tensor_list_gather( + handle, + index, + element_shape=element_shape, + element_dtype=element_dtype) + return wrap(output, True) + else: + assert index_stacked + index_shape = array_ops.shape(index) + index = array_ops.reshape(index, [-1]) + values = list_ops.tensor_list_gather( + handle, index, element_shape=element_shape, element_dtype=element_dtype) + final_shape = array_ops.concat( + [index_shape, array_ops.shape(values)[1:]], axis=0) + return wrap(array_ops.reshape(values, final_shape), True) + + +@RegisterPFor("TensorListScatterIntoExistingList") +def _convert_tensor_list_scatter(pfor_input: _PforInput): + pfor_input.stack_inputs([1]) + handle, handle_stacked, _ = pfor_input.input(0) + item = pfor_input.stacked_input(1) + indices, indices_stacked, _ = pfor_input.input(2) + if handle_stacked: + handle = _untile_variant(handle) + else: + handle = _stack_tensor_list(handle, item.dtype, + pfor_input.pfor.loop_len_vector) + + item = _transpose_first_two_dims(item) + if indices_stacked: + # Pretend the list is a dense tensor: + # list_as_dense: Tensor[list_len, loop_len, ...] + # And indices are a tensor with shape (before transpose): + # indices: Tensor[loop_len, num_scatters] + # The item to scatter has shape (before transpose): + # item: Tensor[loop_len, num_scatters, ...] + # + # We want list_as_dense[indices[i, j], i] = item[i, j] + # + # Since we're not just indexing along the first axis of `list_as_dense`, we + # need to first extract the relevant list entries based on `indices`, + # scatter into them according to the loop index, and re-scatter the chunks + # we updated back into the list. + indices = _transpose_first_two_dims(indices) + indices_flat = array_ops.reshape(indices, [-1]) + # In many cases `indices` will be unique across pfor iterations, but this is + # not guaranteed. If there are duplicates, we need to map multiple updates + # to a single chunk extracted from the list. The last update should win. + unique_indices = array_ops.unique(indices_flat) + gathered_items = list_ops.tensor_list_gather( + handle, unique_indices.y, element_dtype=item.dtype, + element_shape=array_ops.shape(item)[1:]) + loop_idx = math_ops.range(pfor_input.pfor.loop_len_vector[0]) + scatters_per_op = array_ops.shape(indices)[0] + + unique_indices_loop_idx = array_ops.reshape(array_ops.tile( + loop_idx[None, :], [scatters_per_op, 1]), [-1]) + scatter_indices = array_ops_stack.stack( + [unique_indices.idx, unique_indices_loop_idx], + axis=1) + # This op does *not* guarantee last-update-wins on GPU, so semantics may not + # be exactly preserved for duplicate updates there. + scattered = array_ops.tensor_scatter_nd_update( + tensor=gathered_items, + indices=scatter_indices, + updates=_flatten_first_two_dims(item)) + handle = list_ops.tensor_list_scatter( + scattered, unique_indices.y, input_handle=handle) + else: + handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle) + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorListFromTensor") +def _convert_tensor_list_from_tensor(pfor_input: _PforInput): + tensor = pfor_input.stacked_input(0) + element_shape = pfor_input.unstacked_input(1) + tensor = _transpose_first_two_dims(tensor) + element_shape = _stack_tensor_list_shape(element_shape, + pfor_input.pfor.loop_len_vector) + handle = list_ops.tensor_list_from_tensor(tensor, element_shape) + return wrap(_tile_variant(handle, pfor_input), True) + + +@RegisterPFor("TensorScatterUpdate") +def _convert_tensor_scatter_update(pfor_input: _PforInput): + pfor_input.stack_inputs([0, 1, 2]) + tensor = pfor_input.stacked_input(0) + indices = pfor_input.stacked_input(1) + updates = pfor_input.stacked_input(2) + + indices_shape = array_ops.shape(indices) + indices_rank = array_ops.rank(indices) + loop_length = indices_shape[0] + + # Create a loop count range and extend its dimensions to match `indices`. + loop_count_shape = array_ops.tensor_scatter_nd_update( + array_ops.ones([indices_rank], dtype=dtypes.int32), [[0]], [loop_length]) + loop_count = array_ops.reshape(math_ops.range(loop_length), loop_count_shape) + + # Tile the loop count range for the batch dimensions (all except the first and + # last dimensions of indices). + # Rank(indices) >= 3 always for this function so we always have at least 1. + tile_multiplier = array_ops.tensor_scatter_nd_update( + indices_shape, [[0], [indices_rank - 1]], [1, 1]) + meta_index = array_ops.tile(loop_count, tile_multiplier) + + # Insert the loop-identifying index. + indices = array_ops.concat([meta_index, indices], axis=-1) + + result = array_ops.tensor_scatter_nd_update(tensor, indices, updates) + return wrap(result, True) + +# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar +# to TensorArrays, we convert them by changing the dimension of the elements +# inside the stack. +# +# We consider two cases: +# +# 1. StackV2 is constructed and used entirely inside the pfor loop. +# We keep a single Stack and perform the push/pop operations of all the +# iterations in lock-step. We also assume that all the iterations perform these +# operations. In case of dynamic control flow, if only some of the iterations +# try to perform a push/pop, then the conversion may not work correctly and may +# cause undefined behavior. +# TODO(agarwal): test StackV2 with dynamic control flow. +# +# 2. StackV2 is constructed outside the pfor loop. +# Performing stack push/pop in a parallel fashion is ill-defined. However given +# that reading stacks created externally is a common operation when computing +# jacobians, we provide some special semantics here as follows. +# - disallow push operations to the stack +# - pop operations are performed in lock step by all iterations, similar to the +# case when the stack is created inside. A single value is popped during the +# lock-step operation and broadcast to all the iterations. Values in the stack +# are assumed to be loop-invariant. +# +# Some other implementation details: +# We use an ugly logic to find whether values in Stack data structure are +# loop invariant or not. When converting push/pop operations, we keep track of +# whether the last conversion used a stacked value or not (see _stack_cache +# below). As a result if an unstacked value is written first, subsequent stacked +# writes are disallowed when they could have been allowed in theory. + +# Map from cache key based on StackV2 handle to a bool indicating whether values +# are stacked or not. +# TODO(agarwal): move _stack_cache inside pfor? +_stack_cache = {} + + +def _stack_cache_key(pfor_input: _PforInput): + """Create cache key corresponding to a stack handle.""" + op_type = pfor_input.op_type + assert op_type in ["StackPushV2", "StackPopV2"], op_type + orig_handle = pfor_input.op.inputs[0] + while orig_handle.op.type in ["Identity", "Enter"]: + orig_handle = orig_handle.op.inputs[0] + assert orig_handle.op.type == "StackV2", orig_handle.op + return ops.get_default_graph(), pfor_input.pfor, orig_handle + + +def _stack_handle_inside_pfor(handle, pfor_input: _PforInput): + while handle.op.type in ["Identity", "Enter"]: + handle = handle.op.inputs[0] + assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" % + handle.op) + return pfor_input.pfor.op_is_inside_loop(handle.op) + + +@RegisterPFor("StackPushV2") +def _convert_stack_push_v2(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + elem, elem_stacked, _ = pfor_input.input(1) + swap_memory = pfor_input.get_attr("swap_memory") + + if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): + raise ValueError("StackPushV2 not allowed on stacks created outside pfor.") + stack_cache_key = _stack_cache_key(pfor_input) + stacked = _stack_cache.get(stack_cache_key, None) + if stacked is None: + stacked = elem_stacked + _stack_cache[stack_cache_key] = stacked + else: + # If we previously made it unstacked then we can't revert to being stacked. + if not stacked and elem_stacked: + raise ValueError( + "It looks like the stack was previously determined to be loop " + "invariant, but we are now trying to push a loop dependent value " + "to it. This is currently unsupported.") + if stacked and not elem_stacked: + elem = _stack(elem, pfor_input.pfor.loop_len_vector).t + out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) + return wrap(out, stacked) + + +# Note that inputs to this convertor will be unstacked. However it should get +# called since it is a stateful op. +@RegisterPFor("StackPopV2") +def _convert_stack_pop_v2(pfor_input: _PforInput): + handle = pfor_input.unstacked_input(0) + stack_cache_key = _stack_cache_key(pfor_input) + stacked = _stack_cache.get(stack_cache_key, None) + # If a StackPushV2 has not been converted yet, we default to unstacked since + # the push could be outside of pfor, or the convertor may not be called if the + # inputs are unconverted. + if stacked is None: + stacked = False + _stack_cache[stack_cache_key] = False + elem_type = pfor_input.get_attr("elem_type") + out = data_flow_ops.stack_pop_v2(handle, elem_type) + return wrap(out, stacked) + + +# parsing_ops + + +@RegisterPFor("DecodeCSV") +def _convert_decode_csv(pfor_input: _PforInput): + lines = pfor_input.stacked_input(0) + record_defaults = [ + pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) + ] + field_delim = pfor_input.get_attr("field_delim") + use_quote_delim = pfor_input.get_attr("use_quote_delim") + select_cols = pfor_input.get_attr("select_cols") + if not select_cols: + select_cols = None + return [ + wrap(t, True) for t in gen_parsing_ops.decode_csv( + lines, + record_defaults, + field_delim=field_delim, + use_quote_delim=use_quote_delim, + select_cols=select_cols) + ] + + +@RegisterPFor("ParseSingleExample") +def _convert_parse_single_example(pfor_input: _PforInput): + serialized = pfor_input.stacked_input(0) + dense_defaults = [ + pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) + ] + sparse_keys = pfor_input.get_attr("sparse_keys") + dense_keys = pfor_input.get_attr("dense_keys") + sparse_types = pfor_input.get_attr("sparse_types") + dense_shapes = pfor_input.get_attr("dense_shapes") + output = gen_parsing_ops.parse_example( + serialized=serialized, + names=[], + dense_defaults=dense_defaults, + sparse_keys=sparse_keys, + dense_keys=dense_keys, + sparse_types=sparse_types, + dense_shapes=dense_shapes) + return [wrap(t, True, True) for t in nest.flatten(output)] + + +@RegisterPFor("ParseExampleV2") +def _convert_parse_example_v2(pfor_input: _PforInput): + serialized = pfor_input.stacked_input(0) + sparse_keys = pfor_input.unstacked_input(2) + dense_keys = pfor_input.unstacked_input(3) + ragged_keys = pfor_input.unstacked_input(4) + dense_defaults = [ + pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs) + ] + num_sparse = pfor_input.get_attr("num_sparse") + sparse_types = pfor_input.get_attr("sparse_types") + ragged_value_types = pfor_input.get_attr("ragged_value_types") + ragged_split_types = pfor_input.get_attr("ragged_split_types") + dense_shapes = pfor_input.get_attr("dense_shapes") + if serialized.shape.ndims not in (None, 1): + raise ValueError("ParseExampleV2 can only be converted if `serialized` " + f"is scalar. Received shape: {serialized.shape}.") + output = gen_parsing_ops.parse_example_v2( + serialized=serialized, + names=[], + sparse_keys=sparse_keys, + dense_keys=dense_keys, + ragged_keys=ragged_keys, + dense_defaults=dense_defaults, + num_sparse=num_sparse, + sparse_types=sparse_types, + ragged_value_types=ragged_value_types, + ragged_split_types=ragged_split_types, + dense_shapes=dense_shapes) + return [wrap(t, True, True) for t in nest.flatten(output)] + + +# functional_ops + + +def _convert_function_call(func, converter, inputs): + assert isinstance(func.graph, func_graph.FuncGraph), func + assert isinstance(converter, PFor) + + graph_outputs = func.graph.outputs[:len(func.function_type.flat_outputs)] + # TODO(agarwal): consider caching this function definition. + @def_function.function + def f(*args): + assert all(isinstance(arg, WrappedTensor) for arg in args), args + assert len(args) == len(func.graph.inputs), (args, func.graph.inputs) + # Map inputs to function arguments. + for inp, arg in zip(func.graph.inputs, args): + converter._add_conversion(inp, arg) + # Convert output tensors. + return tuple([converter._convert_helper(x).t for x in graph_outputs]) + + call_outputs = f(*inputs) + assert len(call_outputs) == len(graph_outputs) + outputs = [] + for call_output, output_tensor in zip(call_outputs, graph_outputs): + func_output = converter._convert_helper(output_tensor) + outputs.append( + wrap(call_output, func_output.is_stacked, func_output.is_sparse_stacked) + ) + return outputs + + +@RegisterPFor("StatefulPartitionedCall") +@RegisterPFor("PartitionedCall") +def _convert_partitioned_call(pfor_input: _PforInput): + func_name = pfor_input.get_attr("f").name + func = pfor_input.op.graph._get_function(compat.as_bytes(func_name)) + assert isinstance(func.graph, func_graph.FuncGraph), ( + "Could not find FuncGraph object for %s. Got func %s" % (func_name, func)) + pfor = pfor_input.pfor + converter = PFor( + loop_var=pfor.loop_var, + loop_len=pfor.loop_len_vector[0], + pfor_ops=func.graph.get_operations(), + fallback_to_while_loop=pfor.fallback_to_while_loop, + all_indices=pfor.all_indices, + all_indices_partitioned=pfor.all_indices_partitioned, + pfor_config=pfor.pfor_config) + return _convert_function_call(func, converter, pfor_input.inputs) + + +def _partition_inputs_for_indices(inputs, indices): + new_inputs = [] + for inp in inputs: + if inp.is_stacked: + new_inputs.append(wrap(array_ops.gather(inp.t, indices), True)) + else: + new_inputs.append(inp) + return new_inputs + + +def _outputs_for_branch(func_name, indices, pfor_input: _PforInput, inputs): + if indices is None: + indices = pfor_input.pfor.all_indices + partitioned = pfor_input.pfor.all_indices_partitioned + else: + partitioned = True + func = pfor_input.op.graph._get_function(func_name) + converter = PFor( + loop_var=pfor_input.pfor.loop_var, + loop_len=array_ops.size(indices), + pfor_ops=func.graph.get_operations(), + fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop, + all_indices=indices, + all_indices_partitioned=partitioned, + pfor_config=pfor_input.pfor.pfor_config) + outputs = _convert_function_call(func, converter, inputs) + stacked_outputs = [] + for out in outputs: + if not out.is_stacked: + stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t) + else: + stacked_outputs.append(out.t) + return stacked_outputs + + +# TODO(agarwal): Currently the converted code aggressively tiles loop variant +# outputs from the then/else branches. Instead, it could do so only if at least +# one of the branch outputs is loop variant. +@RegisterPFor("StatelessIf") +@RegisterPFor("If") +def _convert_if(pfor_input: _PforInput): + cond, cond_stacked, _ = pfor_input.input(0) + inputs = pfor_input.inputs[1:] + then_branch = pfor_input.get_attr("then_branch") + else_branch = pfor_input.get_attr("else_branch") + + if cond_stacked: + cond_int = math_ops.cast(cond, dtypes.int32) + # Compute loop indices for the different branches + false_indices, true_indices = data_flow_ops.dynamic_partition( + pfor_input.pfor.all_indices, cond_int, 2) + # Compute indices for cond being True or False. + if pfor_input.pfor.all_indices_partitioned: + else_indices, then_indices = data_flow_ops.dynamic_partition( + math_ops.range(pfor_input.pfor.loop_len_vector[0]), + cond_int, 2) + else: + else_indices, then_indices = false_indices, true_indices + # Partition inputs + then_inputs = _partition_inputs_for_indices(inputs, then_indices) + else_inputs = _partition_inputs_for_indices(inputs, else_indices) + + # Convert "then" branch. + then_outputs = _outputs_for_branch(then_branch.name, true_indices, + pfor_input, then_inputs) + + # Convert "else" branch. + else_outputs = _outputs_for_branch(else_branch.name, false_indices, + pfor_input, else_inputs) + + assert len(then_outputs) == len(else_outputs) + # Note that if the "then" and "else" branches are updating the same state, + # and possibly reading them as well, it could lead to undefined behavior + # since the ordering of those operations is not well defined. + # One possibility is to order all the "then" branches to execute before all + # the "else" branches so that the side-effects in the former are visible to + # the latter. For now, we leave that as undefined behavior. + outputs = [] + # Merge outputs + for then_output, else_output in zip(then_outputs, else_outputs): + out = data_flow_ops.dynamic_stitch([then_indices, else_indices], + [then_output, else_output]) + outputs.append(wrap(out, True)) + return outputs + else: + outputs = tf_cond.cond( + cond, + lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs), + lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs)) + return [wrap(t, True) for t in outputs] + + +@RegisterPFor("Case") +@RegisterPFor("StatelessCase") +def _convert_stateless_case(pfor_input: _PforInput): + branch_idx, is_stacked, _ = pfor_input.input(0) + branches = pfor_input.get_attr("branches") + inputs = pfor_input.inputs[1:] + + if is_stacked: + logging.info("Running stacked flow") + + # Compute loop indices for the different branches + switch_indices = data_flow_ops.dynamic_partition( + pfor_input.pfor.all_indices, branch_idx, len(branches)) + if pfor_input.pfor.all_indices_partitioned: + partitioned_indices = data_flow_ops.dynamic_partition( + math_ops.range(pfor_input.pfor.loop_len_vector[0]), branch_idx, + len(branches)) + else: + partitioned_indices = switch_indices + # Partition inputs + input_list = [] + for indices in partitioned_indices: + input_list.append(_partition_inputs_for_indices(inputs, indices)) + + outputs = [] + for (b, indices, inputs) in zip(branches, switch_indices, input_list): + out = _outputs_for_branch(b.name, indices, pfor_input, inputs) + outputs.extend(out) + + out = data_flow_ops.dynamic_stitch(partitioned_indices, outputs) + return [wrap(out, True)] + else: + new_branches = [] + for b in branches: + def new_function(func=b.name): + return _outputs_for_branch(func, None, pfor_input, + pfor_input.inputs[1:]) + + new_branches.append(new_function) + + outputs = [] + outputs = control_flow_switch_case.switch_case(branch_idx, new_branches) + return [wrap(t, True) for t in outputs] + + +class WhileV2: + """Object for vectorizing V2 while_loop op.""" + + def __init__(self, pfor_input: _PforInput): + self._pfor_input = pfor_input + self._pfor = pfor_input.pfor + cond_func_name = pfor_input.get_attr("cond").name + self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes( + cond_func_name)) + body_func_name = pfor_input.get_attr("body").name + self._body_func = pfor_input.op.graph._get_function(compat.as_bytes( + body_func_name)) + if self._cond_func is None or self._body_func is None: + raise ValueError("Error extracting cond and body functions for op " + f"{self._pfor_input.op}.") + # Indices of inputs that are passed unchanged through the while loop body. + # Typically these are tensors captured from outside the body context. + self._body_pass_through_indices = set() + for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs, + self._body_func.graph.outputs)): + if id(inp) == id(out): + self._body_pass_through_indices.add(i) + self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations") + + def _output_shapes(self): + # Calculate output shape for vectorized loop. This will be used as + # shape_invariant. Merges shape inference outputs with the `output_shapes` + # attribute of the op. + output_shapes = [out.shape for out in self._pfor_input.op.outputs] + shapes = self._pfor_input.get_attr("output_shapes") + if not shapes: + shapes = [tensor_shape.TensorShape(None) for _ in output_shapes] + else: + shapes = [tensor_shape.TensorShape(shape) for shape in shapes] + for i, shape in enumerate(shapes): + shape = shape.merge_with(output_shapes[i]) + pfor_input = self._pfor_input.input(i) + if pfor_input.is_stacked: + if _is_variant_with_internal_stacking(pfor_input.t): + shape = tensor_shape.TensorShape([]).concatenate(shape) + else: + shape = tensor_shape.TensorShape([None]).concatenate(shape) + output_shapes[i] = shape + assert len(output_shapes) == self._pfor_input.num_inputs + return output_shapes + + def _init_values(self): + """Create arguments passed to converted while_loop.""" + loop_len = self._pfor.loop_len_vector[0] + inputs = [] + # TensorArrays for outputs of converted while loop + output_tas = [] + + with ops.name_scope("while_init"): + for inp in self._pfor_input.inputs: + inputs.append(inp.t) + variant_type_id = _variant_type_id(inp.t) + if variant_type_id in _INTERNAL_STACKING_TYPE_IDS: + if variant_type_id != full_type_pb2.TFT_ARRAY: + raise NotImplementedError( + "While loop conversion is only supported for TensorLists. Got " + f"another variant {inp.t}, probably an optional. Please file " + "a bug.") + + # For TensorLists, the input format is: + # + # List[user_list_len, Tensor[loop_len, ...]] + # + # rather than the usual + # + # Tensor[loop_len, ...] + # + # The body of the loop will take and return lists in this "internal + # vectorization" format, so we want to keep it that way as much as + # possible. We'll accumulate finished iterations (only relevant for + # pfor-loop-variant while_loop conditions) in an accumulator with + # type : + # + # List[user_list_len, List[loop_len, Tensor[...]]] + # + # This means that each while_loop iteration, we'll iterate over the + # length of the TensorList, dividing done/remaining pfor loop indices + # and scattering the done indices into the inner nested list of the + # accumulator. + element_shape = list_ops.tensor_list_element_shape( + inp.t, dtypes.int32) + if inp.is_stacked: + # Shapes may be tf.constant(-1) for fully dynamic, in which case + # slicing is an error. + element_shape = tf_cond.cond( + math_ops.equal(array_ops.rank(element_shape), 0), + lambda: element_shape, + lambda: element_shape[1:]) + dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype + + def _init_loop_body(index, output_ta): + output_ta = output_ta.write( + index, + list_ops.tensor_list_reserve(element_shape, loop_len, dtype)) + return index + 1, output_ta + + length = list_ops.tensor_list_length(inp.t) + output_ta = tensor_array_ops.TensorArray( + inp.t.dtype, # Variant; this is a nested TensorList + size=length, + dynamic_size=True, + infer_shape=False) + _, output_ta = while_loop.while_loop(lambda index, _: index < length, + _init_loop_body, [0, output_ta]) + else: + output_ta = tensor_array_ops.TensorArray( + inp.t.dtype, + size=loop_len, + dynamic_size=False, + infer_shape=True) + output_tas.append(output_ta) + # See documentation for __call__ for the structure of init_values. + indices = ( + math_ops.range(self._pfor.loop_len_vector[0]) + if self._pfor.all_indices_partitioned else self._pfor.all_indices) + return [True, indices] + inputs + output_tas + + def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): + """Handles case when condition is pfor loop invariant.""" + # Note that all iterations end together. So we don't need to partition the + # inputs. + not_all_done = array_ops.reshape(conditions, []) + return not_all_done, indices, inputs, output_tas + + def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, + output_tas): + """Handles case when condition is pfor loop dependent.""" + # Compute if all iterations are done. + not_all_done = math_ops.reduce_any(conditions) + conditions_int = math_ops.cast(conditions, dtypes.int32) + # Partition the indices. + done_indices, new_indices = data_flow_ops.dynamic_partition( + indices, conditions_int, 2) + + new_inputs = [] + new_output_tas = [] + for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): + pass_through = i in self._body_pass_through_indices + if not pass_through and _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: + shape_and_type = _parse_variant_shapes_and_types(inp)[0] + element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32) + user_list_len = list_ops.tensor_list_length(inp) + + def _split_vectorized_ta_element(index, new_inp, new_out_ta): + elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype, + element_shape) + if stacked: + done_elem, new_elem = data_flow_ops.dynamic_partition( + elem, conditions_int, 2) + new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem) + else: + done_elem = _stack(elem, [array_ops.size(done_indices)]).t + done_accum = new_out_ta.read(index) + done_accum = list_ops.tensor_list_scatter( + tensor=done_elem, indices=done_indices, input_handle=done_accum) + new_out_ta = new_out_ta.write(index, done_accum) + return index + 1, new_inp, new_out_ta + + length = list_ops.tensor_list_length(inp) + new_inp = list_ops.tensor_list_reserve( + tensor_shape.TensorShape([None]) + + tensor_shape.TensorShape(shape_and_type.shape)[1:], + user_list_len, shape_and_type.dtype) + _, new_inp, out_ta = while_loop.while_loop( + lambda index, unused_new_inp, unused_new_out_ta: index < length, + _split_vectorized_ta_element, [0, new_inp, output_tas[i]]) + else: + # Partition the inputs. + if stacked: + done_inp, new_inp = data_flow_ops.dynamic_partition( + inp, conditions_int, 2) + else: + if not pass_through: + done_inp = _stack(inp, [array_ops.size(done_indices)]).t + new_inp = inp + + out_ta = output_tas[i] + if not pass_through: + # Note that done_indices can be empty. done_inp should also be empty + # in that case. + out_ta = out_ta.scatter(done_indices, done_inp) + new_inputs.append(new_inp) + new_output_tas.append(out_ta) + + assert len(new_output_tas) == len(output_tas) + assert len(new_inputs) == len(inputs) + return not_all_done, new_indices, new_inputs, new_output_tas + + def _process_body(self, inputs_stacked, new_indices, cond_stacked, + new_inputs, not_all_done): + """Convert the body function.""" + # This is used to store the indices of inputs to the while op that need to + # be stacked. This stacking may be needed in cases where the input to the + # while_loop is loop_invariant but the corresponding output is not. + mismatching_stacked_indices = [] + + def true_fn(): + """Converts the body function for all but last iteration.""" + wrapped_inputs = [wrap(inp, stacked) for inp, stacked in + zip(new_inputs, inputs_stacked)] + # Note the iterative process below to figure out loop invariance. + # Here we iterate on vectorization process till a fixed point. The issue + # is that the while body can take pfor loop invariant inputs but return + # loop variant outputs. For any loop variant output, the corresponding + # input has to be then made loop variant (since subsequent while + # iterations will need to see loop variant values). + # However once we make a new input loop variant, we might make other + # outputs loop variant. Hence we need to iterate till we get fixed point. + while True: + if self._pfor.all_indices_partitioned: + indices = array_ops.gather(self._pfor.all_indices, new_indices) + else: + indices = new_indices + body_pfor = PFor( + loop_var=self._pfor.loop_var, + loop_len=array_ops.size(new_indices), + pfor_ops=self._body_func.graph.get_operations(), + fallback_to_while_loop=self._pfor.fallback_to_while_loop, + all_indices=indices, + all_indices_partitioned=(self._pfor.all_indices_partitioned or + cond_stacked), + pfor_config=self._pfor.pfor_config) + stacking_mismatch = False + outputs = _convert_function_call(self._body_func, + body_pfor, + wrapped_inputs) + for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)): + if out.is_stacked != inp.is_stacked: + stacking_mismatch = True + mismatching_stacked_indices.append(i) + stacked = _stack(inp.t, [array_ops.size(new_indices)]) + if inp.t.dtype == dtypes.variant: + stacked = wrap( + _tile_variant_with_length(stacked.t, + [array_ops.size(new_indices)])) + wrapped_inputs[i] = stacked + if not stacking_mismatch: + if mismatching_stacked_indices: + # We needed to stack some inputs. This code will be abandoned and + # should not get executed. Hence we simply return `new_inputs` to + # make sure the graph construction code completes. + with ops.control_dependencies([ + control_flow_assert.Assert( + False, ["pfor ERROR: this branch should never execute"]) + ]): + return [array_ops.identity(x) for x in new_inputs] + else: + return [out.t for out in outputs] + + # If all are done, we simply return `new_inputs`. Else we need to run the + # body function. + return tf_cond.cond( + not_all_done, + true_fn, + lambda: list(new_inputs)), mismatching_stacked_indices + + def __call__(self): + """Converter for the V2 while_loop. + + The conversion of a while_loop is another while_loop. + + The arguments to this converted while_loop are as follows: + not_all_done: Boolean scalar Tensor indicating if all the pfor iterations + are done. + indices: int32 1-D Tensor storing the id of the pfor iterations that are not + done. + args: Remaining arguments. These can be divided into 2 categories: + - The first set of arguments correspond one-to-one to the inputs to the + unvectorized while_loop. + - The second set are TensorArrays, corresponding one-to-one to each output + of the unvectorized while_loop. Each TensorArray has `PFor.loop_len` + elements, i.e. the number of pfor iterations. At the end, the i'th + element of each TensorArray will contain the output computed by the i'th + iteration of pfor. Note that elements can be written into these tensors + arrays in any order, depending on when the corresponding pfor iteration + is done. + In each iteration, the while_loop body recomputes the condition for all + active pfor iterations to see which of them are now done. It then partitions + all the inputs and passes them along to the converted body. Values for all + the iterations that are done are written to TensorArrays indexed by the pfor + iteration number. When all iterations are done, the TensorArrays are stacked + to get the final value. + + Returns: + List of converted outputs. + """ + output_shapes = self._output_shapes() + # Note that we use these lists as a hack since we need the `body` to compute + # these values during construction of the while_loop graph. + cond_is_stacked = [None] + indices_to_stack = [] + + def cond(not_all_done, *_): + return not_all_done + + def body(not_all_done, indices, *args): + # See documentation for __call__ for the structure of *args. + num_inputs = self._pfor_input.num_inputs + inputs = args[:num_inputs] + output_tas = args[num_inputs:] + inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs] + assert len(inputs) >= len(output_tas) + assert len(inputs) == len(inputs_stacked) + # Convert condition + with ops.name_scope("while_cond"): + # Note that we set all_indices_partitioned to True here. At this point + # we don't know if indices will be partitioned. Hence we use the + # conservative value. + cond_pfor = PFor( + loop_var=self._pfor.loop_var, + loop_len=array_ops.size(indices), + pfor_ops=self._cond_func.graph.get_operations(), + fallback_to_while_loop=self._pfor.fallback_to_while_loop, + all_indices=indices, + all_indices_partitioned=True, + pfor_config=self._pfor.pfor_config) + + wrapped_inputs = [wrap(inp, stacked) for inp, stacked + in zip(inputs, inputs_stacked)] + conditions, cond_stacked, _ = _convert_function_call( + self._cond_func, + cond_pfor, + wrapped_inputs)[0] + cond_is_stacked[0] = cond_stacked + + # Recompute the new condition, write outputs of done iterations, and + # partition the inputs if needed. + if not cond_stacked: + (not_all_done, new_indices, new_inputs, + new_output_tas) = self._process_cond_unstacked(conditions, indices, + inputs, output_tas) + else: + (not_all_done, new_indices, new_inputs, + new_output_tas) = self._process_cond_stacked(conditions, indices, + inputs, inputs_stacked, + output_tas) + # Convert body + with ops.name_scope("while_body"): + # Compute the outputs from the body. + new_outputs, mismatching_stacked_indices = self._process_body( + inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done) + + indices_to_stack[:] = mismatching_stacked_indices + for i, new_output in enumerate(new_outputs): + new_output.set_shape(output_shapes[i]) + new_args = ([not_all_done, new_indices] + new_outputs + + list(new_output_tas)) + return tuple(new_args) + + # Note that we run the code below in a function since we might abandon the + # generated code in cases where the conversion dictates that some inputs be + # further stacked. Hence we run the graph construction using + # `get_concrete_function` and avoid calling the constructed function if not + # needed. + @def_function.function + def while_fn(): + # Create init_values that will be passed to the while_loop. + init_values = self._init_values() + ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in + self._pfor_input.outputs] + shape_invariants = ( + [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])] + + output_shapes + ta_shape_invariants) + + while_outputs = while_loop.while_loop( + cond, + body, + init_values, + shape_invariants=shape_invariants, + parallel_iterations=self._parallel_iterations) + if indices_to_stack: + # This function will be abandoned. + return while_outputs + else: + num_inputs = self._pfor_input.num_inputs + new_inputs = while_outputs[2:num_inputs+2] + output_tas = while_outputs[num_inputs+2:] + assert cond_is_stacked[0] is not None + outputs = [] + for i, inp in enumerate(new_inputs): + if cond_is_stacked[0]: + if i in self._body_pass_through_indices: + outputs.append(init_values[i + 2]) + else: + ta = output_tas[i] + if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: + shape_and_type = _parse_variant_shapes_and_types(inp)[0] + length = list_ops.tensor_list_length(inp) + + # We have been accumulating values in a: + # + # List[user_list_len, List[loop_len, Tensor[...]]] + # + # We want to return an output in the same format as the input: + # + # List[user_list_len, Tensor[loop_len, ...]] + # + # So we need to loop over the list and stack its contents. + def _stack_loop_body(index, output_list): + current_value = ta.read(index) + output_list = list_ops.tensor_list_set_item( + output_list, index, + list_ops.tensor_list_stack( + current_value, shape_and_type.dtype)) + return index + 1, output_list + + output_list = list_ops.tensor_list_reserve( + tensor_shape.TensorShape(shape_and_type.shape), length, + shape_and_type.dtype) + _, output_list = while_loop.while_loop( + lambda index, _: index < length, _stack_loop_body, + [0, output_list]) + outputs.append(output_list) + else: + outputs.append(ta.stack()) + else: + outputs.append(inp) + return outputs + + _ = while_fn.get_concrete_function() + if indices_to_stack: + # Need to abandon the current conversion, stack some inputs and restart. + self._pfor_input.stack_inputs( + stack_indices=indices_to_stack, tile_variants=True) + # Note that this call will recurse at most one time. The first call will + # do the required stacking, based on the iterative procedure in + # _process_body, and the next invocation to __call__ should not need to do + # any more stacking. + # We invoke `self()` here as a way to discard any corrupted state. + return self() + else: + outputs = while_fn() + wrapped_outputs = [] + for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)): + if i not in self._body_pass_through_indices and cond_is_stacked[0]: + wrapped_outputs.append(wrap(out, True)) + else: + wrapped_outputs.append(wrap(out, inp.is_stacked)) + return wrapped_outputs + + +@RegisterPFor("StatelessWhile") +@RegisterPFor("While") +def _convert_while(pfor_input: _PforInput): + converter = WhileV2(pfor_input) + return converter() + + +# spectral_ops + + +@RegisterPForWithArgs("FFT", gen_spectral_ops.fft) +@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d) +@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d) +@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft) +@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d) +@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d) +def _convert_fft(pfor_input: _PforInput, _, op_func): + return wrap(op_func(pfor_input.stacked_input(0)), True) + + +@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex") +@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex") +@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex") +@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal") +@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal") +@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal") +def _convert_rfft(pfor_input: _PforInput, _, op_func, attr_name): + inp = pfor_input.stacked_input(0) + fft_length = pfor_input.unstacked_input(1) + attr = pfor_input.get_attr(attr_name) + return wrap(op_func(inp, fft_length, attr), True) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/test_util.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa3f4559d4bfcce3c3a4131a5e0b4164b2b055f --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/parallel_for/test_util.py @@ -0,0 +1,76 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utility.""" + +import numpy as np + +from tensorflow.python.ops import variables +from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops +from tensorflow.python.platform import test +from tensorflow.python.util import nest + + +class PForTestCase(test.TestCase): + """Base class for test cases.""" + + def _run_targets(self, targets1, targets2=None, run_init=True): + targets1 = nest.flatten(targets1) + targets2 = ([] if targets2 is None else nest.flatten(targets2)) + assert len(targets1) == len(targets2) or not targets2 + if run_init: + init = variables.global_variables_initializer() + self.evaluate(init) + return self.evaluate(targets1 + targets2) + + # TODO(agarwal): Allow tests to pass down tolerances. + def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): + outputs = self._run_targets(targets1, targets2) + outputs = nest.flatten(outputs) # flatten SparseTensorValues + n = len(outputs) // 2 + for i in range(n): + if outputs[i + n].dtype != np.object_: + self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol) + else: + self.assertAllEqual(outputs[i + n], outputs[i]) + + def _test_loop_fn(self, + loop_fn, + iters, + parallel_iterations=None, + fallback_to_while_loop=False, + rtol=1e-4, + atol=1e-5): + t1 = pfor_control_flow_ops.pfor( + loop_fn, + iters=iters, + fallback_to_while_loop=fallback_to_while_loop, + parallel_iterations=parallel_iterations) + loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1) + t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters, + parallel_iterations=parallel_iterations) + + def _check_shape(a, b): + msg = ( + "Inferred static shapes are different between two loops:" + f" {a.shape} vs {b.shape}." + ) + # TODO(b/268146947): should assert bool(a.shape) == bool(b.shape), + # since both should be either defined or undefined. But it does not work. + if b.shape: + self.assertEqual(a.shape.as_list()[0], b.shape.as_list()[0], msg) + # TODO(b/268146947): self.assertShapeEqual(a, b, msg) does not work. + + nest.map_structure(_check_shape, t1, t2) + self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c9482bc3cc02240fb0b86cb4865044c4181ab7f5 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_batch_gather_ops.py @@ -0,0 +1,60 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Batch gather operations for RaggedTensors.""" + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged import ragged_gather_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util import dispatch + + +#=============================================================================== +# ragged.batch_gather +#=============================================================================== +@dispatch.dispatch_for_api(array_ops.batch_gather) +def batch_gather(params: ragged_tensor.RaggedOrDense, + indices: ragged_tensor.RaggedOrDense, + name=None): + """Gathers slices from `params` according to `indices` with batch dims. + + This operation is similar to `gather`, but it assumes that the leading `N` + dimensions of `indices` and `params` are batch dimensions, and performs a + gather within each batch. In particular, when using this operation with `N` + batch dimensions `B1...BN`: + + * `indices` has shape `[B1...BN, I]` + * `params` has shape `[B1...BN, P1...PM]`. + * `result` has shape `[B1...BN, I, P2...PM]`. + * `result[b1...bN, i, p2...pM] = + params[b1...bN, indices[b1...bN, i], p2...pM]` + + Args: + params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`, + `M>0`). + indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`). + name: A name for the operation (optional). + + Returns: + A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`. + `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`. + + #### Example: + + >>> params = tf.ragged.constant([['a', 'b', 'c'], ['d'], [], ['e']]) + >>> indices = tf.ragged.constant([[1, 2, 0], [], [], [0, 0]]) + >>> tf.compat.v1.batch_gather(params, indices) + + """ + return ragged_gather_ops.gather(params, indices, batch_dims=-1, name=name) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_check_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_check_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6f8b96abc618fcd21ed789750180201e19c1c329 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_check_ops.py @@ -0,0 +1,27 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Asserts and Boolean Checks for RaggedTensors.""" + +from tensorflow.python.ops import check_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util import dispatch + + +@dispatch.dispatch_for_api(check_ops.assert_type) +def assert_type(tensor: ragged_tensor.Ragged, tf_type, message=None, name=None): + return check_ops.assert_type(tensor.flat_values, tf_type, + message=message, name=name) + + diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_conversion_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_conversion_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e71f5cad7929fc034eaa2bad2f6b81110c36b1ed --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_conversion_ops.py @@ -0,0 +1,180 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops to convert between RaggedTensors and other tensor types.""" + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import indexed_slices +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_ragged_conversion_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import ragged_tensor + + +def from_tensor(tensor, + lengths=None, + padding=None, + ragged_rank=1, + row_splits_dtype=dtypes.int64, + name=None): + if ragged_tensor.is_ragged(tensor): + return tensor + else: + return ragged_tensor.RaggedTensor.from_tensor( + tensor, + lengths=lengths, + padding=padding, + ragged_rank=ragged_rank, + row_splits_dtype=row_splits_dtype, + name=name) + + +def to_tensor(rt_input, default_value=None, name=None): + if ragged_tensor.is_ragged(rt_input): + return rt_input.to_tensor(default_value, name) + else: + return rt_input + + +def ragged_to_dense(rt_input, default_value=None, shape=None): + """Create a dense tensor from a ragged tensor.""" + return rt_input.to_tensor(default_value=default_value, shape=shape) + + +@ops.RegisterGradient("RaggedTensorToTensor") +def _ragged_tensor_to_tensor_grad(op, grad): + """Gradient for RaggedToTensor op.""" + # Extract inputs from the op. + flat_values = op.inputs[1] + default_value = op.inputs[2] + row_partition_tensors = op.inputs[3:] + row_partition_types = op.get_attr("row_partition_types") + flat_value_shape = array_ops.shape(flat_values) + ragged_rank = sum( + 1 for typ in row_partition_types if typ != b"FIRST_DIM_SIZE") + + # Create two tensors that correspond 1:1 with grad (and op.output): + # * indices[i1...iN] is the index in `flat_values` of the value used to + # populate output[i1...iN] (if the value came from `flat_values`) or + # -1 (if the value came from `default_value`). + # * mask[i1...iN] is true if output[i1...iN] came from `flat_values`, or + # false if it came from `default_value`. + indices = gen_ragged_conversion_ops.ragged_tensor_to_tensor( + shape=array_ops.shape(grad)[:1 + ragged_rank], + values=math_ops.range(flat_value_shape[0]), + default_value=-1, + row_partition_types=row_partition_types, + row_partition_tensors=row_partition_tensors) + mask = math_ops.not_equal(indices, -1) + + # Select out the gradients & indices that came from `flat_values`, and use + # those to construct the gradient for `flat_values` (as an IndexedSlices). + values_grad = indexed_slices.IndexedSlices( + values=array_ops.boolean_mask(grad, mask), + indices=array_ops.boolean_mask(indices, mask), + dense_shape=flat_value_shape) + + # Select out the gradients that came from `default_value`, and sum them to + # get the gradient for the default. Note that the default_value may have + # been broadcast as part of the RaggedTensorToTensor operation, so we also + # need to reduce any dimensions that might have been broadcast. + default_grads = array_ops.boolean_mask(grad, ~mask) + dims_to_reduce = math_ops.range( + array_ops.rank(default_grads) - + _rank_ignoring_leading_dims_with_size_1(default_value)) + default_grad = math_ops.reduce_sum(default_grads, axis=dims_to_reduce) + + # Restore any leading dims with size one. + default_grad = array_ops.reshape(default_grad, array_ops.shape(default_value)) + + return ([None, values_grad, default_grad] + + [None for _ in row_partition_tensors]) + + +def _rank_ignoring_leading_dims_with_size_1(value): + """Returns `rank(value)`, ignoring any leading dimensions with size 1.""" + # Compute the result using static shape, if possible. + if value.shape.rank is not None: + ndims = value.shape.rank + for dim in value.shape.dims: + if dim.value == 1: + ndims -= 1 + elif dim.value is None: + ndims = None # Can't compute the result using static shape. + break + else: + break + if ndims is not None: + return ndims + + # Otherwise, we need to compute the result dynamically. The math we use to + # do this is a bit round-about, so here's an example to illustrate: + # shape = [1, 1, 3, 5, 1, 4] # shape(value) + # dim_is_one = [1, 1, 0, 0, 1, 0] # equal(shape, 1) + # leading_ones = [1, 1, 0, 0, 0, 0] # cumprod(dim_is_one) + # num_leading_ones = 2 # reduce_sum(leading_ones) + # result = 4 # rank(value) - num_leading_ones + shape = array_ops.shape(value) + dim_is_one = math_ops.cast(math_ops.equal(shape, 1), dtypes.int32) + leading_ones = math_ops.cumprod(dim_is_one) + num_leading_ones = math_ops.reduce_sum(leading_ones) + return array_ops.rank(value) - num_leading_ones + + +def to_sparse(rt_input, name=None): + return rt_input.to_sparse(name) + + +def from_sparse(st_input, name=None): + return ragged_tensor.RaggedTensor.from_sparse(st_input, name) + + +@ops.RegisterGradient("RaggedTensorFromVariant") +def _ragged_tensor_from_variant_grad(op, *grads): + """Gradient for RaggedTensorFromVariant op.""" + + variant_rank = op.inputs[0].shape.rank + if variant_rank == 0: + batched_input = False + elif variant_rank == 1: + batched_input = True + elif variant_rank is None: + batched_input = (op.get_attr("output_ragged_rank") > 0) + else: + # TODO(edloper): Add a batch_dims argument to RaggedTensorToVariant, so + # we can support this. + raise ValueError("Unable to compute gradient: RaggedTensorToVariant " + "can currently only generate 0D or 1D output.") + return [ + gen_ragged_conversion_ops.ragged_tensor_to_variant( + rt_nested_splits=op.outputs[:-1], + rt_dense_values=grads[-1], + batched_input=batched_input) + ] + + +@ops.RegisterGradient("RaggedTensorToVariant") +def _ragged_tensor_to_variant_grad(op, encoded_ragged_grad): + """Gradient for RaggedTensorToVariant op.""" + dense_values = op.inputs[-1] + ragged_rank = len(op.inputs) - 1 + row_splits = 0 if ragged_rank == 0 else op.inputs[0] + values_grad = gen_ragged_conversion_ops.ragged_tensor_to_variant_gradient( + encoded_ragged_grad=encoded_ragged_grad, + row_splits=row_splits, + dense_values_shape=array_ops.shape(dense_values), + Tvalues=op.inputs[-1].dtype) + result = [None] * ragged_rank + [values_grad] + return result diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_getitem.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_getitem.py new file mode 100644 index 0000000000000000000000000000000000000000..b7ceb9240be243fcdbe471c431337a712ea34a0c --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_getitem.py @@ -0,0 +1,477 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python-style indexing and slicing for RaggedTensors.""" + +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import cond +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import ragged_gather_ops +from tensorflow.python.ops.ragged import ragged_math_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util import dispatch +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("__operators__.ragged_getitem", v1=[]) +@dispatch.add_dispatch_support +def ragged_tensor_getitem(rt_input, key): + """Returns the specified piece of this RaggedTensor. + + Supports multidimensional indexing and slicing, with one restriction: + indexing into a ragged inner dimension is not allowed. This case is + problematic because the indicated value may exist in some rows but not + others. In such cases, it's not obvious whether we should (1) report an + IndexError; (2) use a default value; or (3) skip that value and return a + tensor with fewer rows than we started with. Following the guiding + principles of Python ("In the face of ambiguity, refuse the temptation to + guess"), we simply disallow this operation. + + Args: + rt_input: The RaggedTensor to slice. + key: Indicates which piece of the RaggedTensor to return, using standard + Python semantics (e.g., negative values index from the end). `key` + may have any of the following types: + + * `int` constant + * Scalar integer `Tensor` + * `slice` containing integer constants and/or scalar integer + `Tensor`s + * `Ellipsis` + * `tf.newaxis` + * `tuple` containing any of the above (for multidimensional indexing) + + Returns: + A `Tensor` or `RaggedTensor` object. Values that include at least one + ragged dimension are returned as `RaggedTensor`. Values that include no + ragged dimensions are returned as `Tensor`. See above for examples of + expressions that return `Tensor`s vs `RaggedTensor`s. + + Raises: + ValueError: If `key` is out of bounds. + ValueError: If `key` is not supported. + TypeError: If the indices in `key` have an unsupported type. + + Examples: + + >>> # A 2-D ragged tensor with 1 ragged dimension. + >>> rt = tf.ragged.constant([['a', 'b', 'c'], ['d', 'e'], ['f'], ['g']]) + >>> rt[0].numpy() # First row (1-D `Tensor`) + array([b'a', b'b', b'c'], dtype=object) + >>> rt[:3].to_list() # First three rows (2-D RaggedTensor) + [[b'a', b'b', b'c'], [b'd', b'e'], [b'f']] + >>> rt[3, 0].numpy() # 1st element of 4th row (scalar) + b'g' + + >>> # A 3-D ragged tensor with 2 ragged dimensions. + >>> rt = tf.ragged.constant([[[1, 2, 3], [4]], + ... [[5], [], [6]], + ... [[7]], + ... [[8, 9], [10]]]) + >>> rt[1].to_list() # Second row (2-D RaggedTensor) + [[5], [], [6]] + >>> rt[3, 0].numpy() # First element of fourth row (1-D Tensor) + array([8, 9], dtype=int32) + >>> rt[:, 1:3].to_list() # Items 1-3 of each row (3-D RaggedTensor) + [[[4]], [[], [6]], [], [[10]]] + >>> rt[:, -1:].to_list() # Last item of each row (3-D RaggedTensor) + [[[4]], [[6]], [[7]], [[10]]] + """ + if not isinstance(rt_input, ragged_tensor.RaggedTensor): + raise TypeError("Ragged __getitem__ expects a ragged_tensor.") + scope_tensors = [rt_input] + list(_tensors_in_key_list(key)) + if isinstance(key, (list, tuple)): + key = list(key) + else: + key = [key] + with ops.name_scope(None, "RaggedGetItem", scope_tensors): + return _ragged_getitem(rt_input, key) + + +def _ragged_getitem(rt_input, key_list): + """Helper for indexing and slicing ragged tensors with __getitem__(). + + Extracts the specified piece of the `rt_input`. See + `RaggedTensor.__getitem__` for examples and restrictions. + + Args: + rt_input: The `RaggedTensor` from which a piece should be returned. + key_list: The list of keys specifying which piece to return. Each key + corresponds with a separate dimension. + + Returns: + The indicated piece of rt_input. + + Raises: + ValueError: If `key_list` is not supported. + TypeError: If any keys in `key_list` have an unsupported type. + """ + if not key_list: + return rt_input + row_key = key_list[0] + inner_keys = key_list[1:] + + if row_key is Ellipsis: + expanded_key_list = _expand_ellipsis(key_list, rt_input.shape.ndims) + return _ragged_getitem(rt_input, expanded_key_list) + + # Adding a new axis: Get rt_input[inner_keys], and wrap it in a RaggedTensor + # that puts all values in a single row. + if row_key is array_ops.newaxis: + inner_rt = _ragged_getitem(rt_input, inner_keys) + nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0) + if nsplits.value is not None: + nsplits = nsplits.value + else: + nsplits = array_ops.shape(inner_rt.row_splits, + out_type=inner_rt.row_splits.dtype)[0] + return ragged_tensor.RaggedTensor.from_uniform_row_length( + inner_rt, nsplits - 1, nrows=1, validate=False) + + # Slicing a range of rows: first slice the outer dimension, and then + # call `_ragged_getitem_inner_dimensions` to handle the inner keys. + if isinstance(row_key, slice): + sliced_rt_input = _slice_ragged_row_dimension(rt_input, row_key) + if rt_input.uniform_row_length is not None: + # If the inner dimension has uniform_row_length, then preserve it (by + # re-wrapping the values in a new RaggedTensor). Note that the row + # length won't have changed, since we're slicing a range of rows (and not + # slicing the rows themselves). + sliced_rt_input = ragged_tensor.RaggedTensor.from_uniform_row_length( + sliced_rt_input.values, rt_input.uniform_row_length, + nrows=sliced_rt_input.nrows()) + return _ragged_getitem_inner_dimensions(sliced_rt_input, inner_keys) + + # Indexing a single row: slice values to get the indicated row, and then + # use a recursive call to __getitem__ to handle the inner keys. + else: + starts = rt_input.row_splits[:-1] + limits = rt_input.row_splits[1:] + if context.executing_eagerly(): + # In python, __getitem__ should throw IndexError for out of bound + # indices. This will allow iteration run correctly as python will + # translate IndexError into StopIteration for next()/__next__(). + # Below is an example: + # import tensorflow as tf + # r = tf.ragged.constant([[1., 2.], [3., 4., 5.], [6.]]) + # for elem in r: + # print(elem) + # In non eager mode, the exception is thrown when session runs + # so we don't know if out of bound happens before. + # In eager mode, however, it is possible to find out when to + # throw out of bound IndexError. + # In the following row_key >= len(starts) is checked. In case of + # TypeError which happens when row_key is not an integer, the exception + # will simply be ignored as it will be processed later anyway. + try: + if int(row_key) >= len(starts): + raise IndexError("Row key {} out of bounds".format(row_key)) + except (TypeError, ValueError): + pass + row = rt_input.values[starts[row_key]:limits[row_key]] + return row.__getitem__(inner_keys) + + +def _slice_ragged_row_dimension(rt_input, row_key): + """Slice the outer dimension of `rt_input` according to the given `slice`. + + Args: + rt_input: The `RaggedTensor` to slice. + row_key: The `slice` object that should be used to slice `rt_input`. + + Returns: + A `RaggedTensor` containing the indicated slice of `rt_input`. + """ + if row_key.start is None and row_key.stop is None and row_key.step is None: + return rt_input + + # Use row_key to slice the starts & limits. + new_starts = rt_input.row_splits[:-1][row_key] + new_limits = rt_input.row_splits[1:][row_key] + zero_pad = array_ops.zeros([1], rt_input.row_splits.dtype) + + # If there's no slice step, then we can just select a single continuous + # span of `ragged.values(rt_input)`. + if row_key.step is None or row_key.step == 1: + # Construct the new splits. If new_starts and new_limits are empty, + # then this reduces to [0]. Otherwise, this reduces to: + # concat([[new_starts[0]], new_limits]) + new_splits = array_ops.concat( + [zero_pad[array_ops.size(new_starts):], new_starts[:1], new_limits], + axis=0) + values_start = new_splits[0] + values_limit = new_splits[-1] + return ragged_tensor.RaggedTensor.from_row_splits( + rt_input.values[values_start:values_limit], new_splits - values_start, + validate=False) + + # If there is a slice step (aka a strided slice), then use ragged_gather to + # collect the necessary elements of `ragged.values(rt_input)`. + else: + return _build_ragged_tensor_from_value_ranges(new_starts, new_limits, 1, + rt_input.values) + + +def _ragged_getitem_inner_dimensions(rt_input, key_list): + """Retrieve inner dimensions, keeping outermost dimension unchanged. + + Args: + rt_input: The `RaggedTensor` or `Tensor` from which a piece should be + extracted. + key_list: The __getitem__ keys for slicing the inner dimensions. + + Returns: + A `RaggedTensor`. + + Raises: + ValueError: If key_list is not supported. + """ + if not key_list: + return rt_input + + if not isinstance(rt_input, ragged_tensor.RaggedTensor): + return rt_input.__getitem__([slice(None, None, None)] + key_list) + + column_key = key_list[0] + if column_key is Ellipsis: + expanded_key_list = _expand_ellipsis(key_list, rt_input.values.shape.ndims) + return _ragged_getitem_inner_dimensions(rt_input, expanded_key_list) + + # Adding a new axis to a ragged inner dimension: recursively get the inner + # dimensions of rt_input with key_list[1:], and then wrap the result in a + # RaggedTensor that puts each value in its own row. + if column_key is array_ops.newaxis: + inner_rt = _ragged_getitem_inner_dimensions(rt_input, key_list[1:]) + nsplits = tensor_shape.dimension_at_index(inner_rt.row_splits.shape, 0) + if nsplits.value is not None: + nsplits = nsplits.value + else: + nsplits = array_ops.shape( + inner_rt.row_splits, out_type=inner_rt.row_splits.dtype + )[0] + return ragged_tensor.RaggedTensor.from_uniform_row_length( + inner_rt, 1, nrows=nsplits - 1, validate=False) + + # Slicing a range of columns in a ragged inner dimension. We use a + # recursive call to process the values, and then assemble a RaggedTensor + # with those values. + if isinstance(column_key, slice): + if (column_key.start is None and column_key.stop is None and + column_key.step is None): + # Trivial slice: recursively process all values, & splits is unchanged. + return rt_input.with_values( + _ragged_getitem_inner_dimensions(rt_input.values, key_list[1:])) + else: + if not ( + isinstance(column_key.start, (tensor_lib.Tensor, int, type(None))) + and isinstance(column_key.stop, (tensor_lib.Tensor, int, type(None))) + ): + raise TypeError("slice offsets must be integers or None") + + # Nontrivial slice: use ragged_gather to extract the indicated slice as + # a new RaggedTensor (inner_rt), and then recursively process its values. + starts = rt_input.row_splits[:-1] + limits = rt_input.row_splits[1:] + step = 1 if column_key.step is None else column_key.step + lower_bound = _if_ge_zero(step, lambda: starts, lambda: starts - 1) + upper_bound = _if_ge_zero(step, lambda: limits, lambda: limits - 1) + # inner_rt_starts[i] = index to start gathering for row i. + if column_key.start is None: + inner_rt_starts = _if_ge_zero(step, lambda: starts, lambda: limits - 1) + else: + start_offset = math_ops.cast(column_key.start, starts.dtype) + inner_rt_starts = _if_ge_zero( + column_key.start, + lambda: math_ops.minimum(starts + start_offset, upper_bound), + lambda: math_ops.maximum(limits + start_offset, lower_bound)) + # inner_rt_limits[i] = index to stop gathering for row i. + if column_key.stop is None: + inner_rt_limits = _if_ge_zero(step, lambda: limits, lambda: starts - 1) + else: + stop_offset = math_ops.cast(column_key.stop, starts.dtype) + inner_rt_limits = _if_ge_zero( + column_key.stop, + lambda: math_ops.minimum(starts + stop_offset, upper_bound), + lambda: math_ops.maximum(limits + stop_offset, lower_bound)) + inner_rt = _build_ragged_tensor_from_value_ranges( + inner_rt_starts, inner_rt_limits, column_key.step, rt_input.values) + # If the row dimension is uniform, then calculate the new + # uniform_row_length, and rebuild inner_rt using that uniform_row_lengths. + if rt_input.uniform_row_length is not None: + new_row_length = _slice_length(rt_input.uniform_row_length, column_key) + inner_rt = ragged_tensor.RaggedTensor.from_uniform_row_length( + inner_rt.values, new_row_length, rt_input.nrows()) + return inner_rt.with_values( + _ragged_getitem_inner_dimensions(inner_rt.values, key_list[1:])) + + # Indexing a single column in a ragged inner dimension: raise an Exception. + # See RaggedTensor.__getitem__.__doc__ for an explanation of why indexing + # into a ragged inner dimension is problematic. + if rt_input.uniform_row_length is None: + raise ValueError("Cannot index into an inner ragged dimension.") + + # Indexing a single column in a uniform inner dimension: check that the + # given index is in-bounds, and then use a strided slice over rt_input.values + # to take the indicated element from each row. + row_length = rt_input.uniform_row_length + column_key = math_ops.cast(column_key, row_length.dtype) + oob_err_msg = "Index out of bounds when indexing into a ragged tensor" + oob_checks = [ + check_ops.assert_greater_equal( + column_key, -row_length, message=oob_err_msg), + check_ops.assert_less(column_key, row_length, message=oob_err_msg), + ] + with ops.control_dependencies(oob_checks): + offset = _if_ge_zero(column_key, lambda: column_key, + lambda: row_length + column_key) + sliced_rt = rt_input.values[offset::row_length] + return _ragged_getitem_inner_dimensions(sliced_rt, key_list[1:]) + + +def _slice_length(value_length, slice_key): + """Computes the number of elements in a slice of a value with a given length. + + Returns the equivalent of: `len(range(value_length)[slice_key])` + + Args: + value_length: Scalar int `Tensor`: the length of the value being sliced. + slice_key: A `slice` object used to slice elements from the value. + + Returns: + The number of elements in the sliced value. + """ + # Note: we could compute the slice length without creating a zeros tensor + # with some variant of (stop-start)//step, but doing so would require more + # ops (for checking bounds, handling negative indices, negative step sizes, + # etc); and we expect this to be an uncommon operation, so we use this + # simpler implementation. + zeros = array_ops.zeros(value_length, dtype=dtypes.bool) + return array_ops.size(zeros[slice_key], out_type=value_length.dtype) + + +def _expand_ellipsis(key_list, num_remaining_dims): + """Expands the ellipsis at the start of `key_list`. + + Assumes that the first element of `key_list` is Ellipsis. This will either + remove the Ellipsis (if it corresponds to zero indices) or prepend a new + `slice(None, None, None)` (if it corresponds to more than zero indices). + + Args: + key_list: The arguments to `__getitem__()`. + num_remaining_dims: The number of dimensions remaining. + + Returns: + A copy of `key_list` with he ellipsis expanded. + Raises: + ValueError: If ragged_rank.shape.ndims is None + IndexError: If there are too many elements in `key_list`. + """ + if num_remaining_dims is None: + raise ValueError("Ellipsis not supported for unknown shape RaggedTensors") + num_indices = sum(1 for idx in key_list if idx is not array_ops.newaxis) + if num_indices > num_remaining_dims + 1: + raise IndexError("Too many indices for RaggedTensor") + elif num_indices == num_remaining_dims + 1: + return key_list[1:] + else: + return [slice(None, None, None)] + key_list + + +def _tensors_in_key_list(key_list): + """Generates all Tensors in the given slice spec.""" + if isinstance(key_list, tensor_lib.Tensor): + yield key_list + if isinstance(key_list, (list, tuple)): + for v in key_list: + for tensor in _tensors_in_key_list(v): + yield tensor + if isinstance(key_list, slice): + for tensor in _tensors_in_key_list(key_list.start): + yield tensor + for tensor in _tensors_in_key_list(key_list.stop): + yield tensor + for tensor in _tensors_in_key_list(key_list.step): + yield tensor + + +def _build_ragged_tensor_from_value_ranges(starts, limits, step, values): + """Returns a `RaggedTensor` containing the specified sequences of values. + + Returns a RaggedTensor `output` where: + + ```python + output.shape[0] = starts.shape[0] + output[i] = values[starts[i]:limits[i]:step] + ``` + + Requires that `starts.shape == limits.shape` and + `0 <= starts[i] <= limits[i] <= values.shape[0]`. + + Args: + starts: 1D integer Tensor specifying the start indices for the sequences of + values to include. + limits: 1D integer Tensor specifying the limit indices for the sequences of + values to include. + step: Integer value specifying the step size for strided slices. + values: The set of values to select from. + + Returns: + A `RaggedTensor`. + + Raises: + ValueError: Until the prerequisite ops are checked in. + """ + # Use `ragged_range` to get the index of each value we should include. + if step is None: + step = 1 + step = ops.convert_to_tensor(step, name="step") + if step.dtype.is_integer: + step = math_ops.cast(step, starts.dtype) + else: + raise TypeError("slice strides must be integers or None") + value_indices = ragged_math_ops.range(starts, limits, step, + row_splits_dtype=starts.dtype) + + # Use `ragged_gather` or `array_ops.gather` to collect the values. + if isinstance(values, ragged_tensor.RaggedTensor): + gathered_values = ragged_gather_ops.gather( + params=values, indices=value_indices.values) + else: + gathered_values = array_ops.gather( + params=values, indices=value_indices.values) + + # Assemble the RaggedTensor from splits & values. + return value_indices.with_values(gathered_values) + + +def _if_ge_zero(value, true_fn, false_fn): + """Returns `true_fn() if value >= 0 else false_fn()`.""" + # If `value` is statically known, then don't use a control flow op. + if isinstance(value, tensor_lib.Tensor): + const_value = tensor_util.constant_value(value) + if const_value is None: + return cond.cond(value >= 0, true_fn, false_fn) + else: + value = const_value + if value >= 0: + return true_fn() + else: + return false_fn() diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_map_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_map_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..06d9f0624d08306ecb1b3b4e838116f6fd86b0ed --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_map_ops.py @@ -0,0 +1,174 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional operations for RaggedTensors.""" + +from tensorflow.python.ops import map_fn as map_fn_lib +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util import nest + + +def map_fn(fn, + elems, + dtype=None, + parallel_iterations=None, + back_prop=True, + swap_memory=False, + infer_shape=True, + name=None): + """map on the list of tensors unpacked from `elems` on dimension 0. + + The simplest version of `map_fn` repeatedly applies the callable `fn` to a + sequence of elements from first to last. The elements are made of the + tensors unpacked from `elems`. `dtype` is the data type of the return + value of `fn`. Users must provide `dtype` if it is different from + the data type of `elems`. + + Suppose that `elems` is unpacked into `values`, a list of tensors. The shape + of the result tensor is `[values.shape[0]] + fn(values[0]).shape`. + + This method also allows multi-arity `elems` and output of `fn`. If `elems` + is a (possibly nested) list or tuple of tensors, then each of these tensors + must have a matching first (unpack) dimension. The signature of `fn` may + match the structure of `elems`. That is, if `elems` is + `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: + `fn = lambda (t1, [t2, t3, [t4, t5]]):`. + + Furthermore, `fn` may emit a different structure than its input. For example, + `fn` may look like: `fn = lambda t1: return (t1 + 1, t1 - 1)`. In this case, + the `dtype` parameter is not optional: `dtype` must be a type or (possibly + nested) tuple of types matching the output of `fn`. + + To apply a functional operation to the nonzero elements of a SparseTensor + one of the following methods is recommended. First, if the function is + expressible as TensorFlow ops, use + + ```python + result = SparseTensor(input.indices, fn(input.values), input.dense_shape) + ``` + + If, however, the function is not expressible as a TensorFlow op, then use + + ```python + result = SparseTensor( + input.indices, map_fn(fn, input.values), input.dense_shape) + ``` + + instead. + + When executing eagerly, map_fn does not execute in parallel even if + `parallel_iterations` is set to a value > 1. You can still get the + performance benefits of running a function in parallel by using the + `tf.contrib.eager.defun` decorator, + + ```python + # Assume the function being used in map_fn is fn. + # To ensure map_fn calls fn in parallel, use the defun decorator. + @tf.contrib.eager.defun + def func(tensor): + return tf.map_fn(fn, tensor) + ``` + + Note that if you use the defun decorator, any non-TensorFlow Python code + that you may have written in your function won't get executed. See + `tf.contrib.eager.defun` for more details. The recommendation would be to + debug without defun but switch to defun to get performance benefits of + running map_fn in parallel. + + Args: + fn: The callable to be performed. It accepts one argument, which will have + the same (possibly nested) structure as `elems`. Its output must have the + same structure as `dtype` if one is provided, otherwise it must have the + same structure as `elems`. + elems: A tensor or (possibly nested) sequence of tensors, each of which will + be unpacked along their first dimension. The nested sequence of the + resulting slices will be applied to `fn`. + dtype: (optional) The output type(s) of `fn`. If `fn` returns a structure + of Tensors differing from the structure of `elems`, then `dtype` is not + optional and must have the same structure as the output of `fn`. Use + `RaggedTensorType` to declare an output of type `RaggedTensor`. + parallel_iterations: (optional) The number of iterations allowed to run in + parallel. When graph building, the default value is 10. While executing + eagerly, the default value is set to 1. + back_prop: (optional) True enables support for back propagation. + swap_memory: (optional) True enables GPU-CPU memory swapping. + infer_shape: (optional) False disables tests for consistent output shapes. + name: (optional) Name prefix for the returned tensors. + + Returns: + A possibly nested sequence of potentially ragged tensors. Each + tensor packs the results of applying `fn` to tensors unpacked from `elems` + along the first dimension, from first to last. + + Raises: + TypeError: if `fn` is not callable or the structure of the output of + `fn` and `dtype` do not match, or if elems is a SparseTensor. + ValueError: if the lengths of the output of `fn` and `dtype` do not match. + + #### Examples: + + ```python + elems = np.array([1, 2, 3, 4, 5, 6]) + squares = map_fn(lambda x: x * x, elems) + # squares == [1, 4, 9, 16, 25, 36] + ``` + + ```python + elems = (np.array([1, 2, 3]), np.array([-1, 1, -1])) + alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64) + # alternate == [-1, 2, -3] + ``` + + ```python + elems = np.array([1, 2, 3]) + alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64)) + # alternates[0] == [1, 2, 3] + # alternates[1] == [-1, -2, -3] + ``` + + ```python + elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]]) + mean = map_fn(tf.reduce_mean, elems) + # mean == [2, 4, 6] + ``` + + ```python + elems=ragged.constant([[1, 2, 3], [4, 5], [6, 7]], dtype=tf.int64) + out = map_fn(fn=lambda x: x+1, elems, + dtype=ragged.RaggedTensorType(type=tf.int64, ragged_rank=0)) + # out = tf.ragged.constant([[2, 3, 4], [5, 6], [7, 8]]) + ``` + """ + if dtype is None: + dtype = nest.map_structure(lambda e: e.dtype, elems) + dtype = nest.map_structure(_ragged_type_to_spec, dtype) + return map_fn_lib.map_fn(fn, + elems, + dtype, + parallel_iterations, + back_prop, + swap_memory, + infer_shape, + name) + + +def _ragged_type_to_spec(t): + if isinstance(t, ragged_tensor.RaggedTensorType): + # Note: need to adjust ragged_rank by 1, since RaggedTensorSpec gives the + # type for the mapped `fn` output, but RaggedTensorType gives the type for + # the result of stacking the mapped `fn` outputs. + return ragged_tensor.RaggedTensorSpec( + None, t.dtype, t.ragged_rank - 1, t.row_splits_dtype) + else: + return t diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_squeeze_op.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_squeeze_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6d35ffd493ecfb9823eaca70c349cc2b22a947b5 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_squeeze_op.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operator Squeeze for RaggedTensors.""" + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_assert +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor +from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch + + +@dispatch.dispatch_for_api(array_ops.squeeze_v2) +def squeeze(input: ragged_tensor.Ragged, axis=None, name=None): # pylint: disable=redefined-builtin + """Ragged compatible squeeze. + + If `input` is a `tf.Tensor`, then this calls `tf.squeeze`. + + If `input` is a `tf.RaggedTensor`, then this operation takes `O(N)` time, + where `N` is the number of elements in the squeezed dimensions. + + Args: + input: A potentially ragged tensor. The input to squeeze. + axis: An optional list of ints. Defaults to `None`. If the `input` is + ragged, it only squeezes the dimensions listed. It fails if `input` is + ragged and axis is []. If `input` is not ragged it calls tf.squeeze. Note + that it is an error to squeeze a dimension that is not 1. It must be in + the range of [-rank(input), rank(input)). + name: A name for the operation (optional). + + Returns: + A potentially ragged tensor. Contains the same data as input, + but has one or more dimensions of size 1 removed. + """ + with ops.name_scope(name, 'RaggedSqueeze', [input]): + input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input) + if isinstance(input, tensor.Tensor): + return array_ops.squeeze(input, axis, name) + + if axis is None: + raise ValueError('Ragged.squeeze must have an axis argument.') + if isinstance(axis, int): + axis = [axis] + elif ((not isinstance(axis, (list, tuple))) or + (not all(isinstance(d, int) for d in axis))): + raise TypeError('Axis must be a list or tuple of integers.') + + dense_dims = [] + ragged_dims = [] + # Normalize all the dims in axis to be positive + axis = [ + array_ops.get_positive_axis(d, input.shape.ndims, 'axis[%d]' % i, + 'rank(input)') for i, d in enumerate(axis) + ] + for dim in axis: + if dim > input.ragged_rank: + dense_dims.append(dim - input.ragged_rank) + else: + ragged_dims.append(dim) + + # Make sure the specified ragged dimensions are squeezable. + assertion_list = [] + scalar_tensor_one = constant_op.constant(1, dtype=input.row_splits.dtype) + for i, r in enumerate(input.nested_row_lengths()): + if i + 1 in ragged_dims: + assertion_list.append( + control_flow_assert.Assert( + math_ops.reduce_all(math_ops.equal(r, scalar_tensor_one)), + ['the given axis (axis = %d) is not squeezable!' % (i + 1)])) + if 0 in ragged_dims: + scalar_tensor_two = constant_op.constant(2, dtype=dtypes.int32) + assertion_list.append( + control_flow_assert.Assert( + math_ops.equal( + array_ops.size(input.row_splits), scalar_tensor_two), + ['the given axis (axis = 0) is not squeezable!'])) + + # Till now, we are sure that the ragged dimensions are squeezable. + squeezed_rt = None + squeezed_rt = control_flow_ops.with_dependencies(assertion_list, + input.flat_values) + + if dense_dims: + # Gives error if the dense dimension is not squeezable. + squeezed_rt = array_ops.squeeze(squeezed_rt, dense_dims) + + remaining_row_splits = [] + remaining_row_splits = list() + for i, row_split in enumerate(input.nested_row_splits): + # each row_splits tensor is for dimension #(i+1) . + if (i + 1) not in ragged_dims: + remaining_row_splits.append(row_split) + # Take care of the first row if it is to be squeezed. + if remaining_row_splits and 0 in ragged_dims: + remaining_row_splits.pop(0) + + squeezed_rt = RaggedTensor.from_nested_row_splits(squeezed_rt, + remaining_row_splits) + + # Corner case: when removing all the ragged dimensions and the output is + # a scalar tensor e.g. ragged.squeeze(ragged.constant([[[1]]])). + if set(range(0, input.ragged_rank + 1)).issubset(set(ragged_dims)): + squeezed_rt = array_ops.squeeze(squeezed_rt, [0], name) + + return squeezed_rt + + +@dispatch.dispatch_for_api(array_ops.squeeze) +def _ragged_squeeze_v1(input: ragged_tensor.Ragged, # pylint: disable=redefined-builtin + axis=None, + name=None, + squeeze_dims=None): + axis = deprecation.deprecated_argument_lookup('axis', axis, 'squeeze_dims', + squeeze_dims) + return squeeze(input, axis, name) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_string_ops.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_string_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..41f79781c7b4a8895d0ec050e7361edfb791371b --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_string_ops.py @@ -0,0 +1,948 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ragged operations for working with string Tensors.""" + +import typing + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import cond +from tensorflow.python.ops import gen_string_ops +from tensorflow.python.ops import map_fn as map_fn_lib +from tensorflow.python.ops import string_ops +from tensorflow.python.ops.ragged import ragged_array_ops +from tensorflow.python.ops.ragged import ragged_functional_ops +from tensorflow.python.ops.ragged import ragged_math_ops +from tensorflow.python.ops.ragged import ragged_tensor +from tensorflow.python.util import compat as util_compat +from tensorflow.python.util import deprecation +from tensorflow.python.util import dispatch +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("strings.bytes_split") +@dispatch.add_dispatch_support +def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin + """Split string elements of `input` into bytes. + + Examples: + + >>> tf.strings.bytes_split('hello').numpy() + array([b'h', b'e', b'l', b'l', b'o'], dtype=object) + >>> tf.strings.bytes_split(['hello', '123']) + + + Note that this op splits strings into bytes, not unicode characters. To + split strings into unicode characters, use `tf.strings.unicode_split`. + + See also: `tf.io.decode_raw`, `tf.strings.split`, `tf.strings.unicode_split`. + + Args: + input: A string `Tensor` or `RaggedTensor`: the strings to split. Must + have a statically known rank (`N`). + name: A name for the operation (optional). + + Returns: + A `RaggedTensor` of rank `N+1`: the bytes that make up the source strings. + """ + with ops.name_scope(name, "StringsByteSplit", [input]): + input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, + name="input") + if isinstance(input, ragged_tensor.RaggedTensor): + return input.with_flat_values(string_bytes_split(input.flat_values)) + + rank = input.shape.ndims + if rank is None: + raise ValueError("input must have a statically-known rank.") + + if rank == 0: + return string_bytes_split(array_ops_stack.stack([input]))[0] + elif rank == 1: + indices, values, shape = gen_string_ops.string_split( + input, delimiter="", skip_empty=False) + return ragged_tensor.RaggedTensor.from_value_rowids( + values=values, value_rowids=indices[:, 0], nrows=shape[0], + validate=False) + else: + return string_bytes_split(ragged_tensor.RaggedTensor.from_tensor(input)) + + +# pylint: disable=redefined-builtin +@tf_export("strings.unicode_encode") +@dispatch.add_dispatch_support +def unicode_encode(input, + output_encoding, + errors="replace", + replacement_char=65533, + name=None): + r"""Encodes each sequence of Unicode code points in `input` into a string. + + `result[i1...iN]` is the string formed by concatenating the Unicode + codepoints `input[1...iN, :]`, encoded using `output_encoding`. + + Args: + input: An `N+1` dimensional potentially ragged integer tensor with shape + `[D1...DN, num_chars]`. + output_encoding: Unicode encoding that should be used to encode each + codepoint sequence. Can be `"UTF-8"`, `"UTF-16-BE"`, or `"UTF-32-BE"`. + errors: Specifies the response when an invalid codepoint is encountered + (optional). One of: + * `'replace'`: Replace invalid codepoint with the + `replacement_char`. (default) + * `'ignore'`: Skip invalid codepoints. + * `'strict'`: Raise an exception for any invalid codepoint. + replacement_char: The replacement character codepoint to be used in place of + any invalid input when `errors='replace'`. Any valid unicode codepoint may + be used. The default value is the default unicode replacement character + which is 0xFFFD (U+65533). + name: A name for the operation (optional). + + Returns: + A `N` dimensional `string` tensor with shape `[D1...DN]`. + + #### Example: + + >>> input = tf.ragged.constant( + ... [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]]) + >>> print(unicode_encode(input, 'UTF-8')) + tf.Tensor([b'G\xc3\xb6\xc3\xb6dnight' b'\xf0\x9f\x98\x8a'], + shape=(2,), dtype=string) + """ + with ops.name_scope(name, "UnicodeEncode", [input]): + input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input) + if input_tensor.shape.ndims is None: + raise ValueError("Rank of input_tensor must be statically known.") + if ragged_tensor.is_ragged(input_tensor): + if input_tensor.flat_values.shape.ndims > 1: + # If the flat_values of our ragged tensor is multi-dimensional, we can + # process it separately and our output will have the same nested splits + # as our input. + return input_tensor.with_flat_values( + unicode_encode(input_tensor.flat_values, output_encoding, errors, + replacement_char)) + elif input_tensor.ragged_rank > 1: + # Recursively process the values of the ragged tensor. + return input_tensor.with_values( + unicode_encode(input_tensor.values, output_encoding, errors, + replacement_char)) + else: + # Our ragged tensor is of the correct shape (rank 1 flat_values tensor + # with ragged_rank of 1) so we can process it as normal. + return gen_string_ops.unicode_encode( + input_values=input_tensor.values, + input_splits=input_tensor.row_splits, + output_encoding=output_encoding, + errors=errors, + replacement_char=replacement_char) + else: + if input_tensor.shape.ndims == 2: + # The input tensor is of the correct 2-D shape, it's just not ragged. + return unicode_encode( + ragged_tensor.RaggedTensor.from_tensor(input_tensor), + output_encoding, errors, replacement_char) + elif input_tensor.shape.ndims > 2: + # We need to initially flatten the input tensor to 2-D, and then can + # reshape the output of our processed flattened tensor. + flat_input_tensor = array_ops.reshape( + input_tensor, + array_ops_stack.stack([-1, array_ops.shape(input_tensor)[-1]])) + flat_output_tensor = unicode_encode(flat_input_tensor, output_encoding, + errors, replacement_char) + return array_ops.reshape(flat_output_tensor, input_tensor.shape[:-1]) + elif input_tensor.shape.ndims == 0: + raise ValueError("input_tensor's rank must be at least 1.") + else: + # Our input tensor is rank 1, so we create a ragged tensor with an added + # dimension to create the correct input shape & type, and then remove + # the additional dimension from the output and return the string scalar. + ragged_input_tensor = ragged_tensor.RaggedTensor.from_row_splits( + input_tensor, + array_ops_stack.stack( + [0, array_ops.shape(input_tensor, out_type=dtypes.int32)[0]]), + validate=False) + output_tensor = unicode_encode(ragged_input_tensor, output_encoding, + errors, replacement_char) + return array_ops.reshape(output_tensor, []) + + +# pylint: disable=redefined-builtin +@tf_export("strings.unicode_decode") +@dispatch.add_dispatch_support +def unicode_decode(input, + input_encoding, + errors="replace", + replacement_char=0xFFFD, + replace_control_characters=False, + name=None): + r"""Decodes each string in `input` into a sequence of Unicode code points. + + `result[i1...iN, j]` is the Unicode codepoint for the `j`th character in + `input[i1...iN]`, when decoded using `input_encoding`. + + Args: + input: An `N` dimensional potentially ragged `string` tensor with shape + `[D1...DN]`. `N` must be statically known. + input_encoding: String name for the unicode encoding that should be used to + decode each string. + errors: Specifies the response when an input string can't be converted + using the indicated encoding. One of: + * `'strict'`: Raise an exception for any illegal substrings. + * `'replace'`: Replace illegal substrings with `replacement_char`. + * `'ignore'`: Skip illegal substrings. + replacement_char: The replacement codepoint to be used in place of invalid + substrings in `input` when `errors='replace'`; and in place of C0 control + characters in `input` when `replace_control_characters=True`. + replace_control_characters: Whether to replace the C0 control characters + `(U+0000 - U+001F)` with the `replacement_char`. + name: A name for the operation (optional). + + Returns: + A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`. + The returned tensor is a `tf.Tensor` if `input` is a scalar, or a + `tf.RaggedTensor` otherwise. + + #### Example: + + >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')] + >>> tf.strings.unicode_decode(input, 'UTF-8').to_list() + [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]] + """ + with ops.name_scope(name, "UnicodeDecode", [input]): + return _unicode_decode(input, input_encoding, errors, replacement_char, + replace_control_characters, with_offsets=False) + + +@tf_export("strings.unicode_decode_with_offsets") +@dispatch.add_dispatch_support +def unicode_decode_with_offsets(input, + input_encoding, + errors="replace", + replacement_char=0xFFFD, + replace_control_characters=False, + name=None): + r"""Decodes each string into a sequence of code points with start offsets. + + This op is similar to `tf.strings.decode(...)`, but it also returns the + start offset for each character in its respective string. This information + can be used to align the characters with the original byte sequence. + + Returns a tuple `(codepoints, start_offsets)` where: + + * `codepoints[i1...iN, j]` is the Unicode codepoint for the `j`th character + in `input[i1...iN]`, when decoded using `input_encoding`. + * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th + character in `input[i1...iN]`, when decoded using `input_encoding`. + + Args: + input: An `N` dimensional potentially ragged `string` tensor with shape + `[D1...DN]`. `N` must be statically known. + input_encoding: String name for the unicode encoding that should be used to + decode each string. + errors: Specifies the response when an input string can't be converted + using the indicated encoding. One of: + * `'strict'`: Raise an exception for any illegal substrings. + * `'replace'`: Replace illegal substrings with `replacement_char`. + * `'ignore'`: Skip illegal substrings. + replacement_char: The replacement codepoint to be used in place of invalid + substrings in `input` when `errors='replace'`; and in place of C0 control + characters in `input` when `replace_control_characters=True`. + replace_control_characters: Whether to replace the C0 control characters + `(U+0000 - U+001F)` with the `replacement_char`. + name: A name for the operation (optional). + + Returns: + A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`. + + * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`. + * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`. + + The returned tensors are `tf.Tensor`s if `input` is a scalar, or + `tf.RaggedTensor`s otherwise. + + #### Example: + + >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')] + >>> result = tf.strings.unicode_decode_with_offsets(input, 'UTF-8') + >>> result[0].to_list() # codepoints + [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]] + >>> result[1].to_list() # offsets + [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]] + + """ + with ops.name_scope(name, "UnicodeDecodeWithOffsets", [input]): + return _unicode_decode(input, input_encoding, errors, replacement_char, + replace_control_characters, with_offsets=True) + + +@tf_export("strings.unicode_split") +@dispatch.add_dispatch_support +def unicode_split(input, + input_encoding, + errors="replace", + replacement_char=0xFFFD, + name=None): + r"""Splits each string in `input` into a sequence of Unicode code points. + + `result[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its + `j`th character, when decoded using `input_encoding`. + + Args: + input: An `N` dimensional potentially ragged `string` tensor with shape + `[D1...DN]`. `N` must be statically known. + input_encoding: String name for the unicode encoding that should be used to + decode each string. + errors: Specifies the response when an input string can't be converted + using the indicated encoding. One of: + * `'strict'`: Raise an exception for any illegal substrings. + * `'replace'`: Replace illegal substrings with `replacement_char`. + * `'ignore'`: Skip illegal substrings. + replacement_char: The replacement codepoint to be used in place of invalid + substrings in `input` when `errors='replace'`. + name: A name for the operation (optional). + + Returns: + A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`. + The returned tensor is a `tf.Tensor` if `input` is a scalar, or a + `tf.RaggedTensor` otherwise. + + #### Example: + + >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')] + >>> tf.strings.unicode_split(input, 'UTF-8').to_list() + [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'], + [b'\xf0\x9f\x98\x8a']] + """ + with ops.name_scope(name, "UnicodeSplit", [input]): + codepoints = _unicode_decode(input, input_encoding, errors, + replacement_char, False, with_offsets=False) + return unicode_encode( + ragged_array_ops.expand_dims(codepoints, -1), + output_encoding=input_encoding, + errors=errors, + replacement_char=replacement_char) + + +@tf_export("strings.unicode_split_with_offsets") +@dispatch.add_dispatch_support +def unicode_split_with_offsets(input, + input_encoding, + errors="replace", + replacement_char=0xFFFD, + name=None): + r"""Splits each string into a sequence of code points with start offsets. + + This op is similar to `tf.strings.decode(...)`, but it also returns the + start offset for each character in its respective string. This information + can be used to align the characters with the original byte sequence. + + Returns a tuple `(chars, start_offsets)` where: + + * `chars[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its + `j`th character, when decoded using `input_encoding`. + * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th + character in `input[i1...iN]`, when decoded using `input_encoding`. + + Args: + input: An `N` dimensional potentially ragged `string` tensor with shape + `[D1...DN]`. `N` must be statically known. + input_encoding: String name for the unicode encoding that should be used to + decode each string. + errors: Specifies the response when an input string can't be converted + using the indicated encoding. One of: + * `'strict'`: Raise an exception for any illegal substrings. + * `'replace'`: Replace illegal substrings with `replacement_char`. + * `'ignore'`: Skip illegal substrings. + replacement_char: The replacement codepoint to be used in place of invalid + substrings in `input` when `errors='replace'`. + name: A name for the operation (optional). + + Returns: + A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`. + + * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`. + * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`. + + The returned tensors are `tf.Tensor`s if `input` is a scalar, or + `tf.RaggedTensor`s otherwise. + + #### Example: + + >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')] + >>> result = tf.strings.unicode_split_with_offsets(input, 'UTF-8') + >>> result[0].to_list() # character substrings + [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'], + [b'\xf0\x9f\x98\x8a']] + >>> result[1].to_list() # offsets + [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]] + + """ + with ops.name_scope(name, "UnicodeSplitWithOffsets", [input]): + codepoints, offsets = _unicode_decode(input, input_encoding, errors, + replacement_char, False, + with_offsets=True) + chars = unicode_encode( + ragged_array_ops.expand_dims(codepoints, -1), + output_encoding=input_encoding, + errors=errors, + replacement_char=replacement_char) + return chars, offsets + + +def _unicode_decode(input, input_encoding, errors, replacement_char, + replace_control_characters, with_offsets): + """Decodes each string into a sequence of codepoints.""" + input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, name="input") + input_ndims = input.shape.ndims + if input_ndims is None: + raise ValueError("Rank of `input` must be statically known.") + + if input_ndims > 1: + # Convert to a ragged tensor with ragged_rank = input_ndims - 1. + if not ragged_tensor.is_ragged(input): + input = ragged_tensor.RaggedTensor.from_tensor( + input, ragged_rank=input_ndims - 1) + elif input.ragged_rank < input_ndims - 1: + input = input.with_flat_values( + ragged_tensor.RaggedTensor.from_tensor( + input.flat_values, + ragged_rank=input_ndims - input.ragged_rank - 1)) + + # Reshape the input to a flat vector, and apply the gen_string_ops op. + if ragged_tensor.is_ragged(input): + flat_input = array_ops.reshape(input.flat_values, [-1]) + else: + flat_input = array_ops.reshape(input, [-1]) + + if with_offsets: + decode_op = gen_string_ops.unicode_decode_with_offsets + else: + decode_op = gen_string_ops.unicode_decode + flat_result = decode_op( + input=flat_input, + input_encoding=input_encoding, + errors=errors, + replacement_char=replacement_char, + replace_control_characters=replace_control_characters) + + if input_ndims == 0: + codepoints = flat_result.char_values + if with_offsets: + offsets = flat_result.char_to_byte_starts + else: + codepoints = ragged_tensor.RaggedTensor.from_row_splits( + flat_result.char_values, flat_result.row_splits, validate=False) + if input_ndims > 1: + codepoints = input.with_flat_values(codepoints) + if with_offsets: + offsets = ragged_tensor.RaggedTensor.from_row_splits( + flat_result.char_to_byte_starts, flat_result.row_splits, + validate=False) + if input_ndims > 1: + offsets = input.with_flat_values(offsets) + + if with_offsets: + return codepoints, offsets + else: + return codepoints + + +@tf_export("strings.split", v1=[]) +@dispatch.add_dispatch_support +def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable=redefined-builtin + """Split elements of `input` based on `sep` into a `RaggedTensor`. + + Let N be the size of `input` (typically N will be the batch size). Split each + element of `input` based on `sep` and return a `RaggedTensor` containing the + split tokens. Empty tokens are ignored. + + Example: + + >>> tf.strings.split('hello world').numpy() + array([b'hello', b'world'], dtype=object) + >>> tf.strings.split(['hello world', 'a b c']) + + + If `sep` is given, consecutive delimiters are not grouped together and are + deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and + `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty + string, consecutive whitespace are regarded as a single separator, and the + result will contain no empty strings at the start or end if the string has + leading or trailing whitespace. + + Note that the above mentioned behavior matches python's str.split. + + Args: + input: A string `Tensor` of rank `N`, the strings to split. If + `rank(input)` is not known statically, then it is assumed to be `1`. + sep: `0-D` string `Tensor`, the delimiter string. + maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. + name: A name for the operation (optional). + + Raises: + ValueError: If sep is not a string. + + Returns: + A `RaggedTensor` of rank `N+1`, the strings split according to the + delimiter. + """ + with ops.name_scope(name, "StringSplit", [input]): + input = ragged_tensor.convert_to_tensor_or_ragged_tensor( + input, dtype=dtypes.string, name="input") + if isinstance(input, ragged_tensor.RaggedTensor): + return input.with_flat_values( + string_split_v2(input.flat_values, sep, maxsplit)) + + rank = input.shape.ndims + if rank == 0: + return string_split_v2(array_ops_stack.stack([input]), sep, maxsplit)[0] + elif rank == 1 or rank is None: + sparse_result = string_ops.string_split_v2( + input, sep=sep, maxsplit=maxsplit) + return ragged_tensor.RaggedTensor.from_value_rowids( + values=sparse_result.values, + value_rowids=sparse_result.indices[:, 0], + nrows=sparse_result.dense_shape[0], + validate=False) + else: + return string_split_v2( + ragged_tensor.RaggedTensor.from_tensor(input), sep, maxsplit) + + +@tf_export(v1=["string_split"]) +@dispatch.add_dispatch_support +@deprecation.deprecated_args(None, + "delimiter is deprecated, please use sep instead.", + "delimiter") +def string_split(source, sep=None, skip_empty=True, delimiter=None, + result_type="SparseTensor", name=None): # pylint: disable=invalid-name + """Split elements of `source` based on `delimiter`. + + Let N be the size of `source` (typically N will be the batch size). Split each + element of `source` based on `delimiter` and return a `SparseTensor` + or `RaggedTensor` containing the split tokens. Empty tokens are ignored. + + If `sep` is an empty string, each element of the `source` is split + into individual strings, each containing one byte. (This includes splitting + multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is + treated as a set of delimiters with each considered a potential split point. + + Examples: + + >>> print(tf.compat.v1.string_split(['hello world', 'a b c'])) + SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...), + values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...), + dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64)) + + >>> print(tf.compat.v1.string_split(['hello world', 'a b c'], + ... result_type="RaggedTensor")) + + + Args: + source: `1-D` string `Tensor`, the strings to split. + sep: `0-D` string `Tensor`, the delimiter character, the string should + be length 0 or 1. Default is ' '. + skip_empty: A `bool`. If `True`, skip the empty strings from the result. + delimiter: deprecated alias for `sep`. + result_type: The tensor type for the result: one of `"RaggedTensor"` or + `"SparseTensor"`. + name: A name for the operation (optional). + + Raises: + ValueError: If delimiter is not a string. + + Returns: + A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according + to the delimiter. The first column of the indices corresponds to the row + in `source` and the second column corresponds to the index of the split + component in this row. + """ + with ops.name_scope(name, "StringSplit", [source]): + sparse_result = string_ops.string_split( + source, sep=sep, skip_empty=skip_empty, delimiter=delimiter) + if result_type == "SparseTensor": + return sparse_result + elif result_type == "RaggedTensor": + return ragged_tensor.RaggedTensor.from_value_rowids( + values=sparse_result.values, + value_rowids=sparse_result.indices[:, 0], + nrows=sparse_result.dense_shape[0], + validate=False) + else: + raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.") + + +# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit), +# but we need to add the result_type argument. +@tf_export(v1=["strings.split"]) +@dispatch.add_dispatch_support +def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redefined-builtin + result_type="SparseTensor", source=None, name=None): + """Split elements of `input` based on `sep`. + + Let N be the size of `input` (typically N will be the batch size). Split each + element of `input` based on `sep` and return a `SparseTensor` or + `RaggedTensor` containing the split tokens. Empty tokens are ignored. + + Examples: + + >>> print(tf.compat.v1.strings.split(['hello world', 'a b c'])) + SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...), + values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...), + dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64)) + + >>> print(tf.compat.v1.strings.split(['hello world', 'a b c'], + ... result_type="RaggedTensor")) + + + If `sep` is given, consecutive delimiters are not grouped together and are + deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and + `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty + string, consecutive whitespace are regarded as a single separator, and the + result will contain no empty strings at the start or end if the string has + leading or trailing whitespace. + + Note that the above mentioned behavior matches python's str.split. + + Args: + input: A string `Tensor` of rank `N`, the strings to split. If + `rank(input)` is not known statically, then it is assumed to be `1`. + sep: `0-D` string `Tensor`, the delimiter character. + maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result. + result_type: The tensor type for the result: one of `"RaggedTensor"` or + `"SparseTensor"`. + source: alias for "input" argument. + name: A name for the operation (optional). + + Raises: + ValueError: If sep is not a string. + + Returns: + A `SparseTensor` or `RaggedTensor` of rank `N+1`, the strings split + according to the delimiter. + """ + input = deprecation.deprecated_argument_lookup( + "input", input, "source", source) + with ops.name_scope(name, "StringSplit", [input]): + input = ragged_tensor.convert_to_tensor_or_ragged_tensor( + input, dtype=dtypes.string, name="input") + + if input.shape.rank == 0: + input = array_ops.expand_dims(input, 0) + + if result_type == "SparseTensor": + if input.shape.rank == 1: + return string_ops.string_split_v2(input, sep=sep, maxsplit=maxsplit) + else: + return string_split_v2(input, sep=sep, maxsplit=maxsplit).to_sparse() + elif result_type == "RaggedTensor": + return string_split_v2(input, sep=sep, maxsplit=maxsplit) + else: + raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.") + + +@dispatch.dispatch_for_api(string_ops.reduce_join_v2) +def reduce_join(inputs: ragged_tensor.Ragged, + axis=None, + keepdims=None, + separator="", + name=None): + """For docs, see: _RAGGED_REDUCE_DOCSTRING.""" + return ragged_math_ops.ragged_reduce_aggregate( + string_ops.reduce_join, string_ops.unsorted_segment_join, inputs, axis, + keepdims, separator, name or "RaggedSegmentJoin") + + +@tf_export("strings.ngrams") +@dispatch.add_dispatch_support +def ngrams(data, + ngram_width, + separator=" ", + pad_values=None, + padding_width=None, + preserve_short_sequences=False, + name=None): + """Create a tensor of n-grams based on `data`. + + Creates a tensor of n-grams based on `data`. The n-grams are created by + joining windows of `width` adjacent strings from the inner axis of `data` + using `separator`. + + The input data can be padded on both the start and end of the sequence, if + desired, using the `pad_values` argument. If set, `pad_values` should contain + either a tuple of strings or a single string; the 0th element of the tuple + will be used to pad the left side of the sequence and the 1st element of the + tuple will be used to pad the right side of the sequence. The `padding_width` + arg controls how many padding values are added to each side; it defaults to + `ngram_width-1`. + + If this op is configured to not have padding, or if it is configured to add + padding with `padding_width` set to less than ngram_width-1, it is possible + that a sequence, or a sequence plus padding, is smaller than the ngram + width. In that case, no ngrams will be generated for that sequence. This can + be prevented by setting `preserve_short_sequences`, which will cause the op + to always generate at least one ngram per non-empty sequence. + + Examples: + + >>> tf.strings.ngrams(["A", "B", "C", "D"], 2).numpy() + array([b'A B', b'B C', b'C D'], dtype=object) + >>> tf.strings.ngrams(["TF", "and", "keras"], 1).numpy() + array([b'TF', b'and', b'keras'], dtype=object) + + Args: + data: A Tensor or RaggedTensor containing the source data for the ngrams. + ngram_width: The width(s) of the ngrams to create. If this is a list or + tuple, the op will return ngrams of all specified arities in list order. + Values must be non-Tensor integers greater than 0. + separator: The separator string used between ngram elements. Must be a + string constant, not a Tensor. + pad_values: A tuple of (left_pad_value, right_pad_value), a single string, + or None. If None, no padding will be added; if a single string, then that + string will be used for both left and right padding. Values must be Python + strings. + padding_width: If set, `padding_width` pad values will be added to both + sides of each sequence. Defaults to `ngram_width`-1. Must be greater than + 0. (Note that 1-grams are never padded, regardless of this value.) + preserve_short_sequences: If true, then ensure that at least one ngram is + generated for each input sequence. In particular, if an input sequence is + shorter than `min(ngram_width) + 2*pad_width`, then generate a single + ngram containing the entire sequence. If false, then no ngrams are + generated for these short input sequences. + name: The op name. + + Returns: + A RaggedTensor of ngrams. If `data.shape=[D1...DN, S]`, then + `output.shape=[D1...DN, NUM_NGRAMS]`, where + `NUM_NGRAMS=S-ngram_width+1+2*padding_width`. + + Raises: + TypeError: if `pad_values` is set to an invalid type. + ValueError: if `pad_values`, `padding_width`, or `ngram_width` is set to an + invalid value. + """ + + with ops.name_scope(name, "StringNGrams", [data]): + if pad_values is None: + left_pad = "" + right_pad = "" + elif isinstance(pad_values, (list, tuple)): + if (not isinstance(pad_values[0], util_compat.bytes_or_text_types) or + not isinstance(pad_values[1], util_compat.bytes_or_text_types)): + raise TypeError( + "pad_values must be a string, tuple of strings, or None.") + left_pad = pad_values[0] + right_pad = pad_values[1] + else: + if not isinstance(pad_values, util_compat.bytes_or_text_types): + raise TypeError( + "pad_values must be a string, tuple of strings, or None.") + left_pad = pad_values + right_pad = pad_values + + if padding_width is not None and padding_width < 1: + raise ValueError("padding_width must be greater than 0.") + + if padding_width is not None and pad_values is None: + raise ValueError("pad_values must be provided if padding_width is set.") + + data = ragged_tensor.convert_to_tensor_or_ragged_tensor( + data, name="data", dtype=dtypes.string) + + # preserve the shape of the data if it is a tensor + to_tensor = False + if isinstance(data, tensor_lib.Tensor): + dense_shape = array_ops.concat([array_ops.shape(data)[:-1], [-1]], axis=0) + to_tensor = True + + if not isinstance(data, ragged_tensor.RaggedTensor): + if data.shape.ndims is None: + raise ValueError("Rank of data must be known.") + elif data.shape.ndims == 0: + raise ValueError("Data must have rank>0") + elif data.shape.ndims == 1: + rt = ragged_tensor.RaggedTensor.from_row_starts( + data, [0], validate=False) + return ngrams(rt, ngram_width, separator, pad_values, padding_width, + preserve_short_sequences, name)[0] + else: + data = ragged_tensor.RaggedTensor.from_tensor( + data, ragged_rank=data.shape.ndims - 1) + + if data.ragged_rank > 1: + output = data.with_values( + ngrams(data.values, ngram_width, separator, pad_values, padding_width, + preserve_short_sequences, name)) + return array_ops.reshape(output.flat_values, + dense_shape) if to_tensor else output + + if pad_values is None: + padding_width = 0 + + if pad_values is not None and padding_width is None: + padding_width = -1 + + if not isinstance(ngram_width, (list, tuple)): + ngram_widths = [ngram_width] + else: + ngram_widths = ngram_width + for width in ngram_widths: + if width < 1: + raise ValueError("All ngram_widths must be greater than 0. Got %s" % + ngram_width) + + output, output_splits = gen_string_ops.string_n_grams( + data=data.flat_values, + data_splits=data.row_splits, + separator=separator, + ngram_widths=ngram_widths, + left_pad=left_pad, + right_pad=right_pad, + pad_width=padding_width, + preserve_short_sequences=preserve_short_sequences) + + # if the input is Dense tensor, the output should also be a dense tensor + output = ragged_tensor.RaggedTensor.from_row_splits( + values=output, row_splits=output_splits, validate=False) + return array_ops.reshape(output.flat_values, + dense_shape) if to_tensor else output + + +@dispatch.dispatch_for_api(string_ops.string_format) +def string_format( + template: str, + inputs: typing.Union[ragged_tensor.Ragged, + typing.List[ragged_tensor.RaggedOrDense]], + placeholder="{}", + summarize=3, + name=None): + """Version of tf.strings.format that handles RaggedTensors.""" + if tensor_util.is_tf_type(inputs) or ragged_tensor.is_ragged(inputs): + inputs = [inputs] + + split_template = template.split(placeholder) + if len(inputs) != len(split_template) - 1: + raise ValueError("num placeholders in template and num inputs must match" + ": {} vs {}".format(len(split_template) - 1, len(inputs))) + + with ops.name_scope(name, "StringFormat", [inputs]): + output_pieces = [constant_op.constant(split_template[0])] + for i, input in enumerate(inputs): + if ragged_tensor.is_ragged(input): + output_pieces.append(ragged_tensor_to_string(input, summarize)) + else: + output_pieces.append(string_ops.string_format( + "{}", [input], summarize=summarize)) + output_pieces.append(constant_op.constant(split_template[i + 1])) + if len(output_pieces) == 1: + return output_pieces[0] + else: + return string_ops.reduce_join(output_pieces) + + +def ragged_tensor_to_string(rt, summarize=None): + """Returns a scalar string tensor with the contents of a RaggedTensor. + + Requires that `rt.shape.rank` is not `None`. + + Note: this converts the entire `RaggedTensor` into a single string scalar. + If you want to convert individual elements, use `tf.strings.as_string(rt)`. + + >>> rt1 = tf.ragged.constant([[1, 2, 3], [4, 5]]) + >>> ragged_tensor_to_string(rt1).numpy() + b'[[1, 2, 3], [4, 5]]' + + >>> rt2 = tf.ragged.constant([[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]) + >>> ragged_tensor_to_string(rt2).numpy() + b"[[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]" + + >>> rt3 = tf.ragged.constant([[1], [2, 3, 4, 5, 6], [], [], [7], [8, 9]]) + >>> ragged_tensor_to_string(rt3, summarize=2).numpy() + b'[[1], [2, 3, ..., 5, 6], ..., [7], [8, 9]]' + + Args: + rt: The RaggedTensor that should be converted to a string. + summarize: If specified, then only the first and last `summarize` elements + within each dimension are included in the string. If `-1` or `None`, then + all elements are included. + """ + if (summarize is not None and summarize != -1 and + not (isinstance(summarize, int) and summarize > 0)): + raise ValueError("Expected summarize to be -1 or a positive int, got %r" % + summarize) + with ops.name_scope(None, "AsString", [rt]): + rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt) + if rt.shape.rank is None: + raise ValueError("RaggedTensor to_string requires that rt.shape.rank " + "is not None.") + # Convert all elements of `rt` to strings. + if rt.dtype == dtypes.string: + escaped = string_ops.regex_replace(rt.flat_values, r"(['\\])", r"\\\1") + str_t = rt.with_flat_values("'" + escaped + "'") + else: + str_t = rt.with_flat_values(string_ops.as_string(rt.flat_values)) + + return _ragged_tensor_to_string(str_t, summarize) + + +def _ragged_tensor_to_string(string_tensor, summarize): + """Returns a scalar string tensor with the contents of `string_tensor`. + + Args: + string_tensor: A potentially ragged tensor with dtype=string. + summarize: Include only the first and last `summarize` elements of each + dimension. If `-1` or `None`, then include all elements. + + Returns: + A scalar string Tensor. + """ + if string_tensor.shape.rank == 1: + pieces = string_tensor + else: + pieces = map_fn_lib.map_fn( + lambda s: _ragged_tensor_to_string(s, summarize), + string_tensor, + fn_output_signature=tensor_lib.TensorSpec(None, dtypes.string)) + if summarize not in (-1, None): + pieces = cond.cond( + _nrows(string_tensor) <= 2 * summarize, + lambda: pieces, + lambda: array_ops.concat( # pylint: disable=g-long-lambda + [pieces[:summarize], ["..."], pieces[-summarize:]], + axis=0)) + return "[" + string_ops.reduce_join(pieces, separator=", ") + "]" + + +def _nrows(tensor, out_type=dtypes.int32): + if isinstance(tensor, ragged_tensor.RaggedTensor): + return tensor.nrows(out_type=out_type) + else: + return array_ops.shape(tensor, out_type=out_type)[0] + + +@dispatch.dispatch_for_api(string_ops.string_join) +def string_join(inputs: typing.List[ragged_tensor.RaggedOrDense], + separator="", + name=None): + """RaggedTensor implementation for tf.strings.join.""" + if len(inputs) < 0: + raise ValueError("tf.strings.join: expected at least one input.") + with ops.name_scope(name, "RaggedStringJoin", inputs): + return ragged_functional_ops.map_flat_values(string_ops.string_join, inputs, + separator) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..a92d425a4c748eab10ff468a432272912a0689b7 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py @@ -0,0 +1,3149 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Classes for storing ragged tensors and their values.""" + +import functools +import operator + +import typing +import numpy as np + +from tensorflow.core.protobuf import struct_pb2 +from tensorflow.python import tf2 +from tensorflow.python.client import session +from tensorflow.python.framework import composite_tensor +from tensorflow.python.framework import composite_tensor_gradient +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor as tensor_lib +from tensorflow.python.framework import tensor_conversion +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import type_spec +from tensorflow.python.framework import type_spec_registry +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import array_ops_stack +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import cond +from tensorflow.python.ops import control_flow_assert +from tensorflow.python.ops import gen_ragged_conversion_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import ragged_config +from tensorflow.python.ops.ragged import ragged_tensor_value +from tensorflow.python.ops.ragged import ragged_util +from tensorflow.python.ops.ragged.row_partition import RowPartition +from tensorflow.python.saved_model import nested_structure_coder +from tensorflow.python.types import core as core_types +from tensorflow.python.types import internal as internal_types +from tensorflow.python.util import dispatch +from tensorflow.python.util.tf_export import tf_export +from tensorflow.tools.docs import doc_controls + +# pylint: disable=protected-access +_convert_row_partition = RowPartition._convert_row_partition +# pylint: enable=protected-access + +# =============================================================================== +# RaggedTensor +# =============================================================================== + + +@tf_export("RaggedTensor") +class RaggedTensor( + composite_tensor.CompositeTensor, + internal_types.NativeObject, + internal_types.RaggedTensor, +): + """Represents a ragged tensor. + + A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are + dimensions whose slices may have different lengths. For example, the inner + (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged, + since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths. + Dimensions whose slices all have the same length are called *uniform + dimensions*. The outermost dimension of a `RaggedTensor` is always uniform, + since it consists of a single slice (and so there is no possibility for + differing slice lengths). + + The total number of dimensions in a `RaggedTensor` is called its *rank*, + and the number of ragged dimensions in a `RaggedTensor` is called its + *ragged-rank*. A `RaggedTensor`'s ragged-rank is fixed at graph creation + time: it can't depend on the runtime values of `Tensor`s, and can't vary + dynamically for different session runs. + + Note that the `__init__` constructor is private. Please use one of the + following methods to construct a `RaggedTensor`: + + * `tf.RaggedTensor.from_row_lengths` + * `tf.RaggedTensor.from_value_rowids` + * `tf.RaggedTensor.from_row_splits` + * `tf.RaggedTensor.from_row_starts` + * `tf.RaggedTensor.from_row_limits` + * `tf.RaggedTensor.from_nested_row_splits` + * `tf.RaggedTensor.from_nested_row_lengths` + * `tf.RaggedTensor.from_nested_value_rowids` + + ### Potentially Ragged Tensors + + Many ops support both `Tensor`s and `RaggedTensor`s + (see [tf.ragged](https://www.tensorflow.org/api_docs/python/tf/ragged) for a + full listing). The term "potentially ragged tensor" may be used to refer to a + tensor that might be either a `Tensor` or a `RaggedTensor`. The ragged-rank + of a `Tensor` is zero. + + ### Documenting RaggedTensor Shapes + + When documenting the shape of a RaggedTensor, ragged dimensions can be + indicated by enclosing them in parentheses. For example, the shape of + a 3-D `RaggedTensor` that stores the fixed-size word embedding for each + word in a sentence, for each sentence in a batch, could be written as + `[num_sentences, (num_words), embedding_size]`. The parentheses around + `(num_words)` indicate that dimension is ragged, and that the length + of each element list in that dimension may vary for each item. + + ### Component Tensors + + Internally, a `RaggedTensor` consists of a concatenated list of values that + are partitioned into variable-length rows. In particular, each `RaggedTensor` + consists of: + + * A `values` tensor, which concatenates the variable-length rows into a + flattened list. For example, the `values` tensor for + `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`. + + * A `row_splits` vector, which indicates how those flattened values are + divided into rows. In particular, the values for row `rt[i]` are stored + in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. + + Example: + + >>> print(tf.RaggedTensor.from_row_splits( + ... values=[3, 1, 4, 1, 5, 9, 2, 6], + ... row_splits=[0, 4, 4, 7, 8, 8])) + + + ### Alternative Row-Partitioning Schemes + + In addition to `row_splits`, ragged tensors provide support for five other + row-partitioning schemes: + + * `row_lengths`: a vector with shape `[nrows]`, which specifies the length + of each row. + + * `value_rowids` and `nrows`: `value_rowids` is a vector with shape + `[nvals]`, corresponding one-to-one with `values`, which specifies + each value's row index. In particular, the row `rt[row]` consists of the + values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an + integer scalar that specifies the number of rows in the + `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.) + + * `row_starts`: a vector with shape `[nrows]`, which specifies the start + offset of each row. Equivalent to `row_splits[:-1]`. + + * `row_limits`: a vector with shape `[nrows]`, which specifies the stop + offset of each row. Equivalent to `row_splits[1:]`. + + * `uniform_row_length`: A scalar tensor, specifying the length of every + row. This row-partitioning scheme may only be used if all rows have + the same length. + + Example: The following ragged tensors are equivalent, and all represent the + nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`. + + >>> values = [3, 1, 4, 1, 5, 9, 2, 6] + >>> RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8]) + + >>> RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0]) + + >>> RaggedTensor.from_value_rowids( + ... values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) + + >>> RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8]) + + >>> RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8]) + + >>> RaggedTensor.from_uniform_row_length(values, uniform_row_length=2) + + + ### Multiple Ragged Dimensions + + `RaggedTensor`s with multiple ragged dimensions can be defined by using + a nested `RaggedTensor` for the `values` tensor. Each nested `RaggedTensor` + adds a single ragged dimension. + + >>> inner_rt = RaggedTensor.from_row_splits( # =rt1 from above + ... values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8]) + >>> outer_rt = RaggedTensor.from_row_splits( + ... values=inner_rt, row_splits=[0, 3, 3, 5]) + >>> print(outer_rt.to_list()) + [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] + >>> print(outer_rt.ragged_rank) + 2 + + The factory function `RaggedTensor.from_nested_row_splits` may be used to + construct a `RaggedTensor` with multiple ragged dimensions directly, by + providing a list of `row_splits` tensors: + + >>> RaggedTensor.from_nested_row_splits( + ... flat_values=[3, 1, 4, 1, 5, 9, 2, 6], + ... nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list() + [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]] + + ### Uniform Inner Dimensions + + `RaggedTensor`s with uniform inner dimensions can be defined + by using a multidimensional `Tensor` for `values`. + + >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3], tf.int32), + ... row_splits=[0, 2, 5]) + >>> print(rt.to_list()) + [[[1, 1, 1], [1, 1, 1]], + [[1, 1, 1], [1, 1, 1], [1, 1, 1]]] + >>> print(rt.shape) + (2, None, 3) + + ### Uniform Outer Dimensions + + `RaggedTensor`s with uniform outer dimensions can be defined by using + one or more `RaggedTensor` with a `uniform_row_length` row-partitioning + tensor. For example, a `RaggedTensor` with shape `[2, 2, None]` can be + constructed with this method from a `RaggedTensor` values with shape + `[4, None]`: + + >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) + >>> print(values.shape) + (4, None) + >>> rt6 = tf.RaggedTensor.from_uniform_row_length(values, 2) + >>> print(rt6) + + >>> print(rt6.shape) + (2, 2, None) + + Note that `rt6` only contains one ragged dimension (the innermost + dimension). In contrast, if `from_row_splits` is used to construct a similar + `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions: + + >>> rt7 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4]) + >>> print(rt7.shape) + (2, None, None) + + Uniform and ragged outer dimensions may be interleaved, meaning that a + tensor with any combination of ragged and uniform dimensions may be created. + For example, a RaggedTensor `t4` with shape `[3, None, 4, 8, None, 2]` could + be constructed as follows: + + ```python + t0 = tf.zeros([1000, 2]) # Shape: [1000, 2] + t1 = RaggedTensor.from_row_lengths(t0, [...]) # [160, None, 2] + t2 = RaggedTensor.from_uniform_row_length(t1, 8) # [20, 8, None, 2] + t3 = RaggedTensor.from_uniform_row_length(t2, 4) # [5, 4, 8, None, 2] + t4 = RaggedTensor.from_row_lengths(t3, [...]) # [3, None, 4, 8, None, 2] + ``` + + """ + + #============================================================================= + # Constructor (private) + #============================================================================= + @doc_controls.do_not_generate_docs + def __init__(self, values, row_partition, internal=False): + """Creates a `RaggedTensor` with a specified partitioning for `values`. + + This constructor is private -- please use one of the following ops to + build `RaggedTensor`s: + + * `tf.RaggedTensor.from_row_lengths` + * `tf.RaggedTensor.from_value_rowids` + * `tf.RaggedTensor.from_row_splits` + * `tf.RaggedTensor.from_row_starts` + * `tf.RaggedTensor.from_row_limits` + * `tf.RaggedTensor.from_nested_row_splits` + * `tf.RaggedTensor.from_nested_row_lengths` + * `tf.RaggedTensor.from_nested_value_rowids` + + Args: + values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`. + row_partition: A `RowPartition` object, representing the arrangement of + the lists at the top level. + internal: True if the constructor is being called by one of the factory + methods. If false, an exception will be raised. + + Raises: + ValueError: If internal = False. Note that this method is intended only + for internal use. + TypeError: If values is not a `RaggedTensor` or `Tensor`, or + row_partition is not a `RowPartition`. + """ + + if not internal: + raise ValueError("RaggedTensor constructor is private; please use one " + "of the factory methods instead (e.g., " + "RaggedTensor.from_row_lengths())") + _assert_is_supported_ragged_values_type(values) + if not isinstance(row_partition, RowPartition): + raise TypeError(f"Argument `row_partition` must be a RowPartition. " + f"Received {row_partition}.") + + # Validate shapes. + values.shape.with_rank_at_least(1) + if isinstance(values, RaggedTensor): + # pylint: disable=protected-access + assert row_partition.dtype == values._row_partition.dtype + + self._values = values + self._row_partition = row_partition + + #============================================================================= + # Factory Methods + #============================================================================= + + @classmethod + def _from_row_partition(cls, values, row_partition, validate=True): + """Creates a `RaggedTensor` with a row partition. + + This is used as a way for RaggedTensors to share row partitions. + + The outer dimension of values must be equal to `partition.nvals()`. + + Args: + values: A potentially ragged tensor. + row_partition: a `RowPartition`: can be shared between tensors. + validate: If true, then use assertions to check that the arguments form a + valid `RaggedTensor`. + + Returns: + A `RaggedTensor`. `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + + Raises: + ValueError: If partition.nvals() != _nrows(values) + """ + if not isinstance(row_partition, RowPartition): + raise TypeError(f"Argument `row_partition` must be a RowPartition. " + f"Received {row_partition}.") + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + values, row_partition = cls._convert_values_and_partition( + values, row_partition, "partition") + if row_partition._has_precomputed_value_rowids(): # pylint: disable=protected-access + value_rowids_shape = row_partition.value_rowids().shape + values.shape[:1].assert_is_compatible_with(value_rowids_shape) + if validate: + msg = "Arguments to _from_row_partition do not form a valid RaggedTensor" + nvals = _nrows(values, row_partition.dtype) + checks = [ + check_ops.assert_equal( + math_ops.cast(row_partition.nvals(), row_partition.dtype), + nvals, + message=msg), + ] + if not isinstance(values, RaggedTensor): + checks.append(check_ops.assert_rank_at_least(values, 1)) + row_partition = row_partition._with_dependencies(checks) # pylint: disable=protected-access + return cls(values=values, internal=True, row_partition=row_partition) + + @classmethod + @dispatch.add_dispatch_support + def from_value_rowids(cls, + values, + value_rowids, + nrows=None, + name=None, + validate=True): + """Creates a `RaggedTensor` with rows partitioned by `value_rowids`. + + The returned `RaggedTensor` corresponds with the python list defined by: + + ```python + result = [[values[i] for i in range(len(values)) if value_rowids[i] == row] + for row in range(nrows)] + ``` + + Args: + values: A potentially ragged tensor with shape `[nvals, ...]`. + value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds + one-to-one with `values`, and specifies each value's row index. Must be + nonnegative, and must be sorted in ascending order. + nrows: An integer scalar specifying the number of rows. This should be + specified if the `RaggedTensor` may containing empty training rows. Must + be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty). + Defaults to `value_rowids[-1] + 1` (or zero if `value_rowids` is empty). + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor`. `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + + Raises: + ValueError: If `nrows` is incompatible with `value_rowids`. + + #### Example: + + >>> print(tf.RaggedTensor.from_value_rowids( + ... values=[3, 1, 4, 1, 5, 9, 2, 6], + ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], + ... nrows=5)) + + + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + + with ops.name_scope(name, "RaggedFromValueRowIds", + [values, value_rowids, nrows]): + row_partition = RowPartition.from_value_rowids( + value_rowids=value_rowids, + nrows=nrows, + validate=validate, + dtype_hint=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) + + @classmethod + @dispatch.add_dispatch_support + def from_row_splits(cls, values, row_splits, name=None, validate=True): + """Creates a `RaggedTensor` with rows partitioned by `row_splits`. + + The returned `RaggedTensor` corresponds with the python list defined by: + + ```python + result = [values[row_splits[i]:row_splits[i + 1]] + for i in range(len(row_splits) - 1)] + ``` + + Args: + values: A potentially ragged tensor with shape `[nvals, ...]`. + row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be + empty, and must be sorted in ascending order. `row_splits[0]` must be + zero and `row_splits[-1]` must be `nvals`. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor`. `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + + Raises: + ValueError: If `row_splits` is an empty list. + + #### Example: + + >>> print(tf.RaggedTensor.from_row_splits( + ... values=[3, 1, 4, 1, 5, 9, 2, 6], + ... row_splits=[0, 4, 4, 7, 8, 8])) + + + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + + with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): + row_partition = RowPartition.from_row_splits( + row_splits=row_splits, + validate=validate, + dtype_hint=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) + + @classmethod + @dispatch.add_dispatch_support + def from_row_lengths(cls, values, row_lengths, name=None, validate=True): + """Creates a `RaggedTensor` with rows partitioned by `row_lengths`. + + The returned `RaggedTensor` corresponds with the python list defined by: + + ```python + result = [[values.pop(0) for i in range(length)] + for length in row_lengths] + ``` + + Args: + values: A potentially ragged tensor with shape `[nvals, ...]`. + row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be + nonnegative. `sum(row_lengths)` must be `nvals`. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor`. `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + + #### Example: + + >>> print(tf.RaggedTensor.from_row_lengths( + ... values=[3, 1, 4, 1, 5, 9, 2, 6], + ... row_lengths=[4, 0, 3, 1, 0])) + + + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + + with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]): + row_partition = RowPartition.from_row_lengths( + row_lengths=row_lengths, + validate=validate, + dtype_hint=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) + + @classmethod + @dispatch.add_dispatch_support + def from_row_starts(cls, values, row_starts, name=None, validate=True): + """Creates a `RaggedTensor` with rows partitioned by `row_starts`. + + Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`. + + Args: + values: A potentially ragged tensor with shape `[nvals, ...]`. + row_starts: A 1-D integer tensor with shape `[nrows]`. Must be + nonnegative and sorted in ascending order. If `nrows>0`, then + `row_starts[0]` must be zero. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor`. `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + + #### Example: + + >>> print(tf.RaggedTensor.from_row_starts( + ... values=[3, 1, 4, 1, 5, 9, 2, 6], + ... row_starts=[0, 4, 4, 7, 8])) + + + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]): + values = _convert_to_ragged_tensor_values(values) + row_partition = RowPartition.from_row_starts( + row_starts=row_starts, + nvals=_nrows(values), + validate=validate, + dtype_hint=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) + + @classmethod + @dispatch.add_dispatch_support + def from_row_limits(cls, values, row_limits, name=None, validate=True): + """Creates a `RaggedTensor` with rows partitioned by `row_limits`. + + Equivalent to: `from_row_splits(values, concat([0, row_limits]))`. + + Args: + values: A potentially ragged tensor with shape `[nvals, ...]`. + row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in + ascending order. If `nrows>0`, then `row_limits[-1]` must be `nvals`. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor`. `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + + #### Example: + + >>> print(tf.RaggedTensor.from_row_limits( + ... values=[3, 1, 4, 1, 5, 9, 2, 6], + ... row_limits=[4, 4, 7, 8, 8])) + + + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]): + values = _convert_to_ragged_tensor_values(values) + row_partition = RowPartition.from_row_limits( + row_limits=row_limits, + validate=validate, + dtype_hint=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) + + @classmethod + @dispatch.add_dispatch_support + def from_uniform_row_length(cls, + values, + uniform_row_length, + nrows=None, + validate=True, + name=None): + """Creates a `RaggedTensor` with rows partitioned by `uniform_row_length`. + + This method can be used to create `RaggedTensor`s with multiple uniform + outer dimensions. For example, a `RaggedTensor` with shape `[2, 2, None]` + can be constructed with this method from a `RaggedTensor` values with shape + `[4, None]`: + + >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) + >>> print(values.shape) + (4, None) + >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2) + >>> print(rt1) + + >>> print(rt1.shape) + (2, 2, None) + + Note that `rt1` only contains one ragged dimension (the innermost + dimension). In contrast, if `from_row_splits` is used to construct a similar + `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions: + + >>> rt2 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4]) + >>> print(rt2.shape) + (2, None, None) + + Args: + values: A potentially ragged tensor with shape `[nvals, ...]`. + uniform_row_length: A scalar integer tensor. Must be nonnegative. The + size of the outer axis of `values` must be evenly divisible by + `uniform_row_length`. + nrows: The number of rows in the constructed RaggedTensor. If not + specified, then it defaults to `nvals/uniform_row_length` (or `0` if + `uniform_row_length==0`). `nrows` only needs to be specified if + `uniform_row_length` might be zero. `uniform_row_length*nrows` must be + `nvals`. + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + name: A name prefix for the RaggedTensor (optional). + + Returns: + A `RaggedTensor` that corresponds with the python list defined by: + + ```python + result = [[values.pop(0) for i in range(uniform_row_length)] + for _ in range(nrows)] + ``` + + `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + with ops.name_scope(name, "RaggedFromUniformRowLength", + [values, uniform_row_length, nrows]): + values = _convert_to_ragged_tensor_values(values) + uniform_row_length = _convert_row_partition( + uniform_row_length, "UniformRowLength", + _get_optional_partition_dtype(values)) + nvals = _nvals_uniform_row_length(values, uniform_row_length) + row_partition = RowPartition.from_uniform_row_length( + uniform_row_length=uniform_row_length, + nvals=nvals, + nrows=nrows, + validate=validate, + dtype_hint=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) + + @classmethod + @dispatch.add_dispatch_support + def from_nested_value_rowids(cls, + flat_values, + nested_value_rowids, + nested_nrows=None, + name=None, + validate=True): + """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors. + + Equivalent to: + + ```python + result = flat_values + for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)): + result = from_value_rowids(result, rowids, nrows) + ``` + + Args: + flat_values: A potentially ragged tensor. + nested_value_rowids: A list of 1-D integer tensors. The `i`th tensor is + used as the `value_rowids` for the `i`th ragged dimension. + nested_nrows: A list of integer scalars. The `i`th scalar is used as the + `nrows` for the `i`th ragged dimension. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty). + + Raises: + ValueError: If `len(nested_values_rowids) != len(nested_nrows)`. + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + if isinstance(nested_value_rowids, tensor_lib.Tensor): + raise TypeError(f"Argument `nested_value_rowids` must be a list of " + f"Tensors. Received {nested_value_rowids}.") + if nested_nrows is None: + nested_nrows = [None] * len(nested_value_rowids) + else: + if isinstance(nested_nrows, tensor_lib.Tensor): + raise TypeError(f"Argument `nested_nrows` must be a list of " + f"Tensors. Received {nested_nrows}.") + if len(nested_nrows) != len(nested_value_rowids): + raise ValueError( + f"Argument `nested_nrows` must have the same length as " + f"argument `nested_value_rowids`. len(nested_nrows) = " + f"{len(nested_nrows)} vs. len(nested_values_rowids) = " + f"{len(nested_value_rowids)}.") + + with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] + + list(nested_value_rowids) + list(nested_nrows)): + result = flat_values + for value_rowids, nrows in reversed( + list(zip(nested_value_rowids, nested_nrows))): + result = cls.from_value_rowids( + result, value_rowids, nrows, validate=validate) + return result + + @classmethod + @dispatch.add_dispatch_support + def from_nested_row_splits(cls, + flat_values, + nested_row_splits, + name=None, + validate=True): + """Creates a `RaggedTensor` from a nested list of `row_splits` tensors. + + Equivalent to: + + ```python + result = flat_values + for row_splits in reversed(nested_row_splits): + result = from_row_splits(result, row_splits) + ``` + + Args: + flat_values: A potentially ragged tensor. + nested_row_splits: A list of 1-D integer tensors. The `i`th tensor is + used as the `row_splits` for the `i`th ragged dimension. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty). + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + if isinstance(nested_row_splits, tensor_lib.Tensor): + raise TypeError(f"Argument `nested_row_splits` must be a list of " + f"Tensors. Received {nested_row_splits}.") + with ops.name_scope(name, "RaggedFromNestedRowSplits", + [flat_values] + list(nested_row_splits)): + result = flat_values + for splits in reversed(nested_row_splits): + result = cls.from_row_splits(result, splits, validate=validate) + return result + + @classmethod + @dispatch.add_dispatch_support + def from_nested_row_lengths(cls, + flat_values, + nested_row_lengths, + name=None, + validate=True): + """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors. + + Equivalent to: + + ```python + result = flat_values + for row_lengths in reversed(nested_row_lengths): + result = from_row_lengths(result, row_lengths) + ``` + + Args: + flat_values: A potentially ragged tensor. + nested_row_lengths: A list of 1-D integer tensors. The `i`th tensor is + used as the `row_lengths` for the `i`th ragged dimension. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty). + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + if isinstance(nested_row_lengths, tensor_lib.Tensor): + raise TypeError(f"Argument `nested_row_lengths` must be a list of " + f"Tensors. Received {nested_row_lengths}.") + with ops.name_scope(name, "RaggedFromNestedRowlengths", + [flat_values] + list(nested_row_lengths)): + result = flat_values + for lengths in reversed(nested_row_lengths): + result = cls.from_row_lengths(result, lengths, validate=validate) + return result + + @classmethod + def _from_nested_row_partitions(cls, + flat_values, + nested_row_partitions, + name=None, + validate=True): + """Creates a `RaggedTensor` from a nested list of row partitions. + + Equivalent to: + + ```python + result = flat_values + for row_partition in reversed(nested_row_partitions): + result = _from_row_partition(result, row_partition) + ``` + + Args: + flat_values: A potentially ragged tensor. + nested_row_partitions: A list of row partitions. The `i`th element is + used as the row partition for the `i`th ragged dimension. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form + a valid `RaggedTensor`. Note: these assertions incur a runtime cost, + since they must be checked for each tensor value. + + Returns: + A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty). + """ + if not isinstance(validate, bool): + raise TypeError(f"Argument `validate` must have type bool. " + f"Received {validate}.") + if isinstance(nested_row_partitions, RowPartition): + raise TypeError(f"Argument `nested_row_partitions` must be a list of " + f"RowPartitions. Received {nested_row_partitions}.") + if isinstance(nested_row_partitions, tensor_lib.Tensor): + raise TypeError(f"Argument `nested_row_partitions` must be a list of " + f"RowPartitions. Received {nested_row_partitions}.") + with ops.name_scope(name, "RaggedFromNestedRowPartitions", + [flat_values] + list(nested_row_partitions)): + result = flat_values + for partition in reversed(nested_row_partitions): + result = cls._from_row_partition(result, partition, validate=validate) + return result + + @classmethod + def _convert_values_and_partition(cls, values, row_partition, name): + """Converts `values` and `partition` to Tensors. + + If `values` is a `RaggedTensor`, then converts `values` and `partition` + to have compatible row-partitioning dtypes. In particular, if any of the + row partitioning tensors are `int64`, then all of the other row + partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype() + is true) or an error will be raised (if auto_cast_partition_dtype() is + false). + + Args: + values: The `values` for the `RaggedTensor` being constructed. + row_partition: A RowPartition object for the `RaggedTensor` being + constructed. + name: The name of the RowPartition object. + + Returns: + A tuple (values, partition). + """ + if not isinstance(row_partition, RowPartition): + raise TypeError(f"Argument `row_partition` must be a RowPartition. " + f"Received {row_partition}.") + if isinstance(values, RaggedTensor): + # pylint: disable=protected-access + if values._row_partition.dtype != row_partition.dtype: + if not ragged_config.auto_cast_partition_dtype(): + # pylint: disable=protected-access + # TODO(edloper): get rid of the `name` parameter. + raise ValueError( + f"Argument `row_partition` of RaggedTensor with name: {name} " + f"must have same dtype as Argument `values`. " + f"({row_partition.dtype} vs. {values._row_partition.dtype}).") + values = values.with_row_splits_dtype(row_partition.dtype) + else: + values = _convert_to_ragged_tensor_values(values) + + return (values, row_partition) + + #============================================================================= + # Accessors + #============================================================================= + + @property + def dtype(self): + """The `DType` of values in this tensor.""" + return self._values.dtype + + @property + def shape(self): + """The statically known shape of this ragged tensor. + + Returns: + A `TensorShape` containing the statically known shape of this ragged + tensor. Ragged dimensions have a size of `None`. + + Examples: + + >>> tf.ragged.constant([[0], [1, 2]]).shape + TensorShape([2, None]) + + >>> tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape + TensorShape([2, None, 2]) + + """ + nrows = self._row_partition.static_nrows + ncols = self._row_partition.static_uniform_row_length + value_shape = self._values.shape[1:] + return tensor_shape.TensorShape([nrows, ncols]).concatenate(value_shape) + + def get_shape(self) -> tensor_shape.TensorShape: + """The statically known shape of this ragged tensor. + + Returns: + A `TensorShape` containing the statically known shape of this ragged + tensor. Ragged dimensions have a size of `None`. + + Alias for `shape` property. + + Examples: + + >>> tf.ragged.constant([[0], [1, 2]]).get_shape() + TensorShape([2, None]) + + >>> tf.ragged.constant( + ... [[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).get_shape() + TensorShape([2, None, 2]) + + """ + return self.shape + + @property + def ragged_rank(self): + """The number of times the RaggedTensor's flat_values is partitioned. + + Examples: + + >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) + >>> values.ragged_rank + 1 + + >>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2) + >>> rt.ragged_rank + 2 + + Returns: + A Python `int` indicating the number of times the underlying `flat_values` + Tensor has been partitioned to add a new dimension. + I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`. + """ + values_is_ragged = isinstance(self._values, RaggedTensor) + return self._values.ragged_rank + 1 if values_is_ragged else 1 + + @property + def values(self): + """The concatenated rows for this ragged tensor. + + `rt.values` is a potentially ragged tensor formed by flattening the two + outermost dimensions of `rt` into a single dimension. + + `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the + number of items in the outer two dimensions of `rt`). + + `rt.ragged_rank = self.ragged_rank - 1` + + Returns: + A potentially ragged tensor. + + #### Example: + + >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + >>> print(rt.values) + tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) + + """ + return self._values + + @property + def _nested_row_partitions(self): + """Returns the row partitions for this `RaggedTensor`.""" + partitions = [self._row_partition] + rt_values = self.values + while isinstance(rt_values, RaggedTensor): + # pylint: disable=protected-access + partitions.append(rt_values._row_partition) + rt_values = rt_values.values + return tuple(partitions) + + @property + def row_splits(self): + """The row-split indices for this ragged tensor's `values`. + + `rt.row_splits` specifies where the values for each row begin and end in + `rt.values`. In particular, the values for row `rt[i]` are stored in + the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. + + Returns: + A 1-D integer `Tensor` with shape `[self.nrows+1]`. + The returned tensor is non-empty, and is sorted in ascending order. + `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to + `self.values.shape[0]`. + + #### Example: + + >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + >>> print(rt.row_splits) # indices of row splits in rt.values + tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64) + + """ + return self._row_partition.row_splits() + + @property + def uniform_row_length(self): + """The length of each row in this ragged tensor, or None if rows are ragged. + + >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) + >>> print(rt1.uniform_row_length) # rows are ragged. + None + + >>> rt2 = tf.RaggedTensor.from_uniform_row_length( + ... values=rt1, uniform_row_length=2) + >>> print(rt2) + + >>> print(rt2.uniform_row_length) # rows are not ragged (all have size 2). + tf.Tensor(2, shape=(), dtype=int64) + + A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged) + if it can be determined statically (at graph construction time) that the + rows all have the same length. + + Returns: + A scalar integer `Tensor`, specifying the length of every row in this + ragged tensor (for ragged tensors whose rows are uniform); or `None` + (for ragged tensors whose rows are ragged). + """ + return self._row_partition.uniform_row_length() + + @property + def flat_values(self): + """The innermost `values` tensor for this ragged tensor. + + Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is + `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`. + + Conceptually, `flat_values` is the tensor formed by flattening the + outermost dimension and all of the ragged dimensions into a single + dimension. + + `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]` + (where `nvals` is the number of items in the flattened dimensions). + + Returns: + A `Tensor`. + + #### Example: + + >>> rt = tf.ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) + >>> print(rt.flat_values) + tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) + + """ + rt_values = self.values + while isinstance(rt_values, RaggedTensor): + rt_values = rt_values.values + return rt_values + + @property + def nested_row_splits(self): + """A tuple containing the row_splits for all ragged dimensions. + + `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for + all ragged dimensions in `rt`, ordered from outermost to innermost. In + particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where: + + * `value_splits = ()` if `rt.values` is a `Tensor`. + * `value_splits = rt.values.nested_row_splits` otherwise. + + Returns: + A `tuple` of 1-D integer `Tensor`s. + + #### Example: + + >>> rt = tf.ragged.constant( + ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]]) + >>> for i, splits in enumerate(rt.nested_row_splits): + ... print('Splits for dimension %d: %s' % (i+1, splits.numpy())) + Splits for dimension 1: [0 3] + Splits for dimension 2: [0 3 3 5] + Splits for dimension 3: [0 4 4 7 8 8] + + """ + rt_nested_splits = [self.row_splits] + rt_values = self.values + while isinstance(rt_values, RaggedTensor): + rt_nested_splits.append(rt_values.row_splits) + rt_values = rt_values.values + return tuple(rt_nested_splits) + + def value_rowids(self, name=None): + """Returns the row indices for the `values` in this ragged tensor. + + `rt.value_rowids()` corresponds one-to-one with the outermost dimension of + `rt.values`, and specifies the row containing each value. In particular, + the row `rt[row]` consists of the values `rt.values[j]` where + `rt.value_rowids()[j] == row`. + + Args: + name: A name prefix for the returned tensor (optional). + + Returns: + A 1-D integer `Tensor` with shape `self.values.shape[:1]`. + The returned tensor is nonnegative, and is sorted in ascending order. + + #### Example: + + >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + >>> print(rt.values) + tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) + >>> print(rt.value_rowids()) # corresponds 1:1 with rt.values + tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64) + + """ + with ops.name_scope(name, "RaggedValueRowIds", [self]): + return self._row_partition.value_rowids() + + def nested_value_rowids(self, name=None): + """Returns a tuple containing the value_rowids for all ragged dimensions. + + `rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors + for + all ragged dimensions in `rt`, ordered from outermost to innermost. In + particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids` + where: + + * `value_ids = ()` if `rt.values` is a `Tensor`. + * `value_ids = rt.values.nested_value_rowids` otherwise. + + Args: + name: A name prefix for the returned tensors (optional). + + Returns: + A `tuple` of 1-D integer `Tensor`s. + + #### Example: + + >>> rt = tf.ragged.constant( + ... [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]]) + >>> for i, ids in enumerate(rt.nested_value_rowids()): + ... print('row ids for dimension %d: %s' % (i+1, ids.numpy())) + row ids for dimension 1: [0 0 0] + row ids for dimension 2: [0 0 0 2 2] + row ids for dimension 3: [0 0 0 0 2 2 2 3] + + """ + with ops.name_scope(name, "RaggedNestedValueRowIds", [self]): + rt_nested_ids = [self.value_rowids()] + rt_values = self.values + while isinstance(rt_values, RaggedTensor): + rt_nested_ids.append(rt_values.value_rowids()) + rt_values = rt_values.values + return tuple(rt_nested_ids) + + def nrows(self, out_type=None, name=None): + """Returns the number of rows in this ragged tensor. + + I.e., the size of the outermost dimension of the tensor. + + Args: + out_type: `dtype` for the returned tensor. Defaults to + `self.row_splits.dtype`. + name: A name prefix for the returned tensor (optional). + + Returns: + A scalar `Tensor` with dtype `out_type`. + + #### Example: + + >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + >>> print(rt.nrows()) # rt has 5 rows. + tf.Tensor(5, shape=(), dtype=int64) + + """ + with ops.name_scope(name, "RaggedNRows", [self]): + if out_type is None: + return self._row_partition.nrows() + else: + return math_ops.cast(self._row_partition.nrows(), dtype=out_type) + + def row_starts(self, name=None): + """Returns the start indices for rows in this ragged tensor. + + These indices specify where the values for each row begin in + `self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`. + + Args: + name: A name prefix for the returned tensor (optional). + + Returns: + A 1-D integer Tensor with shape `[nrows]`. + The returned tensor is nonnegative, and is sorted in ascending order. + + #### Example: + + >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + >>> print(rt.values) + tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) + >>> print(rt.row_starts()) # indices of row starts in rt.values + tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64) + + """ + with ops.name_scope(name, "RaggedRowStarts", [self]): + return self._row_partition.row_starts() + + def row_limits(self, name=None): + """Returns the limit indices for rows in this ragged tensor. + + These indices specify where the values for each row end in + `self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`. + + Args: + name: A name prefix for the returned tensor (optional). + + Returns: + A 1-D integer Tensor with shape `[nrows]`. + The returned tensor is nonnegative, and is sorted in ascending order. + + #### Example: + + >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []]) + >>> print(rt.values) + tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32) + >>> print(rt.row_limits()) # indices of row limits in rt.values + tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64) + + """ + with ops.name_scope(name, "RaggedRowLimits", [self]): + return self._row_partition.row_limits() + + def row_lengths(self, axis=1, name=None): + """Returns the lengths of the rows in this ragged tensor. + + `rt.row_lengths()[i]` indicates the number of values in the + `i`th row of `rt`. + + Args: + axis: An integer constant indicating the axis whose row lengths should be + returned. + name: A name prefix for the returned tensor (optional). + + Returns: + A potentially ragged integer Tensor with shape `self.shape[:axis]`. + + Raises: + ValueError: If `axis` is out of bounds. + + #### Example: + + >>> rt = tf.ragged.constant( + ... [[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []]) + >>> print(rt.row_lengths()) # lengths of rows in rt + tf.Tensor([2 0 2 1 0], shape=(5,), dtype=int64) + >>> print(rt.row_lengths(axis=2)) # lengths of axis=2 rows. + + + """ + if axis == 0: + return self._row_partition.nrows() + + if axis == 1: + return self._row_partition.row_lengths() + + with ops.name_scope(name, "RaggedRowLengths", [self]): + axis = array_ops.get_positive_axis( + axis, self.shape.rank, ndims_name="rank(self)") + if axis == 0: + return self.nrows() + elif axis == 1: + splits = self.row_splits + return splits[1:] - splits[:-1] + elif isinstance(self.values, RaggedTensor): + return self.with_values(self.values.row_lengths(axis - 1)) + else: + shape = array_ops.shape(self.values, out_type=self._row_partition.dtype) + return self.with_values( + array_ops.ones(shape[:axis - 1], self._row_partition.dtype) * + shape[axis - 1]) + + def nested_row_lengths(self, name=None): + """Returns a tuple containing the row_lengths for all ragged dimensions. + + `rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors + for all ragged dimensions in `rt`, ordered from outermost to innermost. + + Args: + name: A name prefix for the returned tensors (optional). + + Returns: + A `tuple` of 1-D integer `Tensors`. The length of the tuple is equal to + `self.ragged_rank`. + """ + with ops.name_scope(name, "RaggedNestedRowLengths", [self]): + rt_nested_row_lengths = [] + rt = self + while isinstance(rt, RaggedTensor): + rt_nested_row_lengths.append(rt.row_lengths()) + rt = rt.values + return tuple(rt_nested_row_lengths) + + def bounding_shape(self, axis=None, name=None, out_type=None): + """Returns the tight bounding box shape for this `RaggedTensor`. + + Args: + axis: An integer scalar or vector indicating which axes to return the + bounding box for. If not specified, then the full bounding box is + returned. + name: A name prefix for the returned tensor (optional). + out_type: `dtype` for the returned tensor. Defaults to + `self.row_splits.dtype`. + + Returns: + An integer `Tensor` (`dtype=self.row_splits.dtype`). If `axis` is not + specified, then `output` is a vector with + `output.shape=[self.shape.ndims]`. If `axis` is a scalar, then the + `output` is a scalar. If `axis` is a vector, then `output` is a vector, + where `output[i]` is the bounding size for dimension `axis[i]`. + + #### Example: + + >>> rt = tf.ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]]) + >>> rt.bounding_shape().numpy() + array([5, 4]) + + """ + if out_type is None: + out_type = self._row_partition.dtype + else: + out_type = dtypes.as_dtype(out_type) + with ops.name_scope(name, "RaggedBoundingBox", [self, axis]): + nested_splits = self.nested_row_splits + rt_flat_values = self.flat_values + + # Optimized special cases for when axis=0 or axis=1: + if isinstance(axis, int): + if axis == 0: + return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1 + elif axis == 1: + result = math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0) + if out_type != self._row_partition.dtype: + result = math_ops.cast(result, out_type) + return result + + splits_shape = array_ops.shape(self.row_splits, out_type=out_type) + flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type) + + ragged_dimensions = [splits_shape[0] - 1] + [ + math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0) + for splits in nested_splits + ] + inner_dimensions = flat_values_shape[1:] + + if out_type != self._row_partition.dtype: + ragged_dimensions = [ + math_ops.cast(d, out_type) for d in ragged_dimensions + ] + bbox = array_ops.concat( + [array_ops_stack.stack(ragged_dimensions), inner_dimensions], axis=0) + return bbox if axis is None else array_ops.gather(bbox, axis) + + #============================================================================= + # Transformation + #============================================================================= + + def with_values(self, new_values): + """Returns a copy of `self` with `values` replaced by `new_value`. + + Preserves cached row-partitioning tensors such as `self.cached_nrows` and + `self.cached_value_rowids` if they have values. + + Args: + new_values: Potentially ragged tensor to use as the `values` for the + returned `RaggedTensor`. Must have `rank > 0`, and must have the same + number of rows as `self.values`. + + Returns: + A `RaggedTensor`. `result.rank = 1 + new_values.rank`. + `result.ragged_rank = 1 + new_values.ragged_rank` + """ + new_values = _convert_to_ragged_tensor_values(new_values) + new_values.shape.with_rank_at_least(1) + self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) + if (isinstance(new_values, RaggedTensor) and + self._row_partition.dtype != new_values.row_splits.dtype): + if not ragged_config.auto_cast_partition_dtype(): + raise ValueError("self and new_values have mismatched row_splits " + "dtypes; use RaggedTensor.with_row_splits_dtype() to " + "convert them to compatible dtypes.") + new_values = new_values.with_row_splits_dtype(dtypes.int64) + return self.with_row_splits_dtype(dtypes.int64).with_values(new_values) + return RaggedTensor( + values=new_values, row_partition=self._row_partition, internal=True) + + def with_flat_values(self, new_values): + """Returns a copy of `self` with `flat_values` replaced by `new_value`. + + Preserves cached row-partitioning tensors such as `self.cached_nrows` and + `self.cached_value_rowids` if they have values. + + Args: + new_values: Potentially ragged tensor that should replace + `self.flat_values`. Must have `rank > 0`, and must have the same number + of rows as `self.flat_values`. + + Returns: + A `RaggedTensor`. + `result.rank = self.ragged_rank + new_values.rank`. + `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`. + """ + if isinstance(self._values, RaggedTensor): + return self.with_values(self.values.with_flat_values(new_values)) + else: + new_values = _convert_to_ragged_tensor_values(new_values) + return self.with_values(new_values) + + def with_row_splits_dtype(self, dtype): + """Returns a copy of this RaggedTensor with the given `row_splits` dtype. + + For RaggedTensors with multiple ragged dimensions, the `row_splits` for all + nested `RaggedTensor` objects are cast to the given dtype. + + Args: + dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`. + + Returns: + A copy of this RaggedTensor, with the `row_splits` cast to the given + type. + """ + dtype = dtypes.as_dtype(dtype) + if dtype not in (dtypes.int32, dtypes.int64): + raise ValueError(f"Argument `row_splits` dtype must be int32 or int64. " + f"Received {dtype}.") + if self._row_partition.dtype == dtype: + return self + current_values = self._values + if isinstance(current_values, RaggedTensor): + return RaggedTensor( + values=current_values.with_row_splits_dtype(dtype), + row_partition=self._row_partition.with_dtype(dtype), + internal=True) + else: + return RaggedTensor( + values=current_values, + row_partition=self._row_partition.with_dtype(dtype), + internal=True) + + def merge_dims(self, outer_axis, inner_axis): + """Merges outer_axis...inner_axis into a single dimension. + + Returns a copy of this RaggedTensor with the specified range of dimensions + flattened into a single dimension, with elements in row-major order. + + #### Examples: + + >>> rt = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]]) + >>> print(rt.merge_dims(0, 1)) + + >>> print(rt.merge_dims(1, 2)) + + >>> print(rt.merge_dims(0, 2)) + tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32) + + To mimic the behavior of `np.flatten` (which flattens all dimensions), use + `rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which + flattens all dimensions except the outermost batch dimension), use + `rt.merge_dims(1, -1)`. + + Args: + outer_axis: `int`: The first dimension in the range of dimensions to + merge. May be negative if `self.shape.rank` is statically known. + inner_axis: `int`: The last dimension in the range of dimensions to merge. + May be negative if `self.shape.rank` is statically known. + + Returns: + A copy of this tensor, with the specified dimensions merged into a + single dimension. The shape of the returned tensor will be + `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` + is the total number of slices in the merged dimensions. + """ + outer_axis = array_ops.get_positive_axis( + outer_axis, + self.shape.rank, + axis_name="outer_axis", + ndims_name="rank(self)") + inner_axis = array_ops.get_positive_axis( + inner_axis, + self.shape.rank, + axis_name="inner_axis", + ndims_name="rank(self)") + if not outer_axis <= inner_axis: + raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or " + f"equal to inner_axis ({inner_axis}).") + return merge_dims(self, outer_axis, inner_axis) + + def _set_shape(self, shape): + """Updates the static shape of `self` to be `shape`. + + * If a dimension of `shape` has known rank, and is encoded via + partitioning, then this will update the corresponding partition to + define `_uniform_row_length` and `nrows`. + * If a dimension of `shape` has a known rank, and is encoded as one + of the `flat_values` dimensions, then `flat_values.set_shape()` will + be used to update its shape. + + Warning: Using this method to assert an incorrect shape for a RaggedTensor + (i.e., one that's not consistent with its actual shape) can cause + segmentation faults and very difficult-to-diagnose behavior. Only use this + method if you are certain that the shape is correct. + + Args: + shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`. + """ + # TODO(edloper): Refactor this to not directly access private members + # of RowPartition. + # pylint: disable=protected-access + + shape = tensor_shape.as_shape(shape) + if shape.rank is None: + return # Nothing to do. + + shape = shape.as_list() + + # Outermost dimension + if shape[0] is not None: + self._row_partition._row_splits.set_shape(shape[0] + 1) + + # Partitioned dimensions + dtype = self._row_partition.dtype + for i, partition in enumerate(self._nested_row_partitions): + size = shape[i + 1] + if size is not None: + if partition._uniform_row_length is not None: + old_row_length = tensor_util.constant_value( + partition._uniform_row_length) + if old_row_length is not None: + if size == old_row_length: + continue # already have shape info for this axis. + else: + raise ValueError(f"Inconsistent size for axis {i + 1}: " + f"{old_row_length} vs. {size}.") + partition._uniform_row_length = ops.convert_to_tensor(size, dtype) + if partition._nrows is None: + partition._nrows = array_ops.size( + partition._row_splits, out_type=dtype) - 1 + + # self.flat_values could be a CompositeTensor and doesn't have set_shape. + if hasattr(self.flat_values, "set_shape"): + # Inner dimensions + flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:]) + self.flat_values.set_shape(flat_shape) + + #============================================================================= + # Tensor Type Conversions + #============================================================================= + + @classmethod + @dispatch.add_dispatch_support + def from_tensor(cls, + tensor, + lengths=None, + padding=None, + ragged_rank=1, + name=None, + row_splits_dtype=dtypes.int64): + """Converts a `tf.Tensor` into a `RaggedTensor`. + + The set of absent/default values may be specified using a vector of lengths + or a padding value (but not both). If `lengths` is specified, then the + output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If + 'lengths' is a list of lists or tuple of lists, those lists will be used + as nested row lengths. If `padding` is specified, then any row *suffix* + consisting entirely of `padding` will be excluded from the returned + `RaggedTensor`. If neither `lengths` nor `padding` is specified, then the + returned `RaggedTensor` will have no absent/default values. + + Examples: + + >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]]) + >>> tf.RaggedTensor.from_tensor(dt) + + >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3]) + + + >>> tf.RaggedTensor.from_tensor(dt, padding=0) + + + >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]], + ... [[0, 0], [3, 0], [0, 0]], + ... [[6, 0], [0, 0], [0, 0]]]) + >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1])) + + + Args: + tensor: The `Tensor` to convert. Must have rank `ragged_rank + 1` or + higher. + lengths: An optional set of row lengths, specified using a 1-D integer + `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows + in `tensor`). If specified, then `output[row]` will contain + `tensor[row][:lengths[row]]`. Negative lengths are treated as zero. You + may optionally pass a list or tuple of lengths to this argument, which + will be used as nested row lengths to construct a ragged tensor with + multiple ragged dimensions. + padding: An optional padding value. If specified, then any row suffix + consisting entirely of `padding` will be excluded from the returned + RaggedTensor. `padding` is a `Tensor` with the same dtype as `tensor` + and with `shape=tensor.shape[ragged_rank + 1:]`. + ragged_rank: Integer specifying the ragged rank for the returned + `RaggedTensor`. Must be greater than zero. + name: A name prefix for the returned tensors (optional). + row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` + tensor. One of `tf.int32` or `tf.int64`. + + Returns: + A `RaggedTensor` with the specified `ragged_rank`. The shape of the + returned ragged tensor is compatible with the shape of `tensor`. + + Raises: + ValueError: If both `lengths` and `padding` are specified. + ValueError: If the rank of `tensor` is 0 or 1. + """ + row_splits_dtype = dtypes.as_dtype(row_splits_dtype) + if lengths is not None and padding is not None: + raise ValueError("Specify argument `lengths` or `padding`, but not both.") + if not isinstance(ragged_rank, int): + raise TypeError(f"Argument `ragged_rank` must be an int. " + f"Received {ragged_rank}.") + if ragged_rank <= 0: + raise ValueError(f"Argument `ragged_rank` must be greater than 0. " + f"Received {ragged_rank}.") + + with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]): + tensor = ops.convert_to_tensor(tensor, name="tensor") + if tensor.shape.rank is not None and tensor.shape.rank < 2: + raise ValueError(f"The rank of a RaggedTensor must be greater than 1, " + f"i.e., a list of scalars won't have ragged " + f"dimensions. Received argument `tensor` with rank " + f"{tensor.shape.rank}.") + tensor.shape.with_rank_at_least(ragged_rank + 1) + input_shape = array_ops.shape(tensor, out_type=row_splits_dtype) + ncols = input_shape[1] + + # Handle nested row lengths. + if (lengths is not None and isinstance(lengths, (list, tuple)) and + len(lengths) and not isinstance(lengths[0], (int, float))): + if ragged_rank not in (1, len(lengths)): + # Note: we accept `ragged_rank=1` here because it's the default value; + # i.e., if the user passes in a tuple of lengths, but doesn't specify + # ragged_rank, then we should use that tuple to determine ragged_rank. + # We only want to complain if they pass in an explicit ragged_rank + # that doesn't match len(lengths). + raise ValueError(f"If Argument `lengths` is a tuple of row_lengths, " + f"argument `ragged_rank` must be " + f"len(lengths): {len(lengths)}. Received " + f"ragged_rank: {ragged_rank}.") + # Rather than reconstructing the tensor mask directly, we can + # recreate it as a boolean RaggedTensor, then densify that and use + # that as the mask to clear out the unused data in the passed tensor. + tensor.shape.with_rank_at_least(len(lengths) + 1) + num_tokens = math_ops.reduce_sum(lengths[-1]) + ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool) + ragged_mask = cls.from_nested_row_lengths( + ones_mask, lengths, validate=False) + dense_ragged_mask = ragged_mask.to_tensor(default_value=False) + masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask) + return cls.from_nested_row_lengths(masked_data, lengths, validate=False) + + # Handle ragged_rank>1 via recursion: + # If the output should have multiple ragged dimensions, then first + # flatten the tensor to eliminate all but the last ragged dimension, + # and recursively convert that flattened tensor. Then add on the splits + # for the dimensions that we flattened out. + if ragged_rank > 1: + if tensor.shape.is_fully_defined(): + input_shape = tensor.shape.as_list() + # The total number of elements in each dimension. E.g., if + # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total. + dim_size = np.cumprod(input_shape) + new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:] + else: + dim_size = math_ops.cumprod(input_shape) + new_shape = array_ops.concat( + [[dim_size[ragged_rank - 1]], input_shape[ragged_rank:]], axis=0) + flattened = array_ops.reshape(tensor, new_shape) + result = cls.from_tensor( + flattened, lengths, padding, row_splits_dtype=row_splits_dtype) + + for axis in range(ragged_rank - 1, 0, -1): + dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value + if dim_len is None: + dim_len = input_shape[axis] + else: + dim_len = constant_op.constant(dim_len, row_splits_dtype) + result = RaggedTensor.from_uniform_row_length( + values=result, + uniform_row_length=dim_len, + nrows=dim_size[axis - 1], + validate=False) + return result + + # If padding was specified, then use it to find row lengths. + if padding is not None: + padding = ops.convert_to_tensor( + padding, name="padding", dtype=tensor.dtype) + padding.shape.assert_is_compatible_with(tensor.shape[2:]) + + # Find places where the padding is equal to the tensor. (This will + # broadcast `padding` across the outermost 2 dimensions of `tensor`, + # so `has_default_value.shape = tensor.shape`.) + has_default_value = math_ops.equal(padding, tensor) + + # If the padding isn't a scalar, then require that all values in the + # padding match each item in the tensor. After this block of code, + # `has_default.shape = tensor.shape[:2]`. (Unfortunately, we can't just + # use reduce_all for both cases, becaue when you pass an empty `axis` + # list to reduce_all, it reduces all axes; but we want it to reduce no + # axes -- i.e., to be a no-op.) + tensor_rank = array_ops.rank(tensor) + reduce_axis = math_ops.range(2, tensor_rank) + has_default = cond.cond( + tensor_rank > 2, + lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis), + lambda: has_default_value) + has_default.set_shape(tensor_shape.TensorShape([None, None])) + has_default.set_shape(tensor.shape[:2]) + + # Use has_default to find the length of each row: for each + # non-default item in a row, calculate the length that the row needs to + # have to include that item; and then take the max of those values + # (across each row). + has_nondefault = math_ops.logical_not(has_default) + has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype) + length_for_nondefault_value = ( + has_nondefault * + array_ops.expand_dims(math_ops.range(1, ncols + 1), 0)) + lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1) + + if lengths is not None: + # If we have lengths (either directly supplied, or computed from + # paddings), then use those to construct splits; and then use masking + # to get the corresponding values. + lengths = ragged_util.convert_to_int_tensor(lengths, "lengths", + row_splits_dtype) + lengths.shape.assert_has_rank(1) + lengths = math_ops.minimum(lengths, ncols) + lengths = math_ops.maximum(lengths, 0) + limits = math_ops.cumsum(lengths) + splits = array_ops.concat( + [array_ops.zeros([1], row_splits_dtype), limits], axis=0) + mask = array_ops.sequence_mask(lengths, maxlen=ncols) + values = array_ops.boolean_mask(tensor, mask) + return cls.from_row_splits(values, splits, validate=False) + + # If neither padding nor lengths were specified, then create a splits + # vector that contains no default values, and reshape the input tensor + # to form the values for the RaggedTensor. + values_shape = array_ops.concat( + [[input_shape[0] * input_shape[1]], input_shape[2:]], axis=0) + values = array_ops.reshape(tensor, values_shape) + const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value + const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value + if const_nrows is not None: + nrows = constant_op.constant(const_nrows, row_splits_dtype) + else: + nrows = input_shape[0] + if const_ncols is not None: + ncols = constant_op.constant(const_ncols, row_splits_dtype) + else: + ncols = input_shape[1] + return RaggedTensor.from_uniform_row_length( + values=values, uniform_row_length=ncols, nrows=nrows, validate=False) + + def to_tensor(self, default_value=None, name=None, shape=None): + """Converts this `RaggedTensor` into a `tf.Tensor`. + + If `shape` is specified, then the result is padded and/or truncated to + the specified shape. + + Examples: + + >>> rt = tf.ragged.constant([[9, 8, 7], [], [6, 5], [4]]) + >>> print(rt.to_tensor()) + tf.Tensor( + [[9 8 7] [0 0 0] [6 5 0] [4 0 0]], shape=(4, 3), dtype=int32) + >>> print(rt.to_tensor(shape=[5, 2])) + tf.Tensor( + [[9 8] [0 0] [6 5] [4 0] [0 0]], shape=(5, 2), dtype=int32) + + Args: + default_value: Value to set for indices not specified in `self`. Defaults + to zero. `default_value` must be broadcastable to + `self.shape[self.ragged_rank + 1:]`. + name: A name prefix for the returned tensors (optional). + shape: The shape of the resulting dense tensor. In particular, + `result.shape[i]` is `shape[i]` (if `shape[i]` is not None), or + `self.bounding_shape(i)` (otherwise).`shape.rank` must be `None` or + equal to `self.rank`. + + Returns: + A `Tensor` with shape `ragged.bounding_shape(self)` and the + values specified by the non-empty values in `self`. Empty values are + assigned `default_value`. + """ + with ops.name_scope(name, "RaggedToTensor", [self, default_value, shape]): + if default_value is not None: + default_value = ops.convert_to_tensor( + default_value, name="default_value", dtype=self.dtype) + type_tensor_pairs = _get_row_partition_type_tensor_pairs(self) + row_partition_types = [x[0] for x in type_tensor_pairs] + row_partition_tensors = [x[1] for x in type_tensor_pairs] + if default_value is None: + default_value = array_ops.zeros((), self.dtype) + + if (isinstance(shape, (list, tuple)) and + any(isinstance(v, tensor_lib.Tensor) for v in shape) and + all(isinstance(v, (int, tensor_lib.Tensor)) for v in shape)): + shape = array_ops_stack.stack(shape) + + shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype) + tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor( + shape=shape_tensor, + values=self.flat_values, + default_value=default_value, + row_partition_types=row_partition_types, + row_partition_tensors=row_partition_tensors, + ) + + ragged_shape = self.shape + + if ragged_shape.rank is not None and not isinstance( + shape, tensor_lib.Tensor + ): + # Merged self.shape and shape, favoring the second one as it takes + # into account potential padding added to the output. + shape = tensor_shape.as_shape(shape) + if shape.rank is None: + output_shape = ragged_shape + else: + # At this point we can assume that hshape.rank == ragged_shape.rank + # because otherwise it would have failed earlier. + output_shape = [ + s1 if s1 is not None else s2 + for (s1, s2) in zip(shape.as_list(), ragged_shape.as_list()) + ] + tensor.set_shape(output_shape) + + return tensor + + @classmethod + @dispatch.add_dispatch_support + def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64): + """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`. + + Each row of the `output` `RaggedTensor` will contain the explicit values + from the same row in `st_input`. `st_input` must be ragged-right. If not + it is not ragged-right, then an error will be generated. + + Example: + + >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]] + >>> st = tf.sparse.SparseTensor(indices=indices, + ... values=[1, 2, 3, 4, 5], + ... dense_shape=[4, 3]) + >>> tf.RaggedTensor.from_sparse(st).to_list() + [[1, 2, 3], [4], [], [5]] + + Currently, only two-dimensional `SparseTensors` are supported. + + Args: + st_input: The sparse tensor to convert. Must have rank 2. + name: A name prefix for the returned tensors (optional). + row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits` + tensor. One of `tf.int32` or `tf.int64`. + + Returns: + A `RaggedTensor` with the same values as `st_input`. + `output.ragged_rank = rank(st_input) - 1`. + `output.shape = [st_input.dense_shape[0], None]`. + Raises: + ValueError: If the number of dimensions in `st_input` is not known + statically, or is not two. + """ + row_splits_dtype = dtypes.as_dtype(row_splits_dtype) + if not sparse_tensor.is_sparse(st_input): + raise TypeError(f"Argument `st_input` must be of type SparseTensor, but " + f"is of type {type(st_input).__name__}.") + with ops.name_scope(name, "RaggedFromSparse", [st_input]): + st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor( + st_input, name="st_input") + + if st_input.dense_shape.shape.ndims is None: + static_rank_from_dense_shape = None + else: + static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value + + if st_input.indices.shape.ndims is None: + static_rank_from_indices = None + else: + static_rank_from_indices = st_input.indices.shape.dims[1].value + + if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2: + raise ValueError("rank(st_input) must be 2.") + + with ops.control_dependencies( + _assert_sparse_indices_are_ragged_right(st_input.indices)): + # Treat sparse row indices as segment ids to generate a splits tensor + # thta we can pair with the sparse tensor values. (Ignore sparse column + # indices.) + segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype) + num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype) + return cls.from_value_rowids( + st_input.values, segment_ids, num_segments, validate=False) + + def to_sparse(self, name=None): + """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`. + + Example: + + >>> rt = tf.ragged.constant([[1, 2, 3], [4], [], [5, 6]]) + >>> print(rt.to_sparse()) + SparseTensor(indices=tf.Tensor( + [[0 0] [0 1] [0 2] [1 0] [3 0] [3 1]], + shape=(6, 2), dtype=int64), + values=tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32), + dense_shape=tf.Tensor([4 3], shape=(2,), dtype=int64)) + + Args: + name: A name prefix for the returned tensors (optional). + + Returns: + A SparseTensor with the same values as `self`. + """ + with ops.name_scope(name, "RaggedToSparse", [self]): + result = gen_ragged_conversion_ops.ragged_tensor_to_sparse( + self.nested_row_splits, self.flat_values, name=name) + return sparse_tensor.SparseTensor(result.sparse_indices, + result.sparse_values, + result.sparse_dense_shape) + + @classmethod + def _from_variant(cls, + variant, + dtype, + output_ragged_rank, + input_ragged_rank=None, + row_splits_dtype=dtypes.int64, + name=None): + """Converts a `variant` Tensor into a `RaggedTensor`. + + The input `variant` could be a scalar, meaning it encodes a single + `RaggedTensor` with ragged_rank `output_ragged_rank`. Alternatively it could + have an arbitrary rank, in which case each element is decoded into a + `RaggedTensor` with ragged_rank `input_ragged_rank` and these are then + stacked according to the input shape to output a single `RaggedTensor` + with ragged_rank `output_ragged_rank`. If `input_ragged_rank` is not + provided, it is inferred dynamically as `output_ragged_rank` - + `rank(variant)`. If `input_ragged_rank` is provided, the following must be + true: `output_ragged_rank` = `input_ragged_rank` + `rank(variant)`. + + Example: + + >>> rt = tf.ragged.constant([[0], [1, 2]]) + >>> et = rt._to_variant() + >>> stacked_et = tf.stack([et, et]) + >>> tf.RaggedTensor._from_variant( # scalar input. + ... et, dtype=tf.int32, output_ragged_rank=1).to_list() + [[0], [1, 2]] + >>> tf.RaggedTensor._from_variant( # batched input. + ... stacked_et, dtype=tf.int32, output_ragged_rank=2).to_list() + [[[0], [1, 2]], [[0], [1, 2]]] + + Args: + variant: A `variant` Tensor representing an encoded (possibly + nested-batched) `RaggedTensor`. + dtype: The dtype of the encoded `RaggedTensor`. + output_ragged_rank: The expected ragged rank of the output `RaggedTensor`. + input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is + optional and inferred dynamically if not provided. + row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One + of `tf.int32` or `tf.int64`. + name: A name prefix for the returned tensors (optional). + + Returns: + A `RaggedTensor` of dtype `dtype` and ragged rank `output_ragged_rank`. + + Raises: + ValueError: If the input rank is known, `input_ragged_rank` is provided + and `output_ragged_rank` = `input_ragged_rank` + `rank(variant)` does + not hold. + """ + variant = ops.convert_to_tensor( + variant, name="variant", dtype=dtypes.variant) + if (variant.shape.ndims is not None and input_ragged_rank is not None and + output_ragged_rank != input_ragged_rank + variant.shape.ndims): + raise ValueError( + f"Argument `output_ragged_rank` ({output_ragged_rank}) must be equal " + f"to `input_ragged_rank` + `variant.shape.ndims` " + f"({input_ragged_rank} + {variant.shape.ndims}).") + input_ragged_rank = -1 if input_ragged_rank is None else input_ragged_rank + with ops.name_scope( + name, "RaggedFromVariant", + [variant, dtype, input_ragged_rank, output_ragged_rank]): + result = gen_ragged_conversion_ops.ragged_tensor_from_variant( + variant, input_ragged_rank, max(output_ragged_rank, 0), dtype, + row_splits_dtype, name) + return cls.from_nested_row_splits( + result.output_dense_values, + result.output_nested_splits, + validate=False) + + def _to_variant(self, batched_input=False, name=None): + """Converts this `RaggedTensor` into a `variant` Tensor. + + If `batched_input` is `True`, then the `RaggedTensor` is unbatched along the + zero-th dimension, each component `RaggedTensor` is encoded into a scalar + `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor. + If `batched_input` is `False`, then the `RaggedTensor` is encoded as is and + a scalar `variant` Tensor is returned. + + Example: + >>> rt = tf.ragged.constant([[[0]], [[1]], [[2]]]) + >>> rt._to_variant().shape.as_list() + [] + >>> rt._to_variant(batched_input=True).shape.as_list() + [3] + + Args: + batched_input: If `True`, the `RaggedTensor` is unbatched and converted to + a `variant` vector. Set to `False` by default. + name: A name prefix for the returned tensors (optional). + + Returns: + A `variant` Tensor that encodes this `RaggedTensor`. + """ + with ops.name_scope(name, "RaggedToVariant", [self, batched_input]): + return gen_ragged_conversion_ops.ragged_tensor_to_variant( + self.nested_row_splits, self.flat_values, batched_input, name) + + #============================================================================= + # String Encoding + #============================================================================= + def __repr__(self): + if self._is_eager(): + # The np.array2string in _formatter provides a separator argument, but + # doesn't handle recursive calls correctly. The np.printoptions handles + # recursive calls correctly, but doesn't provide a separator argument. + # Combines them together to print elements separated by comma, while + # avoiding the redundant array prefixes and dtypes. For example, + # the value of tf.ragged.constant([[1, 2], [3, 4]]) will look like + # + # [[1, 2], + # [3, 4]] + with np.printoptions(formatter={"all": _formatter}): + value_text = _formatter(self.numpy()) + return f"" + else: + return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self.values, + self.row_splits) + + #============================================================================= + # Eager Execution Mode + #============================================================================= + + def numpy(self): + """Returns a numpy `array` with the values for this `RaggedTensor`. + + Requires that this `RaggedTensor` was constructed in eager execution mode. + + Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and + `rank=1`, where each element is a single row. + + #### Examples + + In the following example, the value returned by `RaggedTensor.numpy()` + contains three numpy `array` objects: one for each row (with `rank=1` and + `dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`): + + >>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy() + array([array([1, 2, 3]), array([4, 5])], dtype=object) + + Uniform dimensions are encoded using multidimensional numpy `array`s. In + the following example, the value returned by `RaggedTensor.numpy()` contains + a single numpy `array` object, with `rank=2` and `dtype=int64`: + + >>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy() + array([[1, 2, 3], [4, 5, 6]]) + + Returns: + A numpy `array`. + """ + if not self._is_eager(): + raise ValueError("RaggedTensor.numpy() is only supported in eager mode.") + values = self.values.numpy() + splits = self.row_splits.numpy() + rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] + if not rows: + return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype) + # Note: if `rows` have ragged lengths, then they will be stored in a + # np.ndarray with dtype=object and rank=1. If they have uniform lengths, + # they will be combined into a single np.ndarray with dtype=row.dtype and + # rank=row.rank+1. + # + # Manually set dtype as numpy now complains when given ragged rows. + has_variable_length_rows = any(len(row) != len(rows[0]) for row in rows) + dtype = np.object_ if has_variable_length_rows else None + return np.array(rows, dtype=dtype) + + def to_list(self): + """Returns a nested Python `list` with the values for this `RaggedTensor`. + + Requires that `rt` was constructed in eager execution mode. + + Returns: + A nested Python `list`. + """ + if not isinstance(self.row_splits, ops.EagerTensor): + raise ValueError("to_list can only be used in eager mode.") + row_splits = self.row_splits.numpy().tolist() + values = self.values + + if isinstance(values, RaggedTensor): + return [ + values[row_splits[i]:row_splits[i + 1]].to_list() + for i in range(len(row_splits) - 1) + ] + else: + # Convert values to a Python list. + if hasattr(values, "numpy"): + values_as_list = values.numpy().tolist() + elif hasattr(values, "to_list"): + values_as_list = values.to_list() + else: + raise ValueError("values must be convertible to a list") + + return [ + values_as_list[row_splits[i]:row_splits[i + 1]] + for i in range(len(row_splits) - 1) + ] + + def _eager_value(self): + """Returns a RaggedTensorValue for self. Requires self._is_eager()=true.""" + value = self.flat_values.numpy() + for row_splits in reversed(self.nested_row_splits): + value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy()) + return value + + def _is_eager(self): + """Returns True if values & row_splits Tensors are all `EagerTensor`s.""" + rt = self + while isinstance(rt, RaggedTensor): + if not isinstance(rt.row_splits, ops.EagerTensor): + return False + rt = rt.values + return isinstance(rt, ops.EagerTensor) + + #============================================================================= + # Operators + #============================================================================= + # To avoid circular dependencies, we define stub methods for operators here, + # and then override them when the ragged_operators module is imported. + + def _overloaded_operator(name): # pylint: disable=no-self-argument + + def stub(*args, **kwargs): + del args, kwargs + raise ValueError( + f"You must import 'tensorflow.python.ops.ragged.ragged_ops' " + f"before using RaggedTensor.{name}.") + + return stub + + __getitem__ = _overloaded_operator("__getitem__") + __ge__ = _overloaded_operator("__ge__") + __gt__ = _overloaded_operator("__gt__") + __le__ = _overloaded_operator("__le__") + __lt__ = _overloaded_operator("__lt__") + __and__ = _overloaded_operator("__and__") + __rand__ = _overloaded_operator("__rand__") + __invert__ = _overloaded_operator("__invert__") + __ror__ = _overloaded_operator("__ror__") + __or__ = _overloaded_operator("__or__") + __xor__ = _overloaded_operator("__xor__") + __rxor__ = _overloaded_operator("__rxor__") + __abs__ = _overloaded_operator("__abs__") + __add__ = _overloaded_operator("__add__") + __radd__ = _overloaded_operator("__radd__") + __div__ = _overloaded_operator("__div__") + __rdiv__ = _overloaded_operator("__rdiv__") + __floordiv__ = _overloaded_operator("__floordiv__") + __rfloordiv__ = _overloaded_operator("__rfloordiv__") + __mod__ = _overloaded_operator("__mod__") + __rmod__ = _overloaded_operator("__rmod__") + __mul__ = _overloaded_operator("__mul__") + __rmul__ = _overloaded_operator("__rmul__") + __neg__ = _overloaded_operator("__neg__") + __pow__ = _overloaded_operator("__pow__") + __rpow__ = _overloaded_operator("__rpow__") + __sub__ = _overloaded_operator("__sub__") + __rsub__ = _overloaded_operator("__rsub__") + __truediv__ = _overloaded_operator("__truediv__") + __rtruediv__ = _overloaded_operator("__rtruediv__") + del _overloaded_operator + + #============================================================================= + # Name Scope + #============================================================================= + + # This private function is used by ops.name_scope to ensure that all of the + # input tensors for the scope belong to the same graph. Defining this means + # that you may include `RaggedTensor` objects in the name_scope `values` + # list. + def _as_graph_element(self): + """Convert `self` to a graph element.""" + values = self.values + while isinstance(values, RaggedTensor): + values = values.values + return values + + #============================================================================= + # Composite Tensor + #============================================================================= + + @property + def _type_spec(self): + return RaggedTensorSpec.from_value(self) + + def _shape_invariant_to_type_spec(self, shape): + return RaggedTensorSpec(shape, self.dtype, self.ragged_rank, + self.row_splits.dtype) + + def consumers(self): + return self._consumers() + + __composite_gradient__ = ( + composite_tensor_gradient.WithValuesCompositeTensorGradient()) + + +def is_ragged(value): + """Returns true if `value` is a ragged tensor or ragged tensor value.""" + return isinstance(value, + (RaggedTensor, ragged_tensor_value.RaggedTensorValue)) + + +def match_row_splits_dtypes(*tensors, **kwargs): + """Return a copy of `tensors` with row_splits all having the same dtype. + + Args: + *tensors: A list of Tensors or RaggedTensors. + **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors), + where `dtype` is the data type used by row-splits, and `tensors` is the + converted list of `Tensors` and `RaggedTensors`. + + Returns: + The converted list of `Tensors` and `RaggedTensors`. + """ + return_dtype = kwargs.pop("return_dtype", False) + if kwargs: + raise ValueError(f"Unexpected keyword args {kwargs}.") + + has_int32 = False + has_int64 = False + for tensor in tensors: + if isinstance(tensor, RaggedTensor): + if tensor.row_splits.dtype == dtypes.int32: + has_int32 = True + else: + has_int64 = True + + if has_int32 and has_int64: + if not ragged_config.auto_cast_partition_dtype(): + raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; " + "use RaggedTensor.with_row_splits_dtype() to convert " + "them to compatible dtypes.") + dtype = dtypes.int64 + tensors = tuple( + t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor + ) else t + for t in tensors) + + elif has_int32: + dtype = dtypes.int32 + else: + dtype = dtypes.int64 + + if return_dtype: + return (dtype, tensors) + else: + return tensors + + +# =============================================================================== +# RaggedTensorSpec +# =============================================================================== +@tf_export("RaggedTensorSpec") +@type_spec_registry.register("tf.RaggedTensorSpec") +class RaggedTensorSpec( + type_spec.BatchableTypeSpec, internal_types.RaggedTensorSpec): + """Type specification for a `tf.RaggedTensor`.""" + + __slots__ = [ + "_shape", "_dtype", "_ragged_rank", "_row_splits_dtype", + "_flat_values_spec" + ] + + @property + def dtype(self): + """The `tf.dtypes.DType` specified by this type for the RaggedTensor. + + Examples: + + >>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string) + >>> tf.type_spec_from_value(rt).dtype + tf.string + + Returns: + A `tf.dtypes.DType` of the values in the RaggedTensor. + """ + return self._dtype + + @property + def shape(self): + """The statically known shape of the RaggedTensor. + + Examples: + + >>> rt = tf.ragged.constant([[0], [1, 2]]) + >>> tf.type_spec_from_value(rt).shape + TensorShape([2, None]) + + >>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1) + >>> tf.type_spec_from_value(rt).shape + TensorShape([2, None, 2]) + + Returns: + A `tf.TensorShape` containing the statically known shape of the + RaggedTensor. Ragged dimensions have a size of `None`. + """ + return self._shape + + @property + def ragged_rank(self): + """The number of times the RaggedTensor's flat_values is partitioned. + + Defaults to `shape.ndims - 1`. + + Examples: + + >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]]) + >>> tf.type_spec_from_value(values).ragged_rank + 1 + + >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2) + >>> tf.type_spec_from_value(rt1).ragged_rank + 2 + + Returns: + A Python `int` indicating the number of times the underlying `flat_values` + Tensor has been partitioned to add a new dimension. + I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`. + """ + return self._ragged_rank + + @property + def row_splits_dtype(self): + """The `tf.dtypes.DType` of the RaggedTensor's `row_splits`. + + Examples: + + >>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64) + >>> tf.type_spec_from_value(rt).row_splits_dtype + tf.int64 + + Returns: + A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One + of `tf.int32` or `tf.int64`. + """ + return self._row_splits_dtype + + @property + def flat_values_spec(self): + """The `TypeSpec` of the flat_values of RaggedTensor. + + Returns: + - The TypeSpec of flat_values. + - None when the flat_values is a Tensor. + """ + return self._flat_values_spec + + @property + def value_type(self): + return RaggedTensor if self._ragged_rank > 0 else tensor_lib.Tensor + + def __init__(self, + shape=None, + dtype=dtypes.float32, + ragged_rank=None, + row_splits_dtype=dtypes.int64, + flat_values_spec=None): + """Constructs a type specification for a `tf.RaggedTensor`. + + Args: + shape: The shape of the RaggedTensor, or `None` to allow any shape. If a + shape is specified, then all ragged dimensions must have size `None`. + dtype: `tf.DType` of values in the RaggedTensor. + ragged_rank: Python integer, the number of times the RaggedTensor's + flat_values is partitioned. Defaults to `shape.ndims - 1`. + row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One + of `tf.int32` or `tf.int64`. + flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be + provided when the flat_values is a CompositeTensor rather then Tensor. + If both `dtype` and `flat_values_spec` and are provided, `dtype` must + be the same as `flat_values_spec.dtype`. (experimental) + """ + self._shape = tensor_shape.as_shape(shape) + self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype) + if flat_values_spec is not None: + if dtype is None: + dtype = flat_values_spec.dtype + elif dtype != flat_values_spec.dtype: + raise ValueError("dtype must be the same as flat_values_spec.dtype") + elif dtype is None: + raise ValueError( + "At least one of dtype or flat_values_spec must be provided") + self._dtype = dtypes.as_dtype(dtype) + self._flat_values_spec = flat_values_spec + + rank = self._shape.ndims + if ragged_rank is None: + if rank is None: + raise ValueError("Must specify ragged_rank or " + "a shape with a known rank.") + ragged_rank = rank - 1 + self._ragged_rank = ragged_rank + if not isinstance(self._ragged_rank, int): + raise TypeError(f"Argument `ragged_rank` must be an int. " + f"Received {ragged_rank}.") + + if rank is not None: + if ragged_rank >= rank: + raise ValueError(f"Argument `ragged_rank` ({ragged_rank}) must be less " + f"than rank ({rank}).") + + def is_compatible_with(self, spec_or_value): + # RaggedTensor with ragged_rank 0 can be compatible with raw flat_values. + if self._ragged_rank == 0: + if self._flat_values_spec is None: + if isinstance( + spec_or_value, (tensor_lib.Tensor, tensor_lib.TensorSpec)): + return tensor_lib.TensorSpec( + self._shape, self._dtype).is_compatible_with(spec_or_value) + elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)): + return self._flat_values_spec.is_compatible_with(spec_or_value) + return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value) + + def _serialize(self): + if self._flat_values_spec is None: + return (self._shape, self._dtype, self._ragged_rank, + self._row_splits_dtype) + else: + return (self._shape, self._dtype, self._ragged_rank, + self._row_splits_dtype, self._flat_values_spec) + + @property + def _component_specs(self): + if self._ragged_rank <= 0: + if self._flat_values_spec is not None: + return [self._flat_values_spec] + else: + return [tensor_lib.TensorSpec(self._shape, self._dtype)] + + flat_values_spec = self._flat_values_spec + if flat_values_spec is None: + flat_values_shape = tensor_shape.TensorShape([None]).concatenate( + self._shape[self._ragged_rank + 1:]) + flat_values_spec = tensor_lib.TensorSpec(flat_values_shape, self._dtype) + outer_dim = tensor_shape.dimension_at_index(self._shape, 0) + outer_splits_shape = [None if outer_dim is None else outer_dim + 1] + inner_splits_spec = tensor_lib.TensorSpec([None], self._row_splits_dtype) + + specs = ([ + flat_values_spec, + tensor_lib.TensorSpec(outer_splits_shape, self._row_splits_dtype) + ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)]) + return specs + + def _to_components(self, value): + if is_ragged(value): + return [value.flat_values] + list(value.nested_row_splits) + else: + return [value] + + def _from_components(self, tensor_list): + result = tensor_list[0] + if (all(isinstance(t, np.ndarray) for t in tensor_list) and + not tf2.enabled()): + for row_splits in reversed(tensor_list[1:]): + result = ragged_tensor_value.RaggedTensorValue(result, row_splits) + else: + if isinstance(tensor_list[0], np.ndarray): + tensor_list = [ops.convert_to_tensor(t) for t in tensor_list] + result = tensor_list[0] + for row_splits in reversed(tensor_list[1:]): + result = RaggedTensor( + result, + RowPartition.from_row_splits(row_splits, validate=False), + internal=True) + if self._shape.ndims is not None: + if isinstance(result, RaggedTensor): + result._set_shape(self._shape) # pylint: disable=protected-access + # TODO(xjun): MaskedTensor doesn't implement set_shape. + if self.flat_values_spec is not None and hasattr(result.flat_values, + "set_shape"): + result.flat_values.set_shape(self.flat_values_spec.shape) + elif isinstance(result, tensor_lib.Tensor): + result.set_shape(self._shape) + return result + + # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops + # to (un)box the component tensors in a way that allows for batching & + # unbatching. + @property + def _flat_tensor_specs(self): + # NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is + # `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of + # boxed `RaggedTensor` objects with shape `(...)` (and batches of batches, + # etc.), so the flat shape must be unknown. + return [tensor_lib.TensorSpec(None, dtypes.variant)] + + def _to_tensor_list(self, value): + # TODO(edloper): Update gen_ragged_conversion_ops that convert to and + # from variant to include all of the row-partitioning tensors. + if self._flat_values_spec is not None: + raise ValueError("Customized value_type is not supported.") + if isinstance(value, RaggedTensor): + if value.ragged_rank != self._ragged_rank: + raise ValueError( + f"Ragged rank of value {value.ragged_rank} does not match " + f"ragged rank of type {self._ragged_rank}.") + # pylint: disable=protected-access + return [value._to_variant(batched_input=False)] + else: + if self._ragged_rank > 0: + raise ValueError( + f"Expected a RaggedTensor if ragged rank={self._ragged_rank}" + f" but got {type(value).__name__}." + ) + return [ + gen_ragged_conversion_ops.ragged_tensor_to_variant( + (), value, batched_input=False) + ] + + def _to_batched_tensor_list(self, value): + if self._flat_values_spec is not None: + raise ValueError("Customized value_type is not supported.") + if isinstance(value, RaggedTensor): + if value.ragged_rank != self._ragged_rank: + raise ValueError( + f"Ragged rank of value {value.ragged_rank} does not match " + f"ragged rank of type {self._ragged_rank}.") + # pylint: disable=protected-access + return [value._to_variant(batched_input=True)] + else: + if self._ragged_rank > 0: + raise ValueError( + f"Expected a RaggedTensor if ragged rank={self._ragged_rank}" + f" but got {type(value).__name__}." + ) + return [ + gen_ragged_conversion_ops.ragged_tensor_to_variant( + rt_nested_splits=(), rt_dense_values=value, batched_input=True) + ] + + def _from_compatible_tensor_list(self, tensor_list): + if self._flat_values_spec is not None: + raise ValueError("Customized value_type is not supported.") + result = RaggedTensor._from_variant( # pylint: disable=protected-access + tensor_list[0], + dtype=self._dtype, + row_splits_dtype=self._row_splits_dtype, + output_ragged_rank=self._ragged_rank) + if self._shape.ndims is not None: + if isinstance(result, RaggedTensor): + result._set_shape(self._shape) # pylint: disable=protected-access + # TODO(xjun): MaskedTensor doesn't implement set_shape. + if self.flat_values_spec is not None and hasattr(self.flat_values, + "set_shape"): + result.flat_values.set_shape(self.flat_values_spec.shape) + else: + result.set_shape(self._shape) + return result + + def _batch(self, batch_size): + if self._flat_values_spec is not None: + raise ValueError("Customized value_type is not supported.") + return RaggedTensorSpec( + tensor_shape.TensorShape([batch_size]).concatenate(self._shape), + self._dtype, self._ragged_rank + 1, self._row_splits_dtype) + + def _unbatch(self): + if self._flat_values_spec is not None: + raise ValueError("Customized value_type is not supported.") + # Note: Negative ragged_rank is allowed here because the dataset could be + # subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is + # consistent. Errors are handled in + # RaggedTensorSpec._from_compatible_tensor_list() + return RaggedTensorSpec(self._shape[1:], self._dtype, self._ragged_rank - 1, + self._row_splits_dtype) + + def _to_legacy_output_types(self): + return self._dtype + + def _to_legacy_output_shapes(self): + return self._shape + + def _to_legacy_output_classes(self): + return self + + @classmethod + def from_value(cls, value): + if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or + isinstance(value.flat_values, tensor_lib.Tensor)): + return cls( + shape=value.shape, + dtype=value.values.dtype, + ragged_rank=value.ragged_rank, + row_splits_dtype=value.row_splits.dtype) + else: + flat_values_spec = type_spec.type_spec_from_value(value.flat_values) + # Relax shape[0] to None, as it is connected to dynamic ragged shapes. + flat_values_spec = flat_values_spec._unbatch()._batch(None) # pylint: disable=protected-access + return cls( + shape=value.shape, + dtype=value.values.dtype, + ragged_rank=value.ragged_rank, + row_splits_dtype=value.row_splits.dtype, + flat_values_spec=flat_values_spec) + + +nested_structure_coder.register_codec( + nested_structure_coder.BuiltInTypeSpecCodec( + RaggedTensorSpec, struct_pb2.TypeSpecProto.RAGGED_TENSOR_SPEC + ) +) + + +type_spec.register_type_spec_from_value_converter( + ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value) + + +# =============================================================================== +# Convert value -> tensor +# =============================================================================== +def convert_to_tensor_or_ragged_tensor(value, + dtype=None, + preferred_dtype=None, + name=None): + """Converts value to a `RaggedTensor` or `Tensor`. + + * If `value` is a `RaggedTensor`, then return it as-is. + * If `value` is a `RaggedTensorValue`, return a corresponding constant + `RaggedTensor`. + * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`. + + Args: + value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has + a registered `Tensor` conversion function. + dtype: Optional element type for the returned tensor. If missing the type + is inferred from the type of `value`. + preferred_dtype: Optional element type for the returned tensor, used when + dtype is None. This argument has no effect if `value` is already a + tensor, or when conversion is not possible. + name: Optional name to use if a new `Tensor` is created. + + Returns: + A `Tensor` or `RaggedTensor`. + """ + if isinstance(value, RaggedTensor): + if dtype and not dtype.is_compatible_with(value.dtype): + raise ValueError(f"Tensor conversion requested dtype {dtype.name} for " + f"RaggedTensor with dtype {value.dtype.name}: {value}.") + return value + elif isinstance(value, ragged_tensor_value.RaggedTensorValue): + with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []): + flat_values = ops.convert_to_tensor( + value=value.flat_values, + dtype=dtype, + dtype_hint=preferred_dtype, + name="flat_values") + return RaggedTensor.from_nested_row_splits( + flat_values, value.nested_row_splits, validate=False) + else: + return tensor_conversion.convert_to_tensor_v2_with_dispatch( + value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name + ) + + +def _convert_to_ragged_tensor_values(value): + """Converts value to supported RaggedTensor value. + + * If `value` is an object of supported value type, then return it as-is. + * Otherwise convert it to Tensor or RaggedTensor. + + Args: + value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor + value types, or an object whose type has a registered `Tensor` conversion + function. + + Returns: + An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor + value types + """ + if _is_supported_ragged_values_type(value): + return value + else: + return convert_to_tensor_or_ragged_tensor(value, name="values") + + +# =============================================================================== +# Register RaggedTensor for use with session.run. +# =============================================================================== +def _ragged_tensor_value_from_components(components): + components = list(components) + value = components.pop() + while components: + value = ragged_tensor_value.RaggedTensorValue(value, components.pop()) + return value + + +def _ragged_tensor_session_fetch(rt): + components = rt.nested_row_splits + (rt.flat_values,) + return (components, _ragged_tensor_value_from_components) + + +def _ragged_tensor_session_feed(feed_key, feed_val): + key_components = feed_key.nested_row_splits + (feed_key.flat_values,) + val_components = feed_val.nested_row_splits + (feed_val.flat_values,) + return zip(key_components, val_components) + + +def _ragged_tensor_session_feed_for_partial_run(feed_key): + return feed_key.nested_row_splits + (feed_key.flat_values,) + + +session.register_session_run_conversion_functions( + RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed, + _ragged_tensor_session_feed_for_partial_run) + + +# =============================================================================== +# RaggedTensorType +# =============================================================================== +class RaggedTensorType: + """Encoding of a static type for a `RaggedTensor`. + + Use this type to express/declare that an output must have the type of + `RaggedTensor`. + """ + + def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64): + """Initializes a RaggedTensorType object. + + Args: + dtype: data type of the `RaggedTensor`'s inner values. + ragged_rank: ragged_rank of the declared `RaggedTensor`. + row_splits_dtype: data type for the `RaggedTensor`'s row splits. + One of: `tf.int32` or `tf.int64`. + """ + row_splits_dtype = dtypes.as_dtype(row_splits_dtype) + self._dtype = dtype + self._ragged_rank = ragged_rank + self._row_splits_dtype = row_splits_dtype + + dtype = property(lambda self: self._dtype) + ragged_rank = property(lambda self: self._ragged_rank) + row_splits_dtype = property(lambda self: self._row_splits_dtype) + + def __repr__(self): + return "RaggedTensorType(%r, %r, %r)" % (self.dtype, self.ragged_rank, + self.row_splits_dtype) + + +# =============================================================================== +# Helper Functions +# =============================================================================== +def _assert_sparse_indices_are_ragged_right(indices): + """Checks that the given SparseTensor.indices tensor is ragged-right. + + Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right + because the entry `[3, 1]` skips a cell. + + Args: + indices: The SparseTensor indices to check. + + Returns: + A list of control dependency op tensors. + """ + index_prefix = indices[:, :-1] + index_suffix = indices[:, -1] + + # Check whether each index is starting a new row in the innermost dimension + # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]). + # (Note: this skips the first index; we will check that separately below.) + index_prefix_changed = math_ops.reduce_any( + math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1) + + # Check two cases: + # * For indices that start a new row: index_suffix[i] must be zero. + # * For indices that continue a row: index_suffix[i] must be equal to + # index_suffix[i-1]+1. + index_ok = array_ops.where( + index_prefix_changed, math_ops.equal(index_suffix[1:], 0), + math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1)) + + # Also check that the very first index didn't skip any cells. The first + # index starts a new row (by definition), so its suffix should be zero. + sparse_indices_are_ragged_right = math_ops.logical_and( + math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)), + math_ops.reduce_all(index_ok)) + + message = [ + "SparseTensor is not right-ragged", "SparseTensor.indices =", indices + ] + return [control_flow_assert.Assert(sparse_indices_are_ragged_right, message)] + + +@ops.RegisterGradient("RaggedTensorToSparse") +def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad, + sparse_values_grad, + unused_sparse_shape_grad): + """Gradient for RaggedTensorToSparse.""" + op_inputs_nested_row_splits = op.inputs[:-1] + op_inputs_flat_values = op.inputs[-1] + + # No gradient for the RaggedTensor's nested_row_splits. + nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits) + + # Gradient for the RaggedTensor's flat_values is formed by reshaping + # the gradient for the SparseTensor's values. + flat_values_shape = array_ops.shape(op_inputs_flat_values) + flat_values_gradient = array_ops.reshape(sparse_values_grad, + flat_values_shape) + + return nested_row_splits_gradient + [flat_values_gradient] + + +def _assert_monotonic_increasing(tensor, message=None): + return check_ops.assert_non_negative( + tensor[1:] - tensor[:-1], message=message) + + +def _assert_zero(tensor, message=None): + return check_ops.assert_equal( + tensor, constant_op.constant(0, dtype=tensor.dtype), message=message) + + +def _nrows(tensor, out_type=dtypes.int32): + if isinstance(tensor, RaggedTensor): + return tensor.nrows(out_type=out_type) + else: + return array_ops.shape(tensor, out_type=out_type)[0] + + +def merge_dims(value, outer_axis, inner_axis): + """Merges value[outer_axis...inner_axis] into a single dimension. + + See `RaggedTensor.merge_dims()` for more details. This helper differs from + `RaggedTensor.merge_dims()` in that `value` may be a dense or ragged tensor. + + Args: + value: A `RaggedTensor` or `Tensor` + outer_axis: `int` + inner_axis: `int` + + Returns: + A flattened `RaggedTensor` or `Tensor`. + """ + if outer_axis == inner_axis: + return value + + # Flatten outer dimensions of a RaggedTensor by just taking its values. + while outer_axis == 0 and isinstance(value, RaggedTensor): + value = value.values + inner_axis -= 1 + if inner_axis == 0: + return value + + # Flatten non-Ragged tensors using tf.reshape(). + if not isinstance(value, RaggedTensor): + if value.shape.is_fully_defined(): + old_shape = value.shape.as_list() + new_shape = old_shape[:outer_axis] + [-1] + old_shape[inner_axis + 1:] + else: + old_shape = array_ops.shape(value) + new_shape = array_ops.concat( + [old_shape[:outer_axis], [-1], old_shape[inner_axis + 1:]], axis=0) + return array_ops.reshape(value, new_shape) + + # Handle outer_axis>1 via recursion. + if outer_axis > 1: + return value.with_values( + merge_dims(value.values, outer_axis - 1, inner_axis - 1)) + + # At this point, we know outer_axis == 1, and value is a RaggedTensor. + # So we need to flatten the values and build a corresponding splits tensor. + new_values = value.values + new_splits = value.row_splits + for axis in range(outer_axis, inner_axis): + if isinstance(new_values, RaggedTensor): + # Flatten a single ragged dimension. + new_splits = array_ops.gather(new_values.row_splits, new_splits) + new_values = new_values.values + else: + # Flatten all remaining dense dimensions. + shape_split = inner_axis - axis + 1 + if new_values.shape.is_fully_defined(): + old_shape = new_values.shape.as_list() + new_shape = [-1] + old_shape[shape_split:] + flat_size = _prod(old_shape[1:shape_split]) + else: + old_shape = array_ops.shape(new_values) + new_shape = array_ops.concat([[-1], old_shape[shape_split:]], axis=0) + flat_size = math_ops.cast( + math_ops.reduce_prod(old_shape[1:shape_split]), new_splits.dtype) + new_values = array_ops.reshape(new_values, new_shape) + new_splits = new_splits * flat_size + break + return RaggedTensor.from_row_splits(new_values, new_splits) + + +def _prod(lst): + """Returns the product of the numbers in a list.""" + return functools.reduce(operator.mul, lst, 1) + + +def _get_row_partition_type_tensor_pairs_tail(partition): + """Gets a row partition type tensor pair for the tail. + + If value_rowid is defined, then it is used. Otherwise, row_splits + are used. + + Args: + partition: a RowPartition. + + Returns: + A list of (row_partition_type, row_partition_tensor) pairs. + """ + if partition._has_precomputed_value_rowids(): # pylint: disable=protected-access + return ("VALUE_ROWIDS", partition.value_rowids()) + else: + return ("ROW_SPLITS", partition.row_splits()) + + +def _get_row_partition_type_tensor_pairs(rt_input): + """Gets a list of the row partitions for rt_input. + + If value_rowids are defined, then they are used. Otherwise, row_splits + are used. If the outermost level has value_rowids defind, then nrows is + also added. + + Args: + rt_input: a ragged tensor. + + Returns: + A list of (row_partition_type, row_partition_tensor) pairs. + """ + partitions = rt_input._nested_row_partitions # pylint: disable=protected-access + tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]] + + if partitions[0]._value_rowids is not None: # pylint: disable=protected-access + return [("FIRST_DIM_SIZE", partitions[0].nrows()), + ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail + else: + return [("ROW_SPLITS", partitions[0].row_splits())] + tail + + +def _shape_as_tensor(shape, dtype): + """Takes shape and coerces it to a shape as a tensor. + + If the object is already a tensor, simply passes it on (result is guaranteed + to be int64 or int32, but not necessarily dtype). + If not, creates a tensor of type dtype. + + Result is either a scalar equal to -1 if the shape is unknown_rank. + Otherwise, it is a vector, where unknown dimensions are represented with a + value of -1. + + In C++, see TensorShapeFromTensor for parsing shapes in kernels, and + InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for + use in the shape inference function. + + Args: + shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]], + Tuple[Optional[Int]]. + dtype: tf.int64 or tf.int32 + + Returns: + a scalar or vector tensor of dtype tf.int32 or tf.int64. + """ + if dtype != dtypes.int64 and dtype != dtypes.int32: + raise ValueError(f"Expected int64 or int32 for dtype: got {dtype}.") + + if isinstance(shape, tensor_lib.Tensor): + if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32: + return math_ops.cast(shape, dtype) + return shape + shape = tensor_shape.as_shape(shape) + if not shape: + # Imply rank is unknown using a -1 scalar. + return constant_op.constant(-1, dtype=dtype) + shape = [(-1 if x is None else x) for x in shape.as_list()] + # At this point, shape is List[Int]. + return constant_op.constant(shape, dtype=dtype) + + +def _nvals_uniform_row_length(values, uniform_row_length): + """Get the number of values for uniform row length constructor.""" + const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value + if const_nvals is not None: + nvals = constant_op.constant(const_nvals, uniform_row_length.dtype) + elif isinstance(values, RaggedTensor): + nvals = values.nrows(out_type=uniform_row_length.dtype) + else: + nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0] + return nvals + + +def _get_optional_partition_dtype(values): + """Returns the partition dtype, or None if None exists.""" + if isinstance(values, RaggedTensor): + # pylint: disable=protected-access + return values._row_partition.dtype + return None + + +_SUPPORTED_RAGGED_VALUE_TYPES = (tensor_lib.Tensor, RaggedTensor) + + +# TODO(edloper): Consider whether we should change the registry to be on +# TypeSpecs rather than ValueTypes. +def _add_supported_value_type(cls): + """Register the `cls` as supported value type of RaggedTenosr. + + The cls must be a subclass of CompositeTensor, and must support: + - Spec: + The Spec must be a `BatchableTypeSpec` + - Properties: + - x.shape + - x.dtype + - Methods: + - x.__getitem__(idx) (method: returns a supported value type) + - x.set_shape(shape) + - Ops: + - tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor. + - tf.tile(x) + - assert_rank_at_least(x) + - tf.ones_like(x) + - tf.gather(params=x, indices=Tensor) + - tf.add(x, y) + - tf.boolean_mask(x, ...) + - @TODO(edloper): Complete this list + + Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not + currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor: + - rt.to_tensor() + - rt.to_sparse_tensor() + - rt._to_variant() + - rt._from_variant() + - tf.ragged.cross([rt]) + - tf.gather(params=x, indices=rt) # rt used for indices + - RaggedTensorSpec methods: + - _batch + - _unbatch + - _to_tensor_list + - _to_batched_tensor_list + - _from_compatible_tensor_list + + Args: + cls: The type to be added to supported value types. + """ + if not issubclass(cls, composite_tensor.CompositeTensor): + raise ValueError(f"cls ({cls}) must be a subclass of CompositeTensor.") + if not hasattr(cls, "shape"): + raise ValueError("cls must support the `shape` property.") + if not hasattr(cls, "dtype"): + raise ValueError("cls must support the `dtype` property.") + global _SUPPORTED_RAGGED_VALUE_TYPES + _SUPPORTED_RAGGED_VALUE_TYPES += (cls,) + + +def _is_supported_ragged_values_type(value): + return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES) + + +def _assert_is_supported_ragged_values_type(value): + if not _is_supported_ragged_values_type(value): + ok_types = ", ".join(cls.__name__ for cls in _SUPPORTED_RAGGED_VALUE_TYPES) + raise TypeError(f"type(values) must be one of: {ok_types}, got {value}.") + + +def _formatter(x): + """Separate Numpy array elements with comma.""" + if isinstance(x, np.ndarray): + if x.size != 0: + return np.array2string(x, separator=", ") + else: + # When x.size==0, np.array2string always returns `[]`. This isn't always + # what we want. E.g., if `x.shape=[0, 3]`, then we want `[[], [], []]`. + return repr(x.tolist()) + else: + return str(x) + +# Type annotation indicating that a value is ragged. Includes RaggedTensor +# as well as the (deprecated) RaggedTensorValue class from TF 1.x. +Ragged = typing.Union[RaggedTensor, ragged_tensor_value.RaggedTensorValue] + +# Type annotation indicating that a value is a ragged tensor, a dense tensor, +# or a value that can be converted to a tensor (e.g. np.array). +# TODO(edloper): Add Variable to TensorLike, and remove it from here. +RaggedOrDense = typing.Union[Ragged, core_types.TensorLike] + +# RaggedTensor must import ragged_ops to ensure that all dispatched ragged ops +# are registered. Ragged ops import RaggedTensor, so import at bottom of the +# file to avoid a partially-initialized module error. +from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import, g-bad-import-order, g-import-not-at-top diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_util.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_util.py new file mode 100644 index 0000000000000000000000000000000000000000..81b91d0214ec2c07d2b66ac03e4d38bf6e65ecdc --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/tensorflow/python/ops/ragged/ragged_util.py @@ -0,0 +1,138 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Private convenience functions for RaggedTensors. + +None of these methods are exposed in the main "ragged" package. +""" + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_ragged_math_ops +from tensorflow.python.ops import math_ops + + +def assert_splits_match(nested_splits_lists): + """Checks that the given splits lists are identical. + + Performs static tests to ensure that the given splits lists are identical, + and returns a list of control dependency op tensors that check that they are + fully identical. + + Args: + nested_splits_lists: A list of nested_splits_lists, where each split_list is + a list of `splits` tensors from a `RaggedTensor`, ordered from outermost + ragged dimension to innermost ragged dimension. + + Returns: + A list of control dependency op tensors. + Raises: + ValueError: If the splits are not identical. + """ + error_msg = "Inputs must have identical ragged splits" + for splits_list in nested_splits_lists: + if len(splits_list) != len(nested_splits_lists[0]): + raise ValueError(error_msg) + return [ + check_ops.assert_equal(s1, s2, message=error_msg) + for splits_list in nested_splits_lists[1:] + for (s1, s2) in zip(nested_splits_lists[0], splits_list) + ] + + +# Note: imported here to avoid circular dependency of array_ops. +get_positive_axis = array_ops.get_positive_axis +convert_to_int_tensor = array_ops.convert_to_int_tensor +repeat = array_ops.repeat_with_axis + + +def lengths_to_splits(lengths): + """Returns splits corresponding to the given lengths.""" + return array_ops.concat([[0], math_ops.cumsum(lengths)], axis=-1) + + +def repeat_ranges(params, splits, repeats): + """Repeats each range of `params` (as specified by `splits`) `repeats` times. + + Let the `i`th range of `params` be defined as + `params[splits[i]:splits[i + 1]]`. Then this function returns a tensor + containing range 0 repeated `repeats[0]` times, followed by range 1 repeated + `repeats[1]`, ..., followed by the last range repeated `repeats[-1]` times. + + Args: + params: The `Tensor` whose values should be repeated. + splits: A splits tensor indicating the ranges of `params` that should be + repeated. Elements should be non-negative integers. + repeats: The number of times each range should be repeated. Supports + broadcasting from a scalar value. Elements should be non-negative + integers. + + Returns: + A `Tensor` with the same rank and type as `params`. + + #### Example: + + >>> print(repeat_ranges( + ... params=tf.constant(['a', 'b', 'c']), + ... splits=tf.constant([0, 2, 3]), + ... repeats=tf.constant(3))) + tf.Tensor([b'a' b'b' b'a' b'b' b'a' b'b' b'c' b'c' b'c'], + shape=(9,), dtype=string) + """ + # Check if the input is valid + splits_checks = [ + check_ops.assert_non_negative( + splits, message="Input argument 'splits' must be non-negative" + ), + check_ops.assert_integer( + splits, + message=( + "Input argument 'splits' must be integer, but got" + f" {splits.dtype} instead" + ), + ), + ] + repeats_checks = [ + check_ops.assert_non_negative( + repeats, message="Input argument 'repeats' must be non-negative" + ), + check_ops.assert_integer( + repeats, + message=( + "Input argument 'repeats' must be integer, but got" + f" {repeats.dtype} instead" + ), + ), + ] + splits = control_flow_ops.with_dependencies(splits_checks, splits) + repeats = control_flow_ops.with_dependencies(repeats_checks, repeats) + + # Divide `splits` into starts and limits, and repeat them `repeats` times. + if repeats.shape.ndims != 0: + repeated_starts = repeat(splits[:-1], repeats, axis=0) + repeated_limits = repeat(splits[1:], repeats, axis=0) + else: + # Optimization: we can just call repeat once, and then slice the result. + repeated_splits = repeat(splits, repeats, axis=0) + n_splits = array_ops.shape(repeated_splits, out_type=repeats.dtype)[0] + repeated_starts = repeated_splits[:n_splits - repeats] + repeated_limits = repeated_splits[repeats:] + + # Get indices for each range from starts to limits, and use those to gather + # the values in the desired repetition pattern. + one = array_ops.ones((), repeated_starts.dtype) + offsets = gen_ragged_math_ops.ragged_range( + repeated_starts, repeated_limits, one) + return array_ops.gather(params, offsets.rt_dense_values)